使用TensorFlow v2.0实现逻辑斯谛回归

此示例使用简单方法来更好地理解训练过程背后的所有机制

MNIST数据集概览

此示例使用MNIST手写数字。该数据集包含60,000个用于训练的样本和10,000个用于测试的样本。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到255。

在此示例中,每个图像将转换为float32,归一化为[0,1],并展平为784个特征(28 * 28)的1维数组。

from __future__ import absolute_import,division,print_functionimport tensorflow as tfimport numpy as np# MNIST 数据集参数num_classes = 10 # 数字0-9num_features = 784 # 28*28# 训练参数learning_rate = 0.01training_steps = 1000batch_size = 256display_step = 50# 准备MNIST数据from tensorflow.keras.datasets import mnist(x_train, y_train),(x_test,y_test) = mnist.load_data()# 转换为float32x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)# 将图像平铺成784个特征的一维向量(28*28)x_train, x_test = x_train.reshape([-1, num_features]), x_test.reshape([-1, num_features])# 将像素值从[0,255]归一化为[0,1]x_train,x_test = x_train / 255, x_test / 255# 使用tf.data api 对数据随机分布和批处理train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)# 权值矩阵形状[784,10],28 * 28图像特征数和类别数目W = tf.Variable(tf.ones([num_features, num_classes]), name="weight")# 偏置形状[10], 类别数目b = tf.Variable(tf.zeros([num_classes]), name="bias")# 逻辑斯谛回归(Wx+b)def logistic_regression(x): #应用softmax将logits标准化为概率分布 return tf.nn.softmax(tf.matmul(x,W) + b)# 交叉熵损失函数def cross_entropy(y_pred, y_true): # 将标签编码为一个独热编码向量 y_true = tf.one_hot(y_true, depth=num_classes) # 压缩预测值以避免log(0)错误 y_pred = tf.clip_by_value(y_pred, 1e-9, 1.) # 计算交叉熵 return tf.reduce_mean(-tf.reduce_sum(y_true * tf.math.log(y_pred)))# 准确率度量def accuracy(y_pred, y_true): # 预测的类别是预测向量中最高分的索引(即argmax) correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64)) return tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# 随机梯度下降优化器optimizer = tf.optimizers.SGD(learning_rate)# 优化过程def run_optimization(x, y): #将计算封装在GradientTape中以实现自动微分 with tf.GradientTape() as g: pred = logistic_regression(x) loss = cross_entropy(pred, y)  # 计算梯度 gradients = g.gradient(loss, [W, b])  # 根据gradients更新 W 和 b optimizer.apply_gradients(zip(gradients, [W, b]))# 针对给定训练步骤数开始训练for step, (batch_x,batch_y) in enumerate(train_data.take(training_steps), 1): # 运行优化以更新W和b值 run_optimization(batch_x, batch_y)  if step % display_step == 0: pred = logistic_regression(batch_x) loss = cross_entropy(pred, batch_y) acc = accuracy(pred, batch_y) print("step: %i, loss: %f, accuracy: %f" % (step, loss, acc))

output:

step: 50, loss: 608.584717, accuracy: 0.824219step: 100, loss: 828.206482, accuracy: 0.765625step: 150, loss: 716.329407, accuracy: 0.746094step: 200, loss: 584.887634, accuracy: 0.820312step: 250, loss: 472.098114, accuracy: 0.871094step: 300, loss: 621.834595, accuracy: 0.832031step: 350, loss: 567.288818, accuracy: 0.714844step: 400, loss: 489.062988, accuracy: 0.847656step: 450, loss: 496.466675, accuracy: 0.843750step: 500, loss: 465.342224, accuracy: 0.875000step: 550, loss: 586.347168, accuracy: 0.855469step: 600, loss: 95.233109, accuracy: 0.906250step: 650, loss: 88.136490, accuracy: 0.910156step: 700, loss: 67.170349, accuracy: 0.937500step: 750, loss: 79.673691, accuracy: 0.921875step: 800, loss: 112.844872, accuracy: 0.914062step: 850, loss: 92.789581, accuracy: 0.894531step: 900, loss: 80.116165, accuracy: 0.921875step: 950, loss: 45.706650, accuracy: 0.925781step: 1000, loss: 72.986969, accuracy: 0.925781# 在验证集上测试模型pred = logistic_regression(x_test)print("Test Accuracy: %f" % accuracy(pred, y_test))

output:

Test Accuracy: 0.901100# 可视化预测import matplotlib.pyplot as plt# 在验证集上中预测5张图片n_images = 5test_images = x_test[:n_images]predictions = logistic_regression(test_images)# 可视化图片和模型预测结果for i in range(n_images): plt.imshow(np.reshape(test_images[i],[28,28]), cmap='gray') plt.show() print("Model prediction: %i" % np.argmax(predictions.numpy()[i]))

