本系列的上一篇博文最后提出了一个问题,是有关如何通过torch来实现给定的神经网络的,这里公布一下我自己的回答:

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.layer12 = nn.Linear(784, 200)self.layer23 = nn.Linear(200, 100)self.layer34 = nn.Linear(100, 10)def forward(self, input):output = self.layer12(input)output = torch.relu(output)output = self.layer23(output)output = torch.relu(output)output = self.layer34(output)output = nn.functional.softmax(output)return output

是不是感觉很简单呢,先别着急,构建一个神经网络容易,但是要训练其中的参数是很麻烦的。

此篇博文将以sklearn中的手写数字图像作为数据集,来讲述一个神经网络是如何训练和测试的。

数据集的引入

Python scikit-learn库中有一个datasets模块,此模块中收录了很多经典的训练模型用到的数据集,mnist(手写数字图像,分辨率8*8)便是其中之一。本篇博文便是基于此数据集对图片中的手写数字进行识别和分类,此数据集是神经网络初学者的必经之路。

载入这些数字图像的方法很简单,详见以下代码:

def preprocess_digit():data = load_digits()x, y = data.data, data.target
## Input Normalizationx = MinMaxScaler().fit_transform(x)return train_test_split(x, y, test_size=0.1)

一般来说,在训练分类型神经网络模型时,输入的数据是要进行标准化的,这在以上代码中的“x = MinMaxScaler().fit_transform(x)”有所体现。本人在此之前未对数据进行标准化的预处理,造成了训练时参数的爆炸…(沉痛的教训,宝贵的经验)

把上面的函数进行调用,执行下面代码,可以得到相关的数据集的信息:

xtrain, xtest, ytrain, ytest = preprocess_digit()
print(xtrain.shape)
print(xtest.shape)
print(ytrain.shape)
print(ytest.shape)

此代码可以查看数据量的大小:

说明训练集数据有1617张图像,测试集有180张图像,所有图像的大小为88,但是载入的数据是164的行向量。

神经网络的搭建

之前讲过这种代码该如何写,不多说,直接上代码:

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.layer12 = nn.Linear(64, 50)self.layer23 = nn.Linear(50, 25)self.layer34 = nn.Linear(25, 10)def forward(self, input):output = self.layer12(input)output = torch.relu(output)output = self.layer23(output)output = torch.relu(output)output = self.layer34(output)output = nn.functional.softmax(output)return output

训练流程

这里通过实例代码来阐述torch是如何进行网络的训练的,首先先贴出我训练成功时所用的代码:

def train():network = Net()optimizer, loss_func = optim.SGD(network.parameters(), lr=0.5), nn.MSELoss()for epoch in range(len(y_train)):x = Variable(torch.tensor(x_train[epoch], dtype=torch.float32)).reshape([1, 64])y = torch.zeros([1, 10])y[0, y_train[epoch]] = 1.0for i in range(10):prediction = network(x)loss = loss_func(prediction, y)optimizer.zero_grad() ## Clear to zero for the devitation of loss(loss-weight)loss.backward()optimizer.step()

一行一行来看:

  1. 此Python函数中第一行的“network = Net()”表示建立一个Net类的对象,Net就是之前定义的此神经网络的类,继承于父类nn.Module;
  2. 语句optim.SGD(network.parameters(), lr=0.5)表示我们定义了一个优化器optimizer,并采用随机梯度下降法优化模型参数,设置每一层网络参数优化时的学习率为0.5。optim是torch中的模块,专门用于优化学习模型中的参数的,SGD是“随机梯度下降”的英文简写。此时我们设置SGD方法中的两个参数:(1)待优化的对象(network.parameters(),.parameters()表示取其网络参数);(2)学习率learn_rate(简写为lr)。
  3. 语句nn.MSELoss()意思是指定网络的损失函数loss_func为均方差(MSE)损失函数;
  4. 接下来便是两层循环,本人在此训练的逻辑是:对于每一个训练数据集中的样本,都送入网络中进行随机梯度下降法的训练,每个样本反复进行10次BP算法;
  5. 外层循环中有三句话,这三句话的含义分别为:(1)将每一个到来的样本输入转化为torch型张量(torch.Tensor),并将其维度调整为一个长度为64的行向量(此处可依据实际情况,判断是否需要维度调整);(2)后两句话含义是,针对此数据的标签Label,制作一个标准输出向量y,此向量将在内层循环中参与计算生成损失函数loss。
  6. 语句prediction = network(x),表示把x作为变量输入网络Net对象network中,最终得到的输出向量作为预测变量prediction;
  7. 语句loss = loss_func(prediction, y)声明loss函数与prediction和y有关,其作用和Loss的定义式

