1.4 Softmax 回归

Softmax 回归可以看成逻辑回归在多个类别上的推广。

操作步骤

导入所需的包。

import tensorflow as tf
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import sklearn.datasets as ds
import sklearn.model_selection as ms

导入数据,并进行预处理。我们使用鸢尾花数据集所有样本,根据萼片长度和花瓣长度预测样本属于三个品种中的哪一种。

iris = ds.load_iris()x_ = iris.data[:, [0, 2]]
y_ = np.expand_dims(iris.target , 1)x_train, x_test, y_train, y_test = \ms.train_test_split(x_, y_, train_size=0.7, test_size=0.3)

定义超参数。

n_input = 2
n_output = 3
n_epoch = 2000
lr = 0.05
变量 含义
n_input 样本特征数
n_ouput 样本类别数
n_epoch 迭代数
lr 学习率

搭建模型。

变量 含义
x 输入
y 真实标签
y_oh 独热的真实标签
w 权重
b 偏置
z 中间变量,x的线性变换
a 输出,也就是样本是某个类别的概率
x = tf.placeholder(tf.float64, [None, n_input])
y = tf.placeholder(tf.int64, [None, 1])
y_oh = tf.one_hot(y, n_output)
y_oh = tf.to_double(tf.reshape(y_oh, [-1, n_output]))
w = tf.Variable(np.random.rand(n_input, n_output))
b = tf.Variable(np.random.rand(1, n_output))
z = x @ w + b
a = tf.nn.softmax(z)

定义损失、优化操作、和准确率度量指标。分类问题有很多指标,这里只展示一种。

我们使用交叉熵损失函数,对于多分类问题,需要改一改,如下。

−mean(sumaxis=1(Y⊗log⁡(A)))-mean(sum_{axis=1}(Y \otimes \log(A)))−mean(sumaxis=1​(Y⊗log(A)))

变量 含义
loss 损失
op 优化操作
y_hat 标签的预测值
acc 准确率
loss = - tf.reduce_mean(tf.reduce_sum(y_oh * tf.log(a), 1))
op = tf.train.AdamOptimizer(lr).minimize(loss)y_hat = tf.argmax(a, 1)
y_hat = tf.expand_dims(y_hat, 1)
acc = tf.reduce_mean(tf.to_double(tf.equal(y_hat, y)))

使用训练集训练模型。

losses = []
accs = []with tf.Session() as sess:sess.run(tf.global_variables_initializer())saver = tf.train.Saver(max_to_keep=1)for e in range(n_epoch):_, loss_ = sess.run([op, loss], feed_dict={x: x_train, y: y_train})losses.append(loss_)

使用测试集计算准确率。

        acc_ = sess.run(acc, feed_dict={x: x_test, y: y_test})accs.append(acc_)

每一百步打印损失和度量值。

        if e % 100 == 0:print(f'epoch: {e}, loss: {loss_}, acc: {acc_}')saver.save(sess,'logit/logit', global_step=e)

得到决策边界:

    x_plt = x_[:, 0]y_plt = x_[:, 1]c_plt = y_.ravel()x_min = x_plt.min() - 1x_max = x_plt.max() + 1y_min = y_plt.min() - 1y_max = y_plt.max() + 1x_rng = np.arange(x_min, x_max, 0.05)y_rng = np.arange(y_min, y_max, 0.05)x_rng, y_rng = np.meshgrid(x_rng, y_rng)model_input = np.asarray([x_rng.ravel(), y_rng.ravel()]).Tmodel_output = sess.run(y_hat, feed_dict={x: model_input}).astype(int)c_rng = model_output.reshape(x_rng.shape)

输出:

epoch: 0, loss: 1.4210691245230944, acc: 0.4222222222222222
epoch: 100, loss: 0.34817911438772636, acc: 0.9777777777777777
epoch: 200, loss: 0.24319161311060128, acc: 0.9777777777777777
epoch: 300, loss: 0.19423490522003387, acc: 0.9777777777777777
epoch: 400, loss: 0.16772540127514665, acc: 0.9777777777777777
epoch: 500, loss: 0.15148045580780634, acc: 0.9777777777777777
epoch: 600, loss: 0.14055638836845924, acc: 0.9777777777777777
epoch: 700, loss: 0.1326877769387738, acc: 0.9777777777777777
epoch: 800, loss: 0.12672480658251276, acc: 1.0
epoch: 900, loss: 0.12203422030859229, acc: 1.0
epoch: 1000, loss: 0.11824285244695919, acc: 1.0
epoch: 1100, loss: 0.11511738393720357, acc: 1.0
epoch: 1200, loss: 0.11250383205230477, acc: 1.0
epoch: 1300, loss: 0.11029541725080125, acc: 1.0
epoch: 1400, loss: 0.10841477350763963, acc: 1.0
epoch: 1500, loss: 0.10680373944570205, acc: 1.0
epoch: 1600, loss: 0.10541728211943671, acc: 1.0
epoch: 1700, loss: 0.10421972968246913, acc: 1.0
epoch: 1800, loss: 0.10318232665398802, acc: 1.0
epoch: 1900, loss: 0.10228157312421919, acc: 1.0

绘制整个数据集以及决策边界。

