TensorFlow HOWTO 1.4 Softmax 回归
1.4 Softmax 回归
Softmax 回归可以看成逻辑回归在多个类别上的推广。
操作步骤
导入所需的包。
import tensorflow as tf
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import sklearn.datasets as ds
import sklearn.model_selection as ms
导入数据,并进行预处理。我们使用鸢尾花数据集所有样本,根据萼片长度和花瓣长度预测样本属于三个品种中的哪一种。
iris = ds.load_iris()x_ = iris.data[:, [0, 2]]
y_ = np.expand_dims(iris.target , 1)x_train, x_test, y_train, y_test = \ms.train_test_split(x_, y_, train_size=0.7, test_size=0.3)
定义超参数。
n_input = 2
n_output = 3
n_epoch = 2000
lr = 0.05
变量 | 含义 |
---|---|
n_input
|
样本特征数 |
n_ouput
|
样本类别数 |
n_epoch
|
迭代数 |
lr
|
学习率 |
搭建模型。
变量 | 含义 |
---|---|
x
|
输入 |
y
|
真实标签 |
y_oh
|
独热的真实标签 |
w
|
权重 |
b
|
偏置 |
z
|
中间变量,x 的线性变换
|
a
|
输出,也就是样本是某个类别的概率 |
x = tf.placeholder(tf.float64, [None, n_input])
y = tf.placeholder(tf.int64, [None, 1])
y_oh = tf.one_hot(y, n_output)
y_oh = tf.to_double(tf.reshape(y_oh, [-1, n_output]))
w = tf.Variable(np.random.rand(n_input, n_output))
b = tf.Variable(np.random.rand(1, n_output))
z = x @ w + b
a = tf.nn.softmax(z)
定义损失、优化操作、和准确率度量指标。分类问题有很多指标,这里只展示一种。
我们使用交叉熵损失函数,对于多分类问题,需要改一改,如下。
−mean(sumaxis=1(Y⊗log(A)))-mean(sum_{axis=1}(Y \otimes \log(A)))−mean(sumaxis=1(Y⊗log(A)))
变量 | 含义 |
---|---|
loss
|
损失 |
op
|
优化操作 |
y_hat
|
标签的预测值 |
acc
|
准确率 |
loss = - tf.reduce_mean(tf.reduce_sum(y_oh * tf.log(a), 1))
op = tf.train.AdamOptimizer(lr).minimize(loss)y_hat = tf.argmax(a, 1)
y_hat = tf.expand_dims(y_hat, 1)
acc = tf.reduce_mean(tf.to_double(tf.equal(y_hat, y)))
使用训练集训练模型。
losses = []
accs = []with tf.Session() as sess:sess.run(tf.global_variables_initializer())saver = tf.train.Saver(max_to_keep=1)for e in range(n_epoch):_, loss_ = sess.run([op, loss], feed_dict={x: x_train, y: y_train})losses.append(loss_)
使用测试集计算准确率。
acc_ = sess.run(acc, feed_dict={x: x_test, y: y_test})accs.append(acc_)
每一百步打印损失和度量值。
if e % 100 == 0:print(f'epoch: {e}, loss: {loss_}, acc: {acc_}')saver.save(sess,'logit/logit', global_step=e)
得到决策边界:
x_plt = x_[:, 0]y_plt = x_[:, 1]c_plt = y_.ravel()x_min = x_plt.min() - 1x_max = x_plt.max() + 1y_min = y_plt.min() - 1y_max = y_plt.max() + 1x_rng = np.arange(x_min, x_max, 0.05)y_rng = np.arange(y_min, y_max, 0.05)x_rng, y_rng = np.meshgrid(x_rng, y_rng)model_input = np.asarray([x_rng.ravel(), y_rng.ravel()]).Tmodel_output = sess.run(y_hat, feed_dict={x: model_input}).astype(int)c_rng = model_output.reshape(x_rng.shape)
输出:
epoch: 0, loss: 1.4210691245230944, acc: 0.4222222222222222
epoch: 100, loss: 0.34817911438772636, acc: 0.9777777777777777
epoch: 200, loss: 0.24319161311060128, acc: 0.9777777777777777
epoch: 300, loss: 0.19423490522003387, acc: 0.9777777777777777
epoch: 400, loss: 0.16772540127514665, acc: 0.9777777777777777
epoch: 500, loss: 0.15148045580780634, acc: 0.9777777777777777
epoch: 600, loss: 0.14055638836845924, acc: 0.9777777777777777
epoch: 700, loss: 0.1326877769387738, acc: 0.9777777777777777
epoch: 800, loss: 0.12672480658251276, acc: 1.0
epoch: 900, loss: 0.12203422030859229, acc: 1.0
epoch: 1000, loss: 0.11824285244695919, acc: 1.0
epoch: 1100, loss: 0.11511738393720357, acc: 1.0
epoch: 1200, loss: 0.11250383205230477, acc: 1.0
epoch: 1300, loss: 0.11029541725080125, acc: 1.0
epoch: 1400, loss: 0.10841477350763963, acc: 1.0
epoch: 1500, loss: 0.10680373944570205, acc: 1.0
epoch: 1600, loss: 0.10541728211943671, acc: 1.0
epoch: 1700, loss: 0.10421972968246913, acc: 1.0
epoch: 1800, loss: 0.10318232665398802, acc: 1.0
epoch: 1900, loss: 0.10228157312421919, acc: 1.0
绘制整个数据集以及决策边界。
plt.figure()
cmap = mpl.colors.ListedColormap(['r', 'b', 'y'])
plt.scatter(x_plt, y_plt, c=c_plt, cmap=cmap)
plt.contourf(x_rng, y_rng, c_rng, alpha=0.2, linewidth=5, cmap=cmap)
plt.title('Data and Model')
plt.xlabel('Petal Length (cm)')
plt.ylabel('Sepal Length (cm)')
plt.show()
绘制训练集上的损失。
plt.figure()
plt.plot(losses)
plt.title('Loss on Training Set')
plt.xlabel('#epoch')
plt.ylabel('Cross Entropy')
plt.show()
绘制测试集上的准确率。
plt.figure()
plt.plot(accs)
plt.title('Accurary on Testing Set')
plt.xlabel('#epoch')
plt.ylabel('Accurary')
plt.show()
扩展阅读
- WikiPedia: Multinomial logistic regression
TensorFlow HOWTO 1.4 Softmax 回归相关推荐
- TensorFlow HOWTO 1.3 逻辑回归
1.3 逻辑回归 将线性回归的模型改一改,就可以用于二分类.逻辑回归拟合样本属于某个分类,也就是样本为正样本的概率. 操作步骤 导入所需的包. import tensorflow as tf impo ...
- 简单探索MNIST(Softmax回归和两层CNN)-Tensorflow学习
简述 这次是在看<21个项目玩转深度学习>那本书的第一章节后做的笔记. 这段时间,打算把TensorFlow再补补,提升一下技术水平~ 希望我能坚持下来,抽空把这本书刷下来吧~ 导入数据 ...
- TensorFlow精进之路(一):Softmax回归模型训练MNIST
1.MNIST数据集简介: MNIST数据集主要由一些手写数字的图片和相应标签组成,图片总共分为10类,分别对应0-9十个数字. 如上图所示,每张图片的大小为28×28像素.而标签则由one-hot向 ...
- TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%)
TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%) 目录 设计思路 全部代码 设计思路 全部代码 #TF:利用是Softmax回归+GD ...
- 【深度学习】基于MindSpore和pytorch的Softmax回归及前馈神经网络
1 实验内容简介 1.1 实验目的 (1)熟练掌握tensor相关各种操作: (2)掌握广义线性回归模型(logistic模型.sofmax模型).前馈神经网络模型的原理: (3)熟练掌握基于mind ...
- 机器学习心得(三)——softmax回归
机器学习心得(三)--softmax回归 在上一篇文章中,主要以二分类为例,讲解了logistic回归模型原理.那么对于多分类问题,我们应该如何处理呢?当然,选择构建许多二分类器进行概率输出自然是一个 ...
- 基于一个线性层的softmax回归模型和MNIST数据集识别自己手写数字
原博文是用cnn识别,因为我是在自己电脑上跑代码,用不了处理器,所以参考Mnist官网上的一个线性层的softmax回归模型的代码,把两篇文章结合起来识别. 最后效果 源代码识别mnist数据集的准确 ...
- Softmax 回归 vs. k 个二元分类器
如果你在开发一个音乐分类的应用,需要对k种类型的音乐进行识别,那么是选择使用 softmax 分类器呢,还是使用 logistic 回归算法建立 k 个独立的二元分类器呢? 这一选择取决于你的类别之间 ...
- 【深度学习】基于Pytorch的softmax回归问题辨析和应用(一)
[深度学习]基于Pytorch的softmax回归问题辨析和应用(一) 文章目录 1 概述 2 网络结构 3 softmax运算 4 仿射变换 5 对数似然 6 图像分类数据集 7 数据预处理 8 总 ...
最新文章
- AlphaGo制胜绝招:蒙特卡洛树搜索入门指南
- <马哲>劳动价值论的理论及实践意义
- macos访问linux分区,在linux中访问macos 下的分区。
- 产品经理必备神器推荐
- JsonData响应工具类封装
- php asp.net 代码量少,.NET_asp.net 反射减少代码书写量, 复制代码 代码如下:public b - phpStudy...
- 网络安全——ipsec
- keras中的mini-batch gradient descent (转)
- c++string类的常用方法详解
- [BUAA软工]beta阶段贡献分
- 考研过程中最容易犯的八大错误
- 【JAVASCRIPT】javascript获取屏幕,浏览器,网页高度宽度
- SilverLight学习之基本图形
- 常用电子元件识别图解大全
- 2018百度seo最新算法大全 青岛墨羽SEO统计
- APP调用微信授权登录-JAVA后台实现
- java调用按键精灵安卓_安卓版按键精灵基本功能版
- 面临裁员潮,更快找到新工作的秘诀
- vue校验表格数据_如何通过数据验证限制Google表格中的数据
- 一个故事讲完进程、线程和协程
热门文章
- (22)FPGA软核、固核、硬核介绍
- 数值运算pythonmopn_Python SciPy库——拟合与插值
- 1001.双系统互联的坑
- 【声辐射】——不同坐标系下的格林函数
- Linux-3.2.0.24中内核的Netlink测试使用
- 一个解除TCP连接的TIME_WAIT状态限制的简便方法
- 使用pyinstaller打包python程序时问题记录
- 手机版计算机音乐,计算机音乐手机版
- python网站服务器好麻烦_python写的网站,云服务器经常无法访问
- vue 实现无限轮播_Vue 实现无缝轮播