output:

Model prediction: 7
Model prediction: 2
Model prediction: 1
Model prediction: 0
Model prediction: 4

判别器loss为0_TensorFlow v2.0实现逻辑斯谛回归相关推荐

  1. 判别器loss为0_scikitlearn—线性判别分析和二次判别分析

    线性判别分析(discriminant_analysis.LinearDiscriminantAnalysis)和二次判别分析(discriminant_analysis.QuadraticDiscr ...

  2. 动易SiteFactory CMS自动采集器 V2.0

    动易SiteFactory CMS采集器说明   注:本程序需要.NET FrameWork 2.0或2.0以上版本的支持(Vista可直接运行本程序,无需安装.NET FrameWork 2.0) ...

  3. Kinect for Windows SDK v2.0 开发笔记 (十三) 高清面部帧(4) 面部模型构建器

     (转载请注明出处) 使用SDK: Kinect for Windows SDK v2.0 public preview1409 同前面,因为SDK未完成,不附上函数/方法/接口的超链接. 这次让 ...

  4. 章节分割器 v2.0 Beta0618 版

    下载:点击此处下载 章节分割器 v2.0 Beta0618 ===================================== 一个把文本小说按照自定义条件切割成章节的软件,没有多么复杂的设置 ...

  5. ENFI下载器 v2.0.3免费版

    点击下载来源:ENFI下载器 v2.0.3免费版 ENFI下载器是一款绿色安全的支持百度网盘资源加速下载工具.界面简洁明了,上手简单,支持各路老司机加速下载网盘资源的软件.资源下载能赚钱+超高速下载, ...

  6. 网游限时器 v2.0 官方

    Welcome to my blog! <script language="javascript" src="http://avss.b15.cnwg.cn/cou ...

  7. python3GUI--翻译器-v2.0(附源码)

    文章目录 一.准备工作 二.预览 1.主界面 2.翻译 3.支持多种语言哦 三.源代码 四.总结 有一个月没发博客了,上次用Tk做了一个翻译器,界面比较粗糙,翻译语言限定为汉语,文本仅支持手动粘贴,遂 ...

  8. MMDetection V2.0发布!速度精度全面提升,现有检测框架最优

    本文授权转自知乎作者陈恺,https://zhuanlan.zhihu.com/p/145084667.未经作者许可,不得二次转载. MMDetection V1.0 版本发布以来,我们收到了很多用户 ...

  9. O-GAN:简单修改,让GAN的判别器变成一个编码器!

    2019-03-08 08:36 作者丨苏剑林 单位丨广州火焰信息科技有限公司 研究方向丨NLP,神经网络 个人主页丨kexue.fm 本文来给大家分享一下笔者最近的一个工作:通过简单地修改原来的 G ...

最新文章

  1. 查看spark是否有僵尸进程,有的话,先杀掉。可以使用下面命令
  2. Kth Largest Element in an Array
  3. 命令行 笔记本键盘禁用_宏碁发布Enduro系列三防笔记本电脑和平板电脑
  4. 五分钟用vue实现一个五星打分效果
  5. 【caffe-Windows】cifar实例编译之model的使用
  6. CrossPHP框架的常用操作
  7. 初学者看看PHP explode() 函数 第6篇
  8. c#中textbox属性_C#.Net中带有示例的TextBox.Multiline属性
  9. 移除集合效率高还是add高_java集合详解
  10. sql行转列 列数据不定 sql交叉报表实例
  11. lenovo 笔记本ideapad 320c-15改装win7问题
  12. Python的7大就业方向,转行的人适合哪个方向?学了Python能干什么?
  13. Java for循环的几种用法详解(转载)
  14. 计算机扩展屏幕管理软件,小智桌面 - 桌面助手 - 桌面管理美化软件
  15. mp4parser库
  16. phython编写图形界面
  17. 由 UWP 版网易云音乐闪退引发的博文
  18. X265-Android
  19. 火车头传数据到mysql_火车头采集器采集文章使用教程实例
  20. 学习,使用主成分分析 (Principal components analysis,PCA)处理数据必看文章

热门文章

  1. SpringCloud干货(2)---------大时代下的分布式微服务
  2. PostgreSQL数据类型-枚举类型、几何类型、网络地址类型和其他数据类型
  3. 参数--argumengs
  4. 【js与jquery】三级联动菜单的制作
  5. linux 下mysql安装配置管理以及优化
  6. 在网页中给Flash加上超级链接
  7. expect返回值给shell_使用expect实现shell自动交互
  8. 微信摇一摇插件ios_iOS实现微信摇一摇功能
  9. springboot创建单个对象
  10. Nginx防盗链的实现原理和实现步骤