学习笔记:Mnist 手写训练集 加入隐藏层后准确率变为0.1的解决办法

  • 提高神经网络准确率的尝试
    • 提高准确率:调小每次训练的批次大小
    • 提高准确率:使用交叉熵
    • 更改优化器及学习率
    • 小结

提高神经网络准确率的尝试

首先是增加隐藏层数,这样有助于提高结果的非线性性,我这里加入了一个100个神经元的中间层,训练目标是将准确率提高到95%以上,不高的要求,事实证明加一层就足够了。

#构建神经网络784-100-10
w_L1=tf.Variable(tf.zeros([784,100]))
b_L1=tf.Variable(tf.zeros([1,100]))
wx_plus_b_L1=tf.matmul(x,w_L1)+b_L1
l1=tf.nn.relu(wx_plus_b_L1)#输出层
w_L2=tf.Variable(tf.zeros([100,10]))
b_L2=tf.Variable(tf.zeros([10]))
prediction=tf.nn.softmax(tf.matmul(l1,w_L2)+b_L2)

但这样添加后,不管如何调激活函数还是训练次数,准确率都是0.1135.
原因出在第一层权重设置上,如果改为正太分布的随机值,那么问题就解决了。

改为正太分布的权重:

w_L1=tf.Variable(tf.truncated_normal(([784,100]),stddev=0.1))

训练20次准确度变化:
Iter20Test Accuracy 0.9314
训练30次准确度变化:
Iter30Test Accuracy0.9399
如果继续增加训练次数,准确率应该可以达到95%

提高准确率:调小每次训练的批次大小

如果将batch_size从100调到50,准确率会有明显的上升。
batch_size=50时训练30次的准确率。
Iter30Test Accuracy 0.9567

如果将把batch_size调大呢?
batch_size=200
Iter30Test Accuracy0.9236
可以看到准确率下降了,所以batch_size越小,准确率越高,这可能和优化的次数有关,每次批量处理数据越小,循环完所有数据需要迭代的次数就越多,优化的效果就越好

提高准确率:使用交叉熵

在激活函数为sigmo函数时,使用交叉熵可以获得更好的训练效果,在loss函数中使用交叉熵函数的方法为:

tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction)

这里需要注意一下,在传入logits参数时,prediction值为:

prediction=tf.nn.softmax(tf.matmul(l1,w_L2)+b_L2)

如果以求完softmax的值作为传入值,训练30次后准确率为:

Iter30Test Accuracy0.9642

但如果改为不求softmax,直接传入计算结果值,效果怎么样呢?

prediction=tf.matmul(l1,w_L2)+b_L2

Iter0Test Accuracy0.9336
起步准确率 达到0.93.
训练10个周期后
Iter9Test Accuracy 0.9747

准确率直接提高到0.97,最后准确率保持在0.97左右震荡。

因此,对于交叉熵,logits的传入值应该为不处理softmax的效果最好,这里有人解释是因为在交叉熵中会对logits求一次softmax函数。结合在不使用softmax对L1层数据处理时训练速度更快,后者的解释应该是可靠的。

更改优化器及学习率

首先,随机梯度下降法的收敛速度比较慢,要达到最低点需要迭代很多次,并且对于马鞍面问题其无法逃离马鞍低点平面,这里推荐几个优化器。 Adadelta, Adagrad, 及NAG优化器。
其中 Adadelta, Adagrad收敛速度最快,NAG是优化后的动量优化器,速度也可以。
在TensorFlow中,三者的方法可以很方便的调出。

train_step=tf.train.AdadeltaOptimizer(1e-3).minimize(loss)
train_step=tf.train.AdagradOptimizer(1e-3).minimize(loss)
train_step=tf.train.AdamOptimizer(1e-3).minimize(loss)

在使用了AdamOptimizer,学习率设置为初始0.001,每迭代一次,学习率乘以0.95以后,训练30次,准确率达到98%以上。

lr=tf.Variable(0.001,dtype=tf.float32)

后半段代码

train_step=tf.train.AdamOptimizer(lr).minimize(loss)#初始变量init=tf.global_variables_initializer()#求准确率
with tf.name_scope('accuracy'):with tf.name_scope('correct_prediction'):correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))with tf.name_scope('accuracy'):accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))tf.summary.scalar('accuracy',accuracy)#合并所有指标#训练
with tf.Session() as sess:sess.run(init)for epoch in range(51):sess.run(tf.assign(lr,0.001*0.95**(epoch)))for batch in range(n_batch):batch_xs,batch_ys=mnist.train.next_batch(batch_size)sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})learning_rate=sess.run(lr)test_acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})print('Iter'+str(epoch)+'Test Accuracy'+str(test_acc)+'Learning rate='+str(learning_rate))

