接上篇caffe手写数字识别-训练模型

import matplotlib.pyplot as plt
import numpy as np
import caffe
# 对之前的solver.prototxt文件进行了修改,
# 没十次训练,测试一轮,显示一次训练测试loss,主要是为了画学习曲线时,x坐标参数相同
caffe.set_device(0)
caffe.set_mode_gpu()
solver = caffe.SGDSolver('mnist/solver.prototxt')
max_iter = 6000      # 训练batch_size的数量
display =  10         # 每训练10个batch_size的图片显示一次loss即1000个样本
test_iter = 100       # 每次测试100个样本,100次测试完10000个测试样例
test_interval = 10    # 测试间隔,每训练1000个样本,进行一次测试
# np.ceil([-0.1,2,1.1])  >>> [-0,2,2]  向上取整
# Try to install numpy 1.11.0 sudo pip install -U numpy==1.11.0.
# caused by unsupported float index in 1.12.0
# even if the case likex[1.0: 3.0] should be considered as valid.
'''
train_loss = np.zeros(np.ceil(max_iter*1.0/display))
test_loss  = np.zeros(np.ceil(max_iter*1.0/test_interval))
test_acc   = np.zeros(np.ceil(max_iter*1.0/test_interval))
'''
'\ntrain_loss = np.zeros(np.ceil(max_iter*1.0/display))\ntest_loss  = np.zeros(np.ceil(max_iter*1.0/test_interval))\ntest_acc   = np.zeros(np.ceil(max_iter*1.0/test_interval))\n'
train_loss = []
test_loss = []
test_acc = []
trainLoss = 0
testLoss = 0
testAcc = 0
# 完成一个batch_size样本的训练
solver.step(1)
for i in xrange(max_iter):solver.step(1)trainLoss += solver.net.blobs['SoftmaxWithLoss1'].data# 每完成20个batch_size的训练,计算一次平均lossif i % display == 0:train_loss.append(trainLoss/display)trainLoss = 0# 每完成1epoch的训练进行,一轮测试,计算测试平均loss和accuif i % test_interval == 0:# 一共10000个测试样本,每次测试100个,测试100次for ti in xrange(test_iter):solver.test_nets[0].forward()testLoss += solver.test_nets[0].blobs['SoftmaxWithLoss1'].datatestAcc += solver.test_nets[0].blobs['Accuracy1'].data# 计算平均值test_loss.append(testLoss/test_iter)test_acc.append(testAcc/test_iter)testLoss = 0testAcc = 0
len(train_loss),len(test_loss),len(test_acc),
(600, 600, 600)
plt.figure(figsize=(15,6))
plt.plot(train_loss,'r',label='Train Loss')
plt.plot(test_loss, 'g',label='Test Loss')
plt.plot(test_acc, 'b',label='Test Accu')
plt.grid()
plt.title('Learning Cruve')
plt.legend()
plt.show()