相当,等效于在此定义了Loss函数的形式。(yhat意思是y的估计,Loss函数是y估计与y差值的二范数平方)

  1. 语句optimizer.zero_grad()意思是将优化器中的梯度清零,因为每开始此BP算法,网络的初始梯度都是0;
  2. 语句loss.backward()意思是执行反向传播过程,在之前定义的model类里,我们只定义了前向传播过程forward,此时反向传播过程backward就已经确定了,所以无需自己再去定义(除非网络有特殊需要);
  3. 语句***optimizer.step()意思是开始一次迭代优化的步骤***。之前的过程都只是自定义训练的步骤,而非真正让机器执行,只有当调用optimizer.step()之后,程序才会开始优化,否则等于没干

以上便是利用torch训练一个神经网络的相关步骤,在数据集很大的情况下,训练的步骤耗时很大,建议程序运行时做点别的事~

保存网络参数

保存网络参数很简单,直接上代码

    if not os.path.isdir('model'):os.mkdir('model')torch.save(network, './model/full_connected.cpkt')

代码的逻辑很容易看懂,当model文件夹不存在时,创建model文件夹,然后在这个文件夹里面保存网络参数。

torch.save()可以保存的参数种类有很多,除了上面代码里面的类类型network以外,还有torch.Tensor、字典、参数列表network.parameters()等。

网络参数保存为文件之后,再次使用时就需要下面代码读取模型参数:

network = torch.load('./model/full_connected.cpkt')
network.eval()

也很容易理解,在torch.load中添加文件的路径即可,但是你事先要知道这个文件中存储的参数类型,并且要用正确的变量类型来接收参数读取的结果。

模型测试

直接给出代码:

    for epoch in range(len(y_test)):x = Variable(torch.tensor(x_test[epoch], dtype=torch.float32)).reshape([1, L])prediction = network(x)max = torch.max(prediction).detach().numpy()if max == prediction[0, y_test[epoch]].detach().numpy():count += 1print("The accuracy of the testSet(%s) : %2.2f\n" % (set, (count/len(y_test))))

和训练时的做法类似,只是此时没有了反向传播步骤,当我们得到最终的预测向量prediction作为输出时,prediction中最大元素所在的下标便是预测的结果,例如,如果最大元素所在下标为8,则程序识别出来的数字是8。prediction是一个长度为10的向量,原因就是0~9包含了10个数字,prediction本质上存储的是每个数字被识别的概率,其中概率最大的就是机器识别的结果,这是神经网络用于分类问题时普遍采用的策略。

效果展示




以上的图片给出了识别正确的数字的例子。

这里打印的是模型识别数字的准确率,可以发现,识别的准确率还算可以,由于本人在这里采用的是随机梯度下降法进行参数更新,所以网络有点过拟合。一般的情况下,在训练网络时都是优先采用批量梯度下降(BGD)法进行网络训练,这里只是做训练的简单演示,所以未考虑此问题。

代码整合

所有与全连接前馈神经网络有关的代码都在下列的Gitee链接中,本人在今后还会向此代码库增加与神经网络相关的内容。
神经网络代码Gitee链接

问题

如果要用神经网络进行非线性回归,代码应该怎么写,你有大致的思路吗?

