TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线

目录

输出结果

设计代码


输出结果

设计代码

import tensorflow as tf
from sklearn.datasets import load_digits
#from sklearn.cross_validation import train_test_split
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer# load data
digits = load_digits()  X = digits.data
y = digits.target
y = LabelBinarizer().fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3)def add_layer(inputs, in_size, out_size, layer_name, activation_function=None, ):# add one more layer and return the output of this layerWeights = tf.Variable(tf.random_normal([in_size, out_size]))biases = tf.Variable(tf.zeros([1, out_size]) + 0.1, )Wx_plus_b = tf.matmul(inputs, Weights) + biases# here to dropoutWx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob)  if activation_function is None:outputs = Wx_plus_belse:outputs = activation_function(Wx_plus_b, )tf.summary.histogram(layer_name + '/outputs', outputs) return outputs# define placeholder for inputs to network
keep_prob = tf.placeholder(tf.float32)
xs = tf.placeholder(tf.float32, [None, 64])
ys = tf.placeholder(tf.float32, [None, 10])# add output layer
l1 = add_layer(xs, 64, 50, 'l1', activation_function=tf.nn.tanh)
prediction = add_layer(l1, 50, 10, 'l2', activation_function=tf.nn.softmax) # the loss between prediction and real data
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),reduction_indices=[1]))
tf.summary.scalar ('loss', cross_entropy)
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) sess = tf.Session()
merged =  tf.summary.merge_all()
# summary writer goes in here
train_writer = tf.summary.FileWriter("logs4/train", sess.graph)
test_writer = tf.summary.FileWriter("logs4/test", sess.graph)    sess.run(tf.global_variables_initializer()) for i in range(500):  # here to determine the keeping probabilitysess.run(train_step, feed_dict={xs: X_train, ys: y_train, keep_prob: 0.5})  if i % 50 == 0:# record losstrain_result = sess.run(merged, feed_dict={xs: X_train, ys: y_train, keep_prob: 1})test_result = sess.run(merged, feed_dict={xs: X_test, ys: y_test, keep_prob: 1})train_writer.add_summary(train_result, i)  test_writer.add_summary(test_result, i)

相关文章
TF:利用sklearn自带数据集使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线

TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线相关推荐

  1. DL之NN:基于(sklearn自带手写数字图片识别数据集)+自定义NN类(三层64→100→10)实现97.5%准确率

    DL之NN:基于(sklearn自带手写数字图片识别数据集)+自定义NN类(三层64→100→10)实现97.5%准确率 目录 输出结果 核心代码 输出结果 核心代码 #DL之NN:基于sklearn ...

  2. DL之DNN:利用DNN算法对mnist手写数字图片识别数据集(sklearn自带,1797*64)训练、预测(95%)

    DL之DNN:利用DNN算法对mnist手写数字图片识别数据集(sklearn自带,1797*64)训练.预测(95%) 目录 数据集展示 输出结果 设计代码 数据集展示 先查看sklearn自带di ...

  3. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 下边两张 ...

  4. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 1.10 ...

  5. TF:基于CNN(2+1)实现MNIST手写数字图片识别准确率提高到99%

    TF:基于CNN(2+1)实现MNIST手写数字图片识别准确率提高到99% 导读 与Softmax回归模型相比,使用两层卷积的神经网络模型借助了卷积的威力,准确率高非常大的提升. 目录 输出结果 代码 ...

  6. TF之DNN:利用DNN【784→500→10】对MNIST手写数字图片识别数据集(TF自带函数下载)预测(98%)+案例理解DNN过程

    TF之DNN:利用DNN[784→500→10]对MNIST手写数字图片识别数据集(TF自带函数下载)预测(98%)+案例理解DNN过程 目录 输出结果 案例理解DNN过程思路 代码设计 输出结果 案 ...

  7. DL之DNN:利用DNN【784→50→100→10】算法对MNIST手写数字图片识别数据集进行预测、模型优化

    DL之DNN:利用DNN[784→50→100→10]算法对MNIST手写数字图片识别数据集进行预测.模型优化 导读 目的是建立三层神经网络,进一步理解DNN内部的运作机制 目录 输出结果 设计思路 ...

  8. TF之LiR:基于tensorflow实现手写数字图片识别准确率

    TF之LiR:基于tensorflow实现手写数字图片识别准确率 目录 输出结果 代码设计 输出结果 Extracting MNIST_data\train-images-idx3-ubyte.gz ...

  9. ML之K-means:基于(完整的)手写数字图片识别数据集利用K-means算法实现图片聚类

    ML之K-means:基于(完整的)手写数字图片识别数据集利用K-means算法实现图片聚类 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 metrics.adjusted_ran ...

最新文章

  1. hihocoder 1490 Tree Restoration
  2. 运用node实现简单爬虫
  3. NHibernate部分错误
  4. mysql和mariadb可以同时使用吗_10分钟实现MariaDB与MySQL在一台服务器同时运行
  5. 有赞统一接入层架构演进
  6. Tcpdump(linux)下载、安装、使用说明
  7. 无线路由器和计算机怎么连接网络连接,华为无线路由器怎么连接宽带上网
  8. HI3559A和AI深度学习框架caffe
  9. 博图买什么样配置的笔记本_3dsmax需要什么样的笔记本配置?
  10. 自动输入命令执行_Ubuntu命令行操作-命令简介
  11. 应用虑镜特效时遇到浏览器权限问题
  12. Idea查看文件结构,类似Eclipse中Ctrl+O
  13. [OpenS-CAD]屏幕坐标转换分析
  14. pb9 调用系统语音_语音通知解决方案,VIKI语音通知软件介绍
  15. 博弈论基础知识--非合作博弈,零和博弈,负和博弈,主从博弈,Nash均衡
  16. 如何判断企业微信是否在线?
  17. 标签上的title属性和alt属性有什么区别
  18. 电脑系统重装win7的教程,win7系统一键安装
  19. 【并发编程】CPU cache结构和缓存一致性(MESI协议)
  20. 昆明理工大学计算机考研分数线,昆明理工大学2015考研分数线已公布

热门文章

  1. OkHttp实现文件上传进度
  2. 《数据中心设计与运营实战》——2.6 监控基础设施
  3. linux centos7清除系统日志、历史记录、登录信息
  4. 把软件放到图片里(超强)
  5. ASP解析JSON例子
  6. nacos集群之日志狂刷fail to connect server,after trying 567 times,last try server is...
  7. 曾经废寝忘食学到的技术,现在都没用了......
  8. 说实话,你工作5年,不知道什么是Java agent技术,让我很吃惊...
  9. 干货,springboot自定义注解实现分布式锁详解
  10. 框架:SpringMVC的工作原理