1.3 逻辑回归

将线性回归的模型改一改,就可以用于二分类。逻辑回归拟合样本属于某个分类,也就是样本为正样本的概率。

操作步骤

导入所需的包。

import tensorflow as tf
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import sklearn.datasets as ds
import sklearn.model_selection as ms

导入数据,并进行预处理。我们使用鸢尾花数据集所有样本,根据萼片长度和花瓣长度预测样本是不是山鸢尾(第一种)。

iris = ds.load_iris()x_ = iris.data[:, [0, 2]]
y_ = (iris.target == 0).astype(int)
y_ = np.expand_dims(y_ , 1)x_train, x_test, y_train, y_test = \ms.train_test_split(x_, y_, train_size=0.7, test_size=0.3)

定义超参数。

变量 含义
n_input 样本特征数
n_epoch 迭代数
lr 学习率
threshold 如果输出超过这个概率,将样本判定为正样本
n_input = 2
n_epoch = 2000
lr = 0.05
threshold = 0.5

搭建模型。

变量 含义
x 输入
y 真实标签
w 权重
b 偏置
z 中间变量,x的线性变换
a 输出,也就是样本是正样本的概率
x = tf.placeholder(tf.float64, [None, n_input])
y = tf.placeholder(tf.float64, [None, 1])
w = tf.Variable(np.random.rand(n_input, 1))
b = tf.Variable(np.random.rand(1, 1))
z = x @ w + b
a = tf.sigmoid(z)

定义损失、优化操作、和准确率度量指标。分类问题有很多指标,这里只展示一种。

我们使用交叉熵损失函数,如下。

−mean(Y⊗log⁡(A)+(1−Y)⊗log⁡(1−A))-mean(Y \otimes \log(A) + (1-Y) \otimes \log(1-A))−mean(Y⊗log(A)+(1−Y)⊗log(1−A))

它的意思是,对于正样本,y 为 1,损失变为-log(a),输出会尽可能接近一。对于负样本,y为 0,损失变为-log(1 - a),输出会尽可能接近零。总之,它使输出尽可能接近真实标签。

变量 含义
loss 损失
op 优化操作
y_hat 标签的预测值
acc 准确率
loss = - tf.reduce_mean(y * tf.log(a) + (1 - y) * tf.log(1 - a))
op = tf.train.AdamOptimizer(lr).minimize(loss)y_hat = tf.to_double(a > threshold)
acc = tf.reduce_mean(tf.to_double(tf.equal(y_hat, y)))

使用训练集训练模型。

losses = []
accs = []with tf.Session() as sess:sess.run(tf.global_variables_initializer())saver = tf.train.Saver(max_to_keep=1)for e in range(n_epoch):_, loss_ = sess.run([op, loss], feed_dict={x: x_train, y: y_train})losses.append(loss_)

使用测试集计算准确率。

        acc_ = sess.run(acc, feed_dict={x: x_test, y: y_test})accs.append(acc_)

每一百步打印损失和度量值。

        if e % 100 == 0:print(f'epoch: {e}, loss: {loss_}, acc: {acc_}')saver.save(sess,'logit/logit', global_step=e)

得到决策边界:

    x_plt = x_[:, 0]y_plt = x_[:, 1]c_plt = y_.ravel()x_min = x_plt.min() - 1x_max = x_plt.max() + 1y_min = y_plt.min() - 1y_max = y_plt.max() + 1x_rng = np.arange(x_min, x_max, 0.05)y_rng = np.arange(y_min, y_max, 0.05)x_rng, y_rng = np.meshgrid(x_rng, y_rng)model_input = np.asarray([x_rng.ravel(), y_rng.ravel()]).Tmodel_output = sess.run(y_hat, feed_dict={x: model_input}).astype(int)c_rng = model_output.reshape(x_rng.shape)

输出:

epoch: 0, loss: 3.935746371309244, acc: 0.3333333333333333
epoch: 100, loss: 0.1969325408656252, acc: 1.0
epoch: 200, loss: 0.08548362243852041, acc: 1.0
epoch: 300, loss: 0.050833687966014396, acc: 1.0
epoch: 400, loss: 0.034929315249291375, acc: 1.0
epoch: 500, loss: 0.026013692651528184, acc: 1.0
epoch: 600, loss: 0.02038864243607467, acc: 1.0
epoch: 700, loss: 0.016552042129938136, acc: 1.0
epoch: 800, loss: 0.013786692432697542, acc: 1.0
epoch: 900, loss: 0.011709709551073783, acc: 1.0
epoch: 1000, loss: 0.010099234422592073, acc: 1.0
epoch: 1100, loss: 0.008818382202721829, acc: 1.0
epoch: 1200, loss: 0.007778392815694136, acc: 1.0
epoch: 1300, loss: 0.0069193419951217704, acc: 1.0
epoch: 1400, loss: 0.0061993983430654875, acc: 1.0
epoch: 1500, loss: 0.00558852696047961, acc: 1.0
epoch: 1600, loss: 0.005064638072189167, acc: 1.0
epoch: 1700, loss: 0.00461114435393481, acc: 1.0
epoch: 1800, loss: 0.004215362417896155, acc: 1.0
epoch: 1900, loss: 0.003867437954560204, acc: 1.0

绘制整个数据集以及决策边界。

