1 前言

上一篇文章 中笔者介绍了如何通过Tensorflow来实现线性回归。在接下来的这篇文章中,笔者将会以Fashion MNIST数据集为例来介绍如何用Tensorflow实现一个Softmax多分类模型。在这篇文章中,我们会开始慢慢接触到Tensoflow中用于实现分类模型的API,例如tf.nn.softmax()softmax_cross_entropy_with_logits_v2等。

2 数据处理

2.1 导入相关包

import tensorflow as tf
import numpy as np
from tensorflow.keras.datasets.fashion_mnist import load_data

正如上一篇文章中说到,tensorflow.keras.datasets中内置了一些比较常用的数据集,并且Tensorflow在各个数据集中均实现了load_data这个方法来对数据集进行载入。所以,上面第三行代码的作用就是用来载入fashion_mnist数据集。

2.2 载入数据

  • 标签转换

    在载入数据前先介绍一个将按序类别编码转化为one-hot编码的方法:

    def dense_to_one_hot(labels_dense, num_classes=10):num_labels = labels_dense.shape[0]index_offset = np.arange(num_labels) * num_classeslabels_one_hot = np.zeros((num_labels, num_classes))labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1return labels_one_hot
    

    该函数的作用就是将普通的[0,1,2,3]类别标签转换为one-hot的编码形式。

  • 载入数据集

    载入任务需要的fashion mnist数据集:

    def load_mnist_data(one_hot=True):(x_train, y_train), (x_test, y_test) = load_data()if one_hot:y_train = dense_to_one_hot(y_train)y_test = dense_to_one_hot(y_test)return x_train / 255., y_train, x_test / 255., y_test
    
  • 批产生器

    构造一个batch迭代器,在训练的过程中每个返回一个batch的数据:

    def gen_batch(x, y, batch_size=64):s_index, e_index, batches = 0, 0 + batch_size, len(y) // batch_sizeif batches * batch_size < len(y):batches += 1for i in range(batches):if e_index > len(y):e_index = len(y)batch_x = x[s_index:e_index]batch_y = y[s_index: e_index]s_index, e_index = e_index, e_index + batch_sizeyield batch_x, batch_y
    

    其中第6、7行的代码表示,当计算得到最后一个batch的结束索引大于数据集长度时,则只取到最后一个样本即可,也就是说,最后一个batch的样本数可能没有batch size个。例如100个样本,batch size 为40,那么每个batch的样本数分别为:40,40和20。

3 框架介绍

3.1 定义正向传播

def forward(x, w, b):f = tf.matmul(x, w) + breturn f

3.2 定义损失

def softmax_cross_entropy(labels, logits):soft_logits = tf.nn.softmax(logits)soft_logits = tf.clip_by_value(soft_logits, 0.000001, 0.999999)cross_entropy = -tf.reduce_sum(labels * tf.log(soft_logits), axis=1)return cross_entropy

tf.nn.softmax()的作用为实现softmax操作的,但有一点值得要说的就是.softmax(dim=-1)中的参数dimdim的默认值是-1,也就是默认将对最后一个维度进行softmax操作。例如在现在这个分类任务中,最后得到logits的维度为[n_samples,n_class],并且我们也的确需要在最后一个维度(每一行)进行softmax操作,因此我们也就没有传入dim的值。可需要我们注意的是,不是所有情况下我们需要进行softmax操作的维度都是最后一个,所有应该要根据实际情况通过dim进行指定

tf.clip_by_value()的作用是对传入值按范围进行裁剪。例如上面第三行代码的作用就是将soft_logits的值限定在[0.000001,0.999999]中,当小于或者大于边界值时,将会强制设定为对于边界值。

tf.log()表示取自然对数,即y=log⁡exy = \log_e xy=loge​x。同时,倒数第二行代码则是用来计算所有样本的交叉熵损失。当然,这个softmax_cross_entropy这个函数的功能tensorflow也已经帮我们实现好了,在下一篇文章中我们会对其进行介绍。

3.3 定义模型

