从零开始的神经网络构建历程(二,用全连接前馈神经网络识别手写数字mnist)
本系列的上一篇博文最后提出了一个问题,是有关如何通过torch来实现给定的神经网络的,这里公布一下我自己的回答:
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.layer12 = nn.Linear(784, 200)self.layer23 = nn.Linear(200, 100)self.layer34 = nn.Linear(100, 10)def forward(self, input):output = self.layer12(input)output = torch.relu(output)output = self.layer23(output)output = torch.relu(output)output = self.layer34(output)output = nn.functional.softmax(output)return output
是不是感觉很简单呢,先别着急,构建一个神经网络容易,但是要训练其中的参数是很麻烦的。
此篇博文将以sklearn中的手写数字图像作为数据集,来讲述一个神经网络是如何训练和测试的。
数据集的引入
Python scikit-learn库中有一个datasets模块,此模块中收录了很多经典的训练模型用到的数据集,mnist(手写数字图像,分辨率8*8)便是其中之一。本篇博文便是基于此数据集对图片中的手写数字进行识别和分类,此数据集是神经网络初学者的必经之路。
载入这些数字图像的方法很简单,详见以下代码:
def preprocess_digit():data = load_digits()x, y = data.data, data.target
## Input Normalizationx = MinMaxScaler().fit_transform(x)return train_test_split(x, y, test_size=0.1)
一般来说,在训练分类型神经网络模型时,输入的数据是要进行标准化的,这在以上代码中的“x = MinMaxScaler().fit_transform(x)”有所体现。本人在此之前未对数据进行标准化的预处理,造成了训练时参数的爆炸…(沉痛的教训,宝贵的经验)
把上面的函数进行调用,执行下面代码,可以得到相关的数据集的信息:
xtrain, xtest, ytrain, ytest = preprocess_digit()
print(xtrain.shape)
print(xtest.shape)
print(ytrain.shape)
print(ytest.shape)
此代码可以查看数据量的大小:
说明训练集数据有1617张图像,测试集有180张图像,所有图像的大小为88,但是载入的数据是164的行向量。
神经网络的搭建
之前讲过这种代码该如何写,不多说,直接上代码:
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.layer12 = nn.Linear(64, 50)self.layer23 = nn.Linear(50, 25)self.layer34 = nn.Linear(25, 10)def forward(self, input):output = self.layer12(input)output = torch.relu(output)output = self.layer23(output)output = torch.relu(output)output = self.layer34(output)output = nn.functional.softmax(output)return output
训练流程
这里通过实例代码来阐述torch是如何进行网络的训练的,首先先贴出我训练成功时所用的代码:
def train():network = Net()optimizer, loss_func = optim.SGD(network.parameters(), lr=0.5), nn.MSELoss()for epoch in range(len(y_train)):x = Variable(torch.tensor(x_train[epoch], dtype=torch.float32)).reshape([1, 64])y = torch.zeros([1, 10])y[0, y_train[epoch]] = 1.0for i in range(10):prediction = network(x)loss = loss_func(prediction, y)optimizer.zero_grad() ## Clear to zero for the devitation of loss(loss-weight)loss.backward()optimizer.step()
一行一行来看:
- 此Python函数中第一行的“network = Net()”表示建立一个Net类的对象,Net就是之前定义的此神经网络的类,继承于父类nn.Module;
- 语句optim.SGD(network.parameters(), lr=0.5)表示我们定义了一个优化器optimizer,并采用随机梯度下降法优化模型参数,设置每一层网络参数优化时的学习率为0.5。optim是torch中的模块,专门用于优化学习模型中的参数的,SGD是“随机梯度下降”的英文简写。此时我们设置SGD方法中的两个参数:(1)待优化的对象(network.parameters(),.parameters()表示取其网络参数);(2)学习率learn_rate(简写为lr)。
- 语句nn.MSELoss()意思是指定网络的损失函数loss_func为均方差(MSE)损失函数;
- 接下来便是两层循环,本人在此训练的逻辑是:对于每一个训练数据集中的样本,都送入网络中进行随机梯度下降法的训练,每个样本反复进行10次BP算法;
- 外层循环中有三句话,这三句话的含义分别为:(1)将每一个到来的样本输入转化为torch型张量(torch.Tensor),并将其维度调整为一个长度为64的行向量(此处可依据实际情况,判断是否需要维度调整);(2)后两句话含义是,针对此数据的标签Label,制作一个标准输出向量y,此向量将在内层循环中参与计算生成损失函数loss。
- 语句prediction = network(x),表示把x作为变量输入网络Net对象network中,最终得到的输出向量作为预测变量prediction;
- 语句loss = loss_func(prediction, y)声明loss函数与prediction和y有关,其作用和Loss的定义式
相当,等效于在此定义了Loss函数的形式。(yhat意思是y的估计,Loss函数是y估计与y差值的二范数平方)
- 语句optimizer.zero_grad()意思是将优化器中的梯度清零,因为每开始此BP算法,网络的初始梯度都是0;
- 语句loss.backward()意思是执行反向传播过程,在之前定义的model类里,我们只定义了前向传播过程forward,此时反向传播过程backward就已经确定了,所以无需自己再去定义(除非网络有特殊需要);
- 语句***optimizer.step()意思是开始一次迭代优化的步骤***。之前的过程都只是自定义训练的步骤,而非真正让机器执行,只有当调用optimizer.step()之后,程序才会开始优化,否则等于没干。
以上便是利用torch训练一个神经网络的相关步骤,在数据集很大的情况下,训练的步骤耗时很大,建议程序运行时做点别的事~
保存网络参数
保存网络参数很简单,直接上代码
if not os.path.isdir('model'):os.mkdir('model')torch.save(network, './model/full_connected.cpkt')
代码的逻辑很容易看懂,当model文件夹不存在时,创建model文件夹,然后在这个文件夹里面保存网络参数。
torch.save()可以保存的参数种类有很多,除了上面代码里面的类类型network以外,还有torch.Tensor、字典、参数列表network.parameters()等。
网络参数保存为文件之后,再次使用时就需要下面代码读取模型参数:
network = torch.load('./model/full_connected.cpkt')
network.eval()
也很容易理解,在torch.load中添加文件的路径即可,但是你事先要知道这个文件中存储的参数类型,并且要用正确的变量类型来接收参数读取的结果。
模型测试
直接给出代码:
for epoch in range(len(y_test)):x = Variable(torch.tensor(x_test[epoch], dtype=torch.float32)).reshape([1, L])prediction = network(x)max = torch.max(prediction).detach().numpy()if max == prediction[0, y_test[epoch]].detach().numpy():count += 1print("The accuracy of the testSet(%s) : %2.2f\n" % (set, (count/len(y_test))))
和训练时的做法类似,只是此时没有了反向传播步骤,当我们得到最终的预测向量prediction作为输出时,prediction中最大元素所在的下标便是预测的结果,例如,如果最大元素所在下标为8,则程序识别出来的数字是8。prediction是一个长度为10的向量,原因就是0~9包含了10个数字,prediction本质上存储的是每个数字被识别的概率,其中概率最大的就是机器识别的结果,这是神经网络用于分类问题时普遍采用的策略。
效果展示
以上的图片给出了识别正确的数字的例子。
这里打印的是模型识别数字的准确率,可以发现,识别的准确率还算可以,由于本人在这里采用的是随机梯度下降法进行参数更新,所以网络有点过拟合。一般的情况下,在训练网络时都是优先采用批量梯度下降(BGD)法进行网络训练,这里只是做训练的简单演示,所以未考虑此问题。
代码整合
所有与全连接前馈神经网络有关的代码都在下列的Gitee链接中,本人在今后还会向此代码库增加与神经网络相关的内容。
神经网络代码Gitee链接
问题
如果要用神经网络进行非线性回归,代码应该怎么写,你有大致的思路吗?
从零开始的神经网络构建历程(二,用全连接前馈神经网络识别手写数字mnist)相关推荐
- Python神经网络识别手写数字-MNIST数据集
Python神经网络识别手写数字-MNIST数据集 一.手写数字集-MNIST 二.数据预处理 输入数据处理 输出数据处理 三.神经网络的结构选择 四.训练网络 测试网络 测试正确率的函数 五.完整的 ...
- BP神经网络识别手写数字项目解析及代码
这两天在学习人工神经网络,用传统神经网络结构做了一个识别手写数字的小项目作为练手.点滴收获与思考,想跟大家分享一下,欢迎指教,共同进步. 平常说的BP神经网络指传统的人工神经网络,相比于卷积神经网络( ...
- 基于Numpy构建全连接前馈神经网络进行手写数字识别
文章目录 (一) 问题描述 (二) 设计简要描述 (三) 程序清单 (四) 结果分析 (五) 调试报告 (六) 实验小结 (一) 问题描述 不使用任何机器学习框架,仅仅通过Numpy库构建一个最简单的 ...
- python手机代码识别数字_利用python构建神经网络识别手写数字(附源代码)
一.运行环境配置 本次实验的运行环境win10(bit64),采用python环境为3.7.6,安装Python环境推荐使用Anaconda.Anaconda是一个免费开源的Python和R语言的发行 ...
- 详细版【全连接前馈神经网络】(邱锡鹏)
全连接前馈神经网络 前馈神经网络又叫多层感知器,但其实是由多层 的 Logistic 回归 具体流程 令
- 全连接前馈神经网络DNN
全连接前馈神经网络DNN 1.DNN概述 前馈神经网络中,各神经元属于不同层,信号从输入层向输出层单向传播(有向无环图) 人工神经元模型:输入:x1,x2 权重:w1,w2 偏置:b ...
- 利用python实现简单的人工神经网络识别手写数字
利用 Python 搭建起了一个简单的神经网络模型,并完成识别手写数字. 1.前置工作 1.1 环境配置 这里使用scikit-learn库内建的手写数字字符集作为本文的数据集.scikit-lear ...
- BP神经网络理解原理——用Python编程实现识别手写数字(翻译英文文献)
BP神经网络理解原理--用Python编程实现识别手写数字 备注,这里可以用这个方法在csdn中编辑公式: https://www.zybuluo.com/codeep/note/163962 一 ...
- 华裔女性钱璐璐:用 DNA 开发人工智能神经网络,识别手写数字!
"既然要学人脑的思维方式,为什么不去研究人脑?"霍金斯在<论智能>中说道. 如今,不少生物学研究者正朝着这个方向努力. 不过,请注意:这不是一次传统意义上的生物实验. ...
最新文章
- np.eye()的函数能将一个label数组,大小为(1,m)或者(m,1)的数组,转化成one-hot数组
- 小程序循环不同的组建_小程序之八,对象数组、循环及条件渲染
- 模板三连击:树状数组+线段树+主席树
- 我用Python爬取1000封情书助力室友表白班花,却反转再反转...原来这就是班花的终极秘密!
- 将一个文本文件的内容按行读出,每读出一行就顺序加上行号,并写入到另一个文件中。...
- 用Java实现断点续传(HTTP)
- Games 图形学 L2线性代数
- mysql如何导出数据脚本_MySQL 导出数据
- 【案例20】NC系统was部署后无法登录
- 华为手机居然还能这样提高续航?简单设置一下,一天一充很轻松
- 从原理层面掌握@RequestAttribute、@SessionAttribute的使用【享学Spring MVC】
- 你可以成为测试界的李子柒
- JMM 8 大原子操作
- codeforces 855-B. Marvolo Gaunt's Ring(背包问题)
- 2022年8月29日 勒索病毒大爆发
- SiteRAS一款外贸网站SEO分析工具,给您的网站做个深度体检
- 腾讯、阿里、字节跳动三家公司有何区别!?
- 解决api-ms-win-core-processthreads-l1-1-1.dll文件丢失
- 11-课程详情页面静态化-课程信息模板设计
- npp夜光数据介绍 viirs_最新 夜光遥感影像VIIRSDMSP下载总结