plt.figure()
cmap = mpl.colors.ListedColormap(['r', 'b', 'y'])
plt.scatter(x_plt, y_plt, c=c_plt, cmap=cmap)
plt.contourf(x_rng, y_rng, c_rng, alpha=0.2, linewidth=5, cmap=cmap)
plt.title('Data and Model')
plt.xlabel('Petal Length (cm)')
plt.ylabel('Sepal Length (cm)')
plt.show()

绘制训练集上的损失。

plt.figure()
plt.plot(losses)
plt.title('Loss on Training Set')
plt.xlabel('#epoch')
plt.ylabel('Cross Entropy')
plt.show()

绘制测试集上的准确率。

plt.figure()
plt.plot(accs)
plt.title('Accurary on Testing Set')
plt.xlabel('#epoch')
plt.ylabel('Accurary')
plt.show()

扩展阅读

  • WikiPedia: Multinomial logistic regression

TensorFlow HOWTO 1.4 Softmax 回归相关推荐

  1. TensorFlow HOWTO 1.3 逻辑回归

    1.3 逻辑回归 将线性回归的模型改一改,就可以用于二分类.逻辑回归拟合样本属于某个分类,也就是样本为正样本的概率. 操作步骤 导入所需的包. import tensorflow as tf impo ...

  2. 简单探索MNIST(Softmax回归和两层CNN)-Tensorflow学习

    简述 这次是在看<21个项目玩转深度学习>那本书的第一章节后做的笔记. 这段时间,打算把TensorFlow再补补,提升一下技术水平~ 希望我能坚持下来,抽空把这本书刷下来吧~ 导入数据 ...

  3. TensorFlow精进之路(一):Softmax回归模型训练MNIST

    1.MNIST数据集简介: MNIST数据集主要由一些手写数字的图片和相应标签组成,图片总共分为10类,分别对应0-9十个数字. 如上图所示,每张图片的大小为28×28像素.而标签则由one-hot向 ...

  4. TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%)

    TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%) 目录 设计思路 全部代码 设计思路 全部代码 #TF:利用是Softmax回归+GD ...

  5. 【深度学习】基于MindSpore和pytorch的Softmax回归及前馈神经网络

    1 实验内容简介 1.1 实验目的 (1)熟练掌握tensor相关各种操作: (2)掌握广义线性回归模型(logistic模型.sofmax模型).前馈神经网络模型的原理: (3)熟练掌握基于mind ...

  6. 机器学习心得(三)——softmax回归

    机器学习心得(三)--softmax回归 在上一篇文章中,主要以二分类为例,讲解了logistic回归模型原理.那么对于多分类问题,我们应该如何处理呢?当然,选择构建许多二分类器进行概率输出自然是一个 ...

  7. 基于一个线性层的softmax回归模型和MNIST数据集识别自己手写数字

    原博文是用cnn识别,因为我是在自己电脑上跑代码,用不了处理器,所以参考Mnist官网上的一个线性层的softmax回归模型的代码,把两篇文章结合起来识别. 最后效果 源代码识别mnist数据集的准确 ...

  8. Softmax 回归 vs. k 个二元分类器

    如果你在开发一个音乐分类的应用,需要对k种类型的音乐进行识别,那么是选择使用 softmax 分类器呢,还是使用 logistic 回归算法建立 k 个独立的二元分类器呢? 这一选择取决于你的类别之间 ...

  9. 【深度学习】基于Pytorch的softmax回归问题辨析和应用(一)

    [深度学习]基于Pytorch的softmax回归问题辨析和应用(一) 文章目录 1 概述 2 网络结构 3 softmax运算 4 仿射变换 5 对数似然 6 图像分类数据集 7 数据预处理 8 总 ...

最新文章

  1. AlphaGo制胜绝招:蒙特卡洛树搜索入门指南
  2. <马哲>劳动价值论的理论及实践意义
  3. macos访问linux分区,在linux中访问macos 下的分区。
  4. 产品经理必备神器推荐
  5. JsonData响应工具类封装
  6. php asp.net 代码量少,.NET_asp.net 反射减少代码书写量, 复制代码 代码如下:public b - phpStudy...
  7. 网络安全——ipsec
  8. keras中的mini-batch gradient descent (转)
  9. c++string类的常用方法详解
  10. [BUAA软工]beta阶段贡献分
  11. 考研过程中最容易犯的八大错误
  12. 【JAVASCRIPT】javascript获取屏幕,浏览器,网页高度宽度
  13. SilverLight学习之基本图形
  14. 常用电子元件识别图解大全
  15. 2018百度seo最新算法大全 青岛墨羽SEO统计
  16. APP调用微信授权登录-JAVA后台实现
  17. java调用按键精灵安卓_安卓版按键精灵基本功能版
  18. 面临裁员潮,更快找到新工作的秘诀
  19. vue校验表格数据_如何通过数据验证限制Google表格中的数据
  20. 一个故事讲完进程、线程和协程

热门文章

  1. (22)FPGA软核、固核、硬核介绍
  2. 数值运算pythonmopn_Python SciPy库——拟合与插值
  3. 1001.双系统互联的坑
  4. 【声辐射】——不同坐标系下的格林函数
  5. Linux-3.2.0.24中内核的Netlink测试使用
  6. 一个解除TCP连接的TIME_WAIT状态限制的简便方法
  7. 使用pyinstaller打包python程序时问题记录
  8. 手机版计算机音乐,计算机音乐手机版
  9. python网站服务器好麻烦_python写的网站,云服务器经常无法访问
  10. vue 实现无限轮播_Vue 实现无缝轮播