def train(x_train, y_train, x_test, y_test):learning_rate = 0.005epochs = 10batch_size = 256n = 784num_class = 10x = tf.placeholder(dtype=tf.float32, shape=[None, n], name='input_x')y = tf.placeholder(dtype=tf.float32, shape=[None, num_class], name="input_y")w = tf.Variable(tf.truncated_normal(shape=[n, num_class],mean=0, stddev=0.1,dtype=tf.float32))b = tf.Variable(tf.constant(0, dtype=tf.float32, shape=[num_class]))y_pred = forward(x, w, b)cross_entropy = softmax_cross_entropy(y, y_pred)loss = tf.reduce_mean(cross_entropy)train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)correct_prediction = tf.equal(tf.argmax(y_pred, axis=1), tf.argmax(y, axis=1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

由于代码稍微有点长,为了排版美观就分成了定义模型和训练模型来进行介绍,其实两者都在一个函数中。上面代码中,前面大部分都是我们在上一篇文章中说过的,在这里就不再赘述。其中倒数第5、6行代码计算得到交叉熵损失值;tf.train.AdamOptimizer()为基于梯度下降算法改进的另外一种考虑到动量的优化器。

倒数第二行则用来计算预测正确和错误的样本,其中tf.argmax()的作用为取概率值最大所对应的索引(即类标)。例如tf.argmax([0.2,0.5,0.3])的结果为0.5所对应的下标,类似的还有tf.argmin()tf.equal()则用来判断两个输入是否相等的情况,例如tf.equal([0,5,3,6,1],[0,5,3,1,1])的结果为[True,True,True,False,True]

最后一行代码则是用来计算预测的准确率,其中tf.cast()表示进行类型转换,即可以将True转换为1,False转换为0,这样就可以通过计算平均值得到准确率。

3.4 训练模型

with tf.Session() as sess:sess.run(tf.global_variables_initializer())for epoch in range(epochs):for step, (batch_x, batch_y) in enumerate(gen_batch(x_train, y_train,                    batch_size=batch_size)):batch_x = batch_x.reshape(len(batch_y), -1)feed_dict = {x: batch_x, y: batch_y}l, acc, _ = sess.run([loss, accuracy, train_op], feed_dict=feed_dict)if step % 100 == 0:print("Epochs{}/{}--Batches{}/{}--loss on train:{:.4}--acc: {:.4}".format(epoch, epochs, step, len(x_train) // batch_size, l, acc))if epoch % 2 == 0:total_correct = 0for batch_x, batch_y in gen_batch(x_test, y_test, batch_size):batch_x = batch_x.reshape(len(batch_y), -1)feed_dict = {x: batch_x, y: batch_y}c = sess.run(correct_prediction, feed_dict=feed_dict)total_correct += np.sum(c * 1.)print("Epochs[{}/{}]---acc on test: {:.4}".format(epoch, epochs, total_correct / len(y_test)))

这部分代码仍旧是函数train()中的代码,可以发现里面大部分的知识点我们上一篇文章中已经做过介绍了。这里稍微需要说一下的就是在本示例中,每个epoch喂入的数据并不是全部样本,而是分batch进行输入。其次是在测试集上计算准确率时,我们先是得到了所有batch中预测正确的样本数量,然后除以总数得到准确率的。最后需要注意的是,在sess.run()中如果是计算多个节点的值,则传入的应该是一个list(例如上面代码第七行);如果仅仅只是计算一个节点的值,则传入改节点即可(例如上面代码倒数第四行)。

3.5 运行结果

Epochs[0/10]---Batches[0/234]---loss on train:3.253---acc: 0.07031
Epochs[0/10]---Batches[100/234]---loss on train:0.5581---acc: 0.7891
Epochs[0/10]---Batches[200/234]---loss on train:0.5026---acc: 0.8359
Epochs[0/10]---acc on test: 0.8149
Epochs[1/10]---Batches[0/234]---loss on train:0.4793---acc: 0.8555
Epochs[1/10]---Batches[100/234]---loss on train:0.4411---acc: 0.8281
Epochs[1/10]---Batches[200/234]---loss on train:0.444---acc: 0.8672
Epochs[2/10]---Batches[0/234]---loss on train:0.4289---acc: 0.8789
Epochs[2/10]---Batches[100/234]---loss on train:0.4132---acc: 0.8398
Epochs[2/10]---Batches[200/234]---loss on train:0.421---acc: 0.8672
Epochs[2/10]---acc on test: 0.8335

4 总结

在这篇文章中,笔者首先介绍了如何定义Softmax回归的正向传播以及如何计算得到交叉熵损失函数;然后介绍了如何定义模型中的参数以及准确率的计算等。并同时依次介绍了Tensorflow中所涉及到的各个API,例如tf.nn.softmax()tf.clip_by_value()tf.equal()等。

本次内容就到此结束,感谢您的阅读!若有任何疑问与建议,请添加笔者微信’nulls8’进行交流。青山不改,绿水长流,我们月来客栈见!

引用

[1]示例代码:https://github.com/moon-hotel/Tensorflow1.xTutorials

近期文章

[1]Tensorflow实现线性回归

[2]Tensorflow运行模式

[3]你们要的Tensorflow入坑指南来了

Tensorflow实现Softmax回归相关推荐

  1. TensorFlow实现Softmax

    TensorFlow实现Softmax Regression识别手写数字 本文是按照黄文坚.唐源所著的<TensorFlow实战>一书,进行编写.在TensorFlow实战之余,力求简洁地 ...

  2. 简单探索MNIST(Softmax回归和两层CNN)-Tensorflow学习

    简述 这次是在看<21个项目玩转深度学习>那本书的第一章节后做的笔记. 这段时间,打算把TensorFlow再补补,提升一下技术水平~ 希望我能坚持下来,抽空把这本书刷下来吧~ 导入数据 ...

  3. TensorFlow HOWTO 1.4 Softmax 回归

    1.4 Softmax 回归 Softmax 回归可以看成逻辑回归在多个类别上的推广. 操作步骤 导入所需的包. import tensorflow as tf import numpy as np ...

  4. TensorFlow精进之路(一):Softmax回归模型训练MNIST

    1.MNIST数据集简介: MNIST数据集主要由一些手写数字的图片和相应标签组成,图片总共分为10类,分别对应0-9十个数字. 如上图所示,每张图片的大小为28×28像素.而标签则由one-hot向 ...

  5. TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%)

    TF:利用是Softmax回归+GD算法实现MNIST手写数字图片识别(10000张图片测试得到的准确率为92%) 目录 设计思路 全部代码 设计思路 全部代码 #TF:利用是Softmax回归+GD ...

  6. TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率

    TF之LoR:基于tensorflow利用逻辑回归算LoR法实现手写数字图片识别提高准确率 目录 输出结果 设计代码 输出结果 设计代码 #TF之LoR:基于tensorflow实现手写数字图片识别准 ...

  7. 【深度学习】基于MindSpore和pytorch的Softmax回归及前馈神经网络

    1 实验内容简介 1.1 实验目的 (1)熟练掌握tensor相关各种操作: (2)掌握广义线性回归模型(logistic模型.sofmax模型).前馈神经网络模型的原理: (3)熟练掌握基于mind ...

  8. 机器学习心得(三)——softmax回归

    机器学习心得(三)--softmax回归 在上一篇文章中,主要以二分类为例,讲解了logistic回归模型原理.那么对于多分类问题,我们应该如何处理呢?当然,选择构建许多二分类器进行概率输出自然是一个 ...

  9. 基于一个线性层的softmax回归模型和MNIST数据集识别自己手写数字

    原博文是用cnn识别,因为我是在自己电脑上跑代码,用不了处理器,所以参考Mnist官网上的一个线性层的softmax回归模型的代码,把两篇文章结合起来识别. 最后效果 源代码识别mnist数据集的准确 ...

  10. 基于 TensorFlow 的逻辑回归详解

    Logistic ( 逻辑回归 ) 一.基本概念简介以及理论讲解 1.1.回归 1.2.Logistic 函数的逆函数 –> Logit 函数 1.2.1.伯努利分布 1.2.2 Logit 函 ...

最新文章

  1. mapreduce python实例_MapReduce程序实例(python)
  2. kafka消费者如何读同一生产者消息_Kafka消费者生产者实例
  3. NetScaler SDWAN 详细配置手册
  4. 64位系统上安装apache
  5. Linux之VI命令详解
  6. P4172-[WC2006]水管局长【LCT,最小生成树】
  7. TensorFlow构建二维数据拟合模型(1)
  8. java云题库测试使用说明 0917
  9. 238.除自身以外数组的乘积 (力扣leetcode) 博主可答疑该问题
  10. 中兴F607ZA查看超级管理员密码
  11. 坪山区关于开展2022年度科技创新专项资金申报工作的通知
  12. 手机的内核版本、基带版本等都是什么意思?
  13. vue 使用qrcode生成二维码功能
  14. 基于SpringBoot+Bootstrap【爱码个人博客系统】附源码
  15. 计算机保研面试自我介绍,计算机保研面试英文自我介绍范文
  16. mysql分区表去重复_MySQL分区表管理
  17. oracle vs. SQL 同义词synonym 别名 alias
  18. 用chrome按F12抓包 页面跳转POST一瞬间就闪没了
  19. 商业智能(Business Intelligence,简称:BI)
  20. Flink (四) Flink 的安装和部署- Flink on Yarn 模式 / 集群HA / 并行度和Slot

热门文章

  1. 内网计算机可以使用键盘,如何在同一个局域网里一套键盘鼠标操作多台电脑?...
  2. KeyPass密码管理软件使用说明
  3. 理解Iass Pass SasS三种云服务区别
  4. 设计之星 ai_二十万人的AI成长之路 ,百度之星用十五年去点亮
  5. 暑假学习打卡【4】——北理工乐学第四周作业
  6. 服务器 字体文件夹,服务器安装字体
  7. VPX信号处理板学习资料第274篇:基于XC7V690T的3U VPX信号处理板
  8. 服务器内存系统,服务器内存系统容量
  9. 基于MATLAB的战术手势识别功能的设计与实现
  10. 网易云音乐java爬虫_用Java实现网易云音乐爬虫