小结

将单层神经元数量增大(比如1000),或者增加中间层,可以提高准确率,但这样想要提高到98%以上(注意对权重及偏置值的初始化进行修改,这里用到的是截断的正太分布),还必须对loss函数及优化器进行调整,loss函数修改为交叉熵有利于提高准确率,最后通过使用更好的优化器结合自适应的学习率(在迭代后期不断减小的学习率),可以使损失函数尽量落在最低点而不冲出最优点而造成震荡。

Tensorflow 学习笔记:Mnist 手写训练集调试,准确率变为0.1的解决办法及如何将准确率调高到98%以上相关推荐

  1. MOOC网深度学习应用开发1——Tensorflow基础、多元线性回归:波士顿房价预测问题Tensorflow实战、MNIST手写数字识别:分类应用入门、泰坦尼克生存预测

    Tensorflow基础 tensor基础 当数据类型不同时,程序做相加等运算会报错,可以通过隐式转换的方式避免此类报错. 单变量线性回归 监督式机器学习的基本术语 线性回归的Tensorflow实战 ...

  2. 用tensorflow框架和Mnist手写字体,训练cnn模型以及测试一张手写字体

    感想 首先我是首先看了一下莫凡pyhton教程中tensorflow python搭建自己的神经网络教程以及查看了官方的教程TensorFlow中文社区-MNIST进阶教程,这里面只是有简单的测试出来 ...

  3. Tensorflow之基于MNIST手写识别的入门介绍

    Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...

  4. 深度学习笔记:手写一个单隐层的神经网络

    出处:数据科学家养成记 深度学习笔记2:手写一个单隐层的神经网络 笔记1中我们利用 numpy 搭建了神经网络最简单的结构单元:感知机.笔记2将继续学习如何手动搭建神经网络.我们将学习如何利用 num ...

  5. 【安全牛学习笔记】CSRF跨站请求伪造攻击漏洞的原理及解决办法

    CSRF跨站请求伪造攻击漏洞的原理及解决办法 CSRF,夸张请求伪造漏洞 漏洞的原理及修复方法 1.常见的触发场景 2.漏洞原理:浏览器同源策略 3.DEMO 4.漏洞危害 5.如何避免&修复 ...

  6. 《深度学习的数学》学习笔记(手写扫描)

    <深度学习的数学>(人民邮电出版社)本书主要介绍了阶层型神经网络.卷积神经网络.梯度下降法以及误差反向传播法(BP).书中语言风格比较俏皮,深入浅出.就是实战用的是Excel这点比较有个性 ...

  7. NSR学习笔记(手写版)

  8. NAC学习笔记(手写版)

  9. URPF学习笔记(手写版)

最新文章

  1. 计算机原理基础知识pdf,计算机原理第一章.pdf
  2. Log4Net 使用 FileAppender (log4net 1.2.10.0)
  3. CSDN蒋涛大数据表明:DCO - 区块链时代企业级服务的全新机会
  4. 有关项目实施【老男孩】的经验分享
  5. vue3与vue2的详细区别
  6. [Python图像处理] 一.图像处理基础知识及OpenCV入门函数
  7. uva 1203—— Argus
  8. 关于进程资源限制的getrlimit和setrlimit函数(epoll、服务器经常用)
  9. 决策树模型(ID3/C4.5/CART)原理和底层代码解读 学习笔记
  10. android支付平台,android移动支付
  11. python定义变量_Python基础 变量的基本使用
  12. java layer调用native层的android_media_AudioTrack_get_min_buff_size()确定audio track buffer的min size...
  13. 类似switchhost 的简单host切换工具
  14. Android模拟器 使用 Fiddler抓包
  15. 2021-12-11 根据单词首字母查找单词
  16. Windows10如何添加开机启动项
  17. 学校oj显示在线用户数超过了序列号允许。您需要购买或升级您的序列号
  18. Nature证实:学术界刮起离职潮!大批学者涌向工业界,互助文档日均20个学者离职...
  19. 【单片机】用定时器以间隔500ms在8位数码管上依次显示0、1、2、3、...C、D、E、F,重复
  20. 常用独立自建站工具大盘点,哪个性价比更高?

热门文章

  1. 背景图片居中全屏自适应显示
  2. 最简单最节省成本的锂电池充电电路!拆开火火兔,搬起小板凳,听老梁分析...
  3. vue项目在IE浏览器和360兼容模式下页面不显示问题,亲测有效
  4. 修改mobaxterm终端字体大小
  5. pyquery - PyQuery完整的API
  6. browserify 使用
  7. HDU 1009 FatMouse' Trade
  8. JumpServer重置管理员密码
  9. 【C#】WixToolset快速入门教程
  10. 核密度估计和非参数回归