plt.figure()
cmap = mpl.colors.ListedColormap(['r', 'b'])
plt.scatter(x_plt, y_plt, c=c_plt, cmap=cmap)
plt.contourf(x_rng, y_rng, c_rng, alpha=0.2, linewidth=5, cmap=cmap)
plt.title('Data and Model')
plt.xlabel('Petal Length (cm)')
plt.ylabel('Sepal Length (cm)')
plt.show()

绘制训练集上的损失。

plt.figure()
plt.plot(losses)
plt.title('Loss on Training Set')
plt.xlabel('#epoch')
plt.ylabel('Cross Entropy')
plt.show()

绘制测试集上的准确率。

plt.figure()
plt.plot(accs)
plt.title('Accurary on Testing Set')
plt.xlabel('#epoch')
plt.ylabel('Accurary')
plt.show()

扩展阅读

  • 斯坦福 CS229 笔记:六、逻辑回归

TensorFlow HOWTO 1.3 逻辑回归相关推荐

  1. TensorFlow HOWTO 1.4 Softmax 回归

    1.4 Softmax 回归 Softmax 回归可以看成逻辑回归在多个类别上的推广. 操作步骤 导入所需的包. import tensorflow as tf import numpy as np ...

  2. 使用TensorFlow编程实现一元逻辑回归

    内容回顾 逻辑回归是在线性模型的基础上,再增加一个Sigmoid函数来实现的. 输入样本特征,经过线性组合之后,得到的是一个连续值,经过Sigmoid函数,把它转化为一个0-1之间的概率,再通过设置一 ...

  3. TensorFlow基础7-机器学习基础知识(逻辑回归,鸢尾花实现多分类)

    记录TensorFlow听课笔记 文章目录 记录TensorFlow听课笔记 一,线性回归 二,广义线性回归 三,一元/多元逻辑回归 四,实现一元逻辑回归 五,多分类问题 六,TensorFlow实现 ...

  4. Tensorflow【实战Google深度学习框架】—Logistic regression逻辑回归模型实例讲解

    文章目录 1.前言 2.程序详细讲解 环境设定 数据读取 准备好placeholder,开好容器来装数据 准备好参数/权重 拿到每个类别的score 计算多分类softmax的loss functio ...

  5. TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率

    TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率 目录 输出结果 设计代码 输出结果 设计代码 #TF之LoR:基于tensorflow实现手写数字图片识别准 ...

  6. python 多分类逻辑回归_机器学习实践:多分类逻辑回归(softmax回归)的sklearn实现和tensorflow实现...

    本文所有代码及数据可下载. Scikit Learn 篇:Light 版 scikit learn内置了逻辑回归,对于小规模的应用较为简单,一般使用如下代码即可 from sklearn.linear ...

  7. tensorflow综合示例4:逻辑回归:使用Estimator

    文章目录 1.加载csv格式的数据集并生成Dataset 1.1 pandas读取csv数据生成Dataframe 1.2 将Dataframe生成Dataset 2.将数据封装成Feature co ...

  8. 逻辑回归实现多分类任务(python+TensorFlow+mnist)

    逻辑回归实现多分类任务(python+TensorFlow+mnist) 逻辑回归是统计学中的一种经典方法,虽然叫回归,但在机器学习领域,逻辑回归通常情况下当成一个分类任务,softmax就是由其演变 ...

  9. Tensorflow逻辑回归处理MNIST数据集

    #1:导入所需的软件 import tensorflow as tf ''' 获取mnist数据放在当前文件夹下,利用input_data函数解析该数据集 train_img和train--label ...

最新文章

  1. win 10无法启动print spooler服务,提示1068依赖服务或组无法启动
  2. zabbix 安装_Zabbix的WEB安装与配置
  3. [置顶]       ibatis做分页
  4. bind、delegate、on的区别
  5. “我哥毕业1年,做Python挣了50W!”网友:吹得太少...
  6. 用Python将音频内容转换为文本格式
  7. 使用TensorFlow.js进行人脸触摸检测第2部分:使用BodyPix
  8. Ansible 学习总结(1)—— Ansible 入门详解
  9. 有标号的DAG计数 II
  10. 这个简单的常见面试题,怎么答才会加分?
  11. 小米5x 运行linux,小米5X root+xposed使用方法
  12. 如何实现施耐德Twido系列PLC远程上下载
  13. uchome持久XSS(2.0版本测试通过)
  14. 恶意软件清理助手 v2.50 Build 005
  15. c语言十佳歌手程序,十佳歌手决赛的细则流程
  16. ajax执行先后顺序
  17. 苹果吃鸡蓝牙耳机推荐哪个?性价比高的游戏蓝牙耳机推荐
  18. LINODE优惠码与服务器搭建
  19. java lcm_Orac and LCM
  20. Principal Components Analysis

热门文章

  1. FPGA可综合语句建立原则
  2. python函数type的用意_Python内置函数Type()函数一个有趣的用法
  3. 嵌入式 U 盘自动挂载
  4. oracle 行数大于一时,oracle – PL / SQL ORA-01422:精确的提取返回超过请求的行数
  5. nginx基础概念(100%)之pipe
  6. (1)散列表(哈希表)的定义
  7. 【C语言】C语言学习整理-putchar,printf,getchar,scanf定义及区别
  8. QT5开发及实例学习之十Qt5主窗口构成
  9. server sql 分组 去重 字符串拼接_SQL必知必会
  10. c语言调用子程序,哪位师傅知道51单片机怎样编写子程序?C语言的。在主程序里调...