从零开始的神经网络构建历程(二,用全连接前馈神经网络识别手写数字mnist)相关推荐

  1. Python神经网络识别手写数字-MNIST数据集

    Python神经网络识别手写数字-MNIST数据集 一.手写数字集-MNIST 二.数据预处理 输入数据处理 输出数据处理 三.神经网络的结构选择 四.训练网络 测试网络 测试正确率的函数 五.完整的 ...

  2. BP神经网络识别手写数字项目解析及代码

    这两天在学习人工神经网络,用传统神经网络结构做了一个识别手写数字的小项目作为练手.点滴收获与思考,想跟大家分享一下,欢迎指教,共同进步. 平常说的BP神经网络指传统的人工神经网络,相比于卷积神经网络( ...

  3. 基于Numpy构建全连接前馈神经网络进行手写数字识别

    文章目录 (一) 问题描述 (二) 设计简要描述 (三) 程序清单 (四) 结果分析 (五) 调试报告 (六) 实验小结 (一) 问题描述 不使用任何机器学习框架,仅仅通过Numpy库构建一个最简单的 ...

  4. python手机代码识别数字_利用python构建神经网络识别手写数字(附源代码)

    一.运行环境配置 本次实验的运行环境win10(bit64),采用python环境为3.7.6,安装Python环境推荐使用Anaconda.Anaconda是一个免费开源的Python和R语言的发行 ...

  5. 详细版【全连接前馈神经网络】(邱锡鹏)

    全连接前馈神经网络 前馈神经网络又叫多层感知器,但其实是由多层 的 Logistic 回归 具体流程 令

  6. 全连接前馈神经网络DNN

    全连接前馈神经网络DNN 1.DNN概述 前馈神经网络中,各神经元属于不同层,信号从输入层向输出层单向传播(有向无环图) 人工神经元模型:输入:x1,x2    权重:w1,w2      偏置:b  ...

  7. 利用python实现简单的人工神经网络识别手写数字

    利用 Python 搭建起了一个简单的神经网络模型,并完成识别手写数字. 1.前置工作 1.1 环境配置 这里使用scikit-learn库内建的手写数字字符集作为本文的数据集.scikit-lear ...

  8. BP神经网络理解原理——用Python编程实现识别手写数字(翻译英文文献)

    BP神经网络理解原理--用Python编程实现识别手写数字   备注,这里可以用这个方法在csdn中编辑公式: https://www.zybuluo.com/codeep/note/163962 一 ...

  9. 华裔女性钱璐璐:用 DNA 开发人工智能神经网络,识别手写数字!

    "既然要学人脑的思维方式,为什么不去研究人脑?"霍金斯在<论智能>中说道. 如今,不少生物学研究者正朝着这个方向努力. 不过,请注意:这不是一次传统意义上的生物实验. ...

最新文章

  1. np.eye()的函数能将一个label数组,大小为(1,m)或者(m,1)的数组,转化成one-hot数组
  2. 小程序循环不同的组建_小程序之八,对象数组、循环及条件渲染
  3. 模板三连击:树状数组+线段树+主席树
  4. 我用Python爬取1000封情书助力室友表白班花,却反转再反转...原来这就是班花的终极秘密!
  5. 将一个文本文件的内容按行读出,每读出一行就顺序加上行号,并写入到另一个文件中。...
  6. 用Java实现断点续传(HTTP)
  7. Games 图形学 L2线性代数
  8. mysql如何导出数据脚本_MySQL 导出数据
  9. 【案例20】NC系统was部署后无法登录
  10. 华为手机居然还能这样提高续航?简单设置一下,一天一充很轻松
  11. 从原理层面掌握@RequestAttribute、@SessionAttribute的使用【享学Spring MVC】
  12. 你可以成为测试界的李子柒
  13. JMM 8 大原子操作
  14. codeforces 855-B. Marvolo Gaunt's Ring(背包问题)
  15. 2022年8月29日 勒索病毒大爆发
  16. SiteRAS一款外贸网站SEO分析工具,给您的网站做个深度体检
  17. 腾讯、阿里、字节跳动三家公司有何区别!?
  18. 解决api-ms-win-core-processthreads-l1-1-1.dll文件丢失
  19. 11-课程详情页面静态化-课程信息模板设计
  20. npp夜光数据介绍 viirs_最新 夜光遥感影像VIIRSDMSP下载总结

热门文章

  1. 浅谈1394总线的那点事
  2. 真相只有一个——谁是凶手
  3. 深入探讨静态路由的next-hop选项(discard/receive/reject)
  4. GPS 校验和 代码_YBS-YFQ-100K真空压力校验台气压校验压力台表压力变送器
  5. 基于asp.net的幼儿园接送信息管理系统-计算机毕业设计
  6. 纷享销客CRM自定义函数:创建自定义对象数据
  7. 高数——导数的意义——学习笔记
  8. 简单Android app之 一键签到 开发日记
  9. 程序员的专属微信公众号编辑器:定制 Markdown 转 HTML
  10. 家教信息管理系统的设计与实现