caffe手写数字分类-学习曲线相关推荐

  1. 如何为MNIST手写数字分类开发CNN

    导言 MNIST手写数字分类问题是计算机视觉和深度学习中使用的标准数据集. 虽然数据集得到了有效的解决,但它可以作为学习和实践如何开发,评估和使用卷积深度学习神经网络从头开始进行图像分类的基础.这包括 ...

  2. 独家 | 如何从头开始为MNIST手写数字分类建立卷积神经网络(附代码)

    翻译:张睿毅 校对:吴金笛 本文约9300字,建议阅读20分钟. 本文章逐步介绍了卷积神经网络的建模过程,最终实现了MNIST手写数字分类. MNIST手写数字分类问题是计算机视觉和深度学习中使用的标 ...

  3. PyTorch基础与简单应用:构建卷积神经网络实现MNIST手写数字分类

    文章目录 (一) 问题描述 (二) 设计简要描述 (三) 程序清单 (四) 结果分析 (五) 调试报告 (六) 实验小结 (七) 参考资料 (一) 问题描述 构建卷积神经网络实现MNIST手写数字分类 ...

  4. 基于PyTorch框架的多层全连接神经网络实现MNIST手写数字分类

    多层全连接神经网络实现MNIST手写数字分类 1 简单的三层全连接神经网络 2 添加激活函数 3 添加批标准化 4 训练网络 5 结论 参考资料 先用PyTorch实现最简单的三层全连接神经网络,然后 ...

  5. 练习利用LSTM实现手写数字分类任务

    练习利用LSTM实现手写数字分类任务 MNIST数据集中图片大小为28*28. 按照行进行展开成28维的特征向量. 考虑到这28个的向量之间存在着顺序依赖关系,我们可以将他们看成是一个长为28的输入序 ...

  6. Keras入门实战(1):MNIST手写数字分类

    目录 1)首先我们加载Keras中的数据集 2)网络架构 3)选择编译(compile参数) 4)准备图像数据 5) 训练模型 6)测试数据 前面的博客中已经介绍了如何在Ubuntu下安装Keras深 ...

  7. 基于tensorflow+RNN的MNIST数据集手写数字分类

    2018年9月25日笔记 tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流. RNN是recurrent neural network的简称,中文叫做循环 ...

  8. MNIST数据集手写数字分类

    参考   MNIST数据集手写数字分类 - 云+社区 - 腾讯云 目录 0.编程环境 1.下载并解压数据集 2.完整代码 3.数据准备 4.数据观察 4.1 查看变量mnist的方法和属性 4.2 对 ...

  9. 机器学习算法(九): 基于线性判别LDA模型的分类(基于LDA手写数字分类实践)

    机器学习算法(九): 基于线性判别模型的分类 1.前言:LDA算法简介和应用 1.1.算法简介 线性判别模型(LDA)在模式识别领域(比如人脸识别等图形图像识别领域)中有非常广泛的应用.LDA是一种监 ...

  10. 神经网络和深度学习(二)——一个简单的手写数字分类网络

    本文转自:https://blog.csdn.net/qq_31192383/article/details/77198870 一个简单的手写数字分类网络 接上一篇文章,我们定义了神经网络,现在我们开 ...

最新文章

  1. Markdown 如何编写表格(格式)?
  2. USTC English Club Note20171020(3)
  3. YUM更换源--yum找不到安装包(转)
  4. linux加密框架 crypto 算法crypto_register_alg的注册流程
  5. mybatis对java自定义注解的使用——入门篇
  6. Python利用结巴模块统计《水浒传》词频
  7. Python:Flask简单实现统计网站访问量
  8. Google Chrome谷歌浏览器去掉右上角更新提示图标
  9. 单纯技术背景已不吃香 MBA管理能力更被招聘者看好
  10. 响铃:360浏览器首创自有根证书,不赚钱为哪般?
  11. 2. 硬件基础知识学习
  12. Fabric 超级账本学习【1】Fabcar网络调用Fabric-Java-SDK进行简单开发 FabCar
  13. pythonsuper继承规则,Python用super继承
  14. 浮动的简介——CSS
  15. MySQL概述以及下载安装
  16. 忆恩师刘自朗,我的高中物理老师
  17. 【转载】三国演义里将领的身高
  18. java---约数个数(每日一道算法2022.9.10)
  19. JqueryEasyUI教程
  20. 用php做一个网站,如何用PHP开发一个完整的网站

热门文章

  1. SpringMVC框架学习上篇
  2. jmeter如何看tps_jmeter性能测试疑难杂症解决思路
  3. IDEA安装mysql程序包,程序包的下载!以及程序包配置到项目详解(更适合英语小白)
  4. python 动态加载代码_python 动态网页爬取?(不是加载更多页的动态网页哟)?...
  5. html javascript 表格id,javascript 获取表格中元素id的实现代码
  6. 未解bug001:SSM整合的过程中单元测试用Junit5复合注解整合失败
  7. position:relative/absolute无法冲破的等级
  8. Windows 2016 减肥
  9. 实验:DHCP中继代理
  10. android ViewFlipper的使用