目录

1. 加载数据:MNIST的train数据和test数据

2. 定义网络:一个简单的线性神经网络

3. 训练:使用定义的网络进行训练

拓展:使用visdom可视化_动态的绘制loss曲线

4. 测试:使用测试集,测试预测精度

5. 辅助工具函数

6. 拓展


1. 加载数据:MNIST的train数据和test数据

import torch
import torchvision  # 处理图像视频, 包含一些常用的数据集、模型、转换函数等等
from torch import nn, optim
from torch.nn import functional as Ffrom matplotlib import pyplot as plt
from utils import plot_curve, plot_image, one_hotbatch_size = 512
# step1. 加载数据集
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data_john', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data_john/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=False)# # test: 显示下训练数据集的前6张图像及对应的标签
# x,y = next(iter(train_loader))
# print(x.shape, y.shape, x.min(), x.max())
# plot_image(x,y,"image_gt")

显示下训练数据集的前6张图像及对应的标签:

2. 定义网络:一个简单的线性神经网络

这里采用 3个线性层,做简单示范。

# step2. 定义神经网络
class MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()# 定义三个线性层 y = wx+b# 输入X的size是:[batch_size, 28*28=784]# y = w1 * x + b1, 例如:参数数量:w1.size是[256,784] (一张图像x的size是[784,1]), b.size是[256]self.fc1 = nn.Linear(28 * 28, 256)  # 28*28是输入图像的大小,256是自定义的中间层大小# y = w2 * x + b2self.fc2 = nn.Linear(256, 64)       # 中间层数的结果,本层输入层数取决于上一层的输出层数,本层输出决定了下一层的输入层数。# y = w3 * x + b3self.fc3 = nn.Linear(64, 10)        # 10是要求的输出分类层数def forward(self, x):x = F.relu(self.fc1(x))             # 使用激活函数,增加非线性x = F.relu(self.fc2(x))x = self.fc3(x)                     # 最后一层根据网络结果输出return x

拓展:网络的另一种定义方式如下。

# step2'. 定义神经网络
class MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()self.model =  nn.Sequential(nn.Linear(784, 256),nn.LeakyReLU(inplace=True), # 使用inplace,省去反复申请和释放内存的时间。会对原变量覆盖。nn.Linear(256, 64),nn.LeakyReLU(inplace=True),nn.Linear(64, 10),nn.LeakyReLU(inplace=True))def forward(self, x):x = self.model(x)return x

3. 训练:使用定义的网络进行训练

# step3. 开始训练
net = MyNet()
# 定义梯度下降方式
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
train_loss = []for epoch in range(3):for batch_idx, (x, y) in enumerate(train_loader): # train_loader中有 n 个 (x,y), 每一个 x和 y包含 batch_size张图像,所以总的图像数量是= batch_idx * batch_sizex = x.view(x.size(0), 28 * 28)  # view 等价于reshape, [512,1,28*28] => [512,28*28]out = net(x)                    # [batch_size,10]y_onehot = one_hot(y)           # 如将 [512,] 变成 [512,10], 原来一维数组中对应m行的值n,对应新的二维数组m行的第n列设置为1loss = F.mse_loss(out, y_onehot)optimizer.zero_grad()loss.backward()optimizer.step()train_loss.append(loss.item())if batch_idx % 10 == 0:print(epoch, batch_idx, loss.item())plot_curve(train_loss)print (net.parameters())

loss的下降结果:

拓展:使用visdom可视化_动态的绘制loss曲线

1. 准备工作:a. 安装visdom: (pip install visdom) b.启动服务(python -m visdom.server) c.浏览器打开localhost (比如 http://localhost:8097)

界面如下:

2. 在代码中添加要监视的变量,如下viz相关的代码

from visdom import Visdomviz = Visdom()
index = 0
viz.line([0.],[0.],win='train_loss', opts=dict(title='train loss'))for epoch in range(3):for batch_idx, (x, y) in enumerate(train_loader): # train_loader中有 n 个 (x,y), 每一个 x和 y包含 batch_size张图像,所以总的图像数量是= batch_idx * batch_sizex = x.view(x.size(0), 28 * 28)  # view 等价于reshape, [512,1,28*28] => [512,28*28]out = net(x)                    # [batch_size,10]y_onehot = one_hot(y)           # 如将 [512,] 变成 [512,10], 原来一维数组中对应m行的值n,对应新的二维数组m行的第n列设置为1loss = F.mse_loss(out, y_onehot)optimizer.zero_grad()loss.backward()optimizer.step()train_loss.append(loss.item())#可视化index += 1# viz.line([0.],[0.],win='train_loss', opts=dict(title='train loss'))viz.line([loss.item()],[index], win='train_loss', update='append')if batch_idx % 10 == 0:print(epoch, batch_idx, loss.item())

可视化效果(动态的):

4. 测试:使用测试集,测试预测精度

# step4. 进行测试,计算预测精度
total_correct = 0
for x, y in test_loader:x = x.view(x.size(0), 28 * 28)out = net(x)pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct / total_num
print("acc: ", acc)# 可视化部分预测结果
x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28 * 28)) # 二维 [batch_size, 10]
pred = out.argmax(dim=1)              # 一维 [batch_size,]
plot_image(x, pred, 'image_predict')

预测精度:

可视化部分预测结果:

5. 辅助工具函数

定义到uitls.py文件中:

用于显示图像,打印一维数组,one hot操作

import torch
from matplotlib import pyplot as plt# 绘制一维数据图
def plot_curve(data):fig = plt.figure()  # 定义一张图纸plt.plot(range(len(data)), data, color="blue")  # 绘制一维数组plt.legend(["value"], loc="upper right")        # 添加图例,即数据说明标签plt.xlabel("step")plt.ylabel("value")plt.show()def plot_image(img, label, name):''':param img:     比如:torch.Size([batch_size=512, 1, 28, 28]):param label:   比如:torch.Size([512]):param name:    string'''fig = plt.figure()for i in range(6):plt.subplot(2, 3, i + 1) # 2*3个小图像plt.tight_layout()plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap="gray", interpolation="none")  # 图像进行正则化,然后显示出来plt.title("{}: {}".format(name, label[i].item()))  # 显示每一个plot的标题plt.xticks([])  # 设置x轴的刻度标签为空,即不显示刻度plt.yticks([])plt.show()def one_hot(label, depth=10):out = torch.zeros(label.size(0), depth) # 定义一个 [batch_size, 10]大小的矩阵idx = torch.LongTensor(label).view(-1, 1)  # 把 label reshape成 [batch_size, 1]尺寸的2维tensorout.scatter_(dim=1, index=idx, value=1)  # 改变 out的第dim=1维度的数据,out中值被改变值的索引,是index中对应的值, 填充的值是 1,# 如第2个样本是“6”,则第1行第5列(矩阵索引从0开始)填充为1.return out# if __name__ == '__main__':
#     data = [1,2,3,4,5,4,3,2,5,6,8]
#     plot_curve(data)

6. 拓展

使用自定义的线性回归模型,使用交叉熵,进行手写数字的识别。注意对比(注释掉35-37行 pk 启用35-37行)下面的参数w1 w2 w3的初始化方法(使用randn初始化时,loss经过几个epoch后会一直居高不下,使用kaiming的初始化方法后,loss会将下来很多。说明初始化对于网络的训练结果来说至关重要!!!)

注释掉35-37行:(randn初始化(高斯分布))

启用35-37行:(kaiming大佬的初始化)

参考:

深度学习入门_哔哩哔哩_bilibili

深度学习笔记_搭建一个简单网络(完整版)_手写数字识别MNIST相关推荐

  1. 读书笔记-深度学习入门之pytorch-第五章(含循环实现手写数字识别)(LSTM、GRU代码详解)

    目录 1.RNN优点:(记忆性) 2.循环神经网络结构与原理 3.LSTM(长短时记忆网络) 4.GRU 5.LSTM.RNN.GRU区别 6.收敛性问题 7.循环神经网络Pytorch实现 (1)R ...

  2. 深度学习入门(斋藤康毅)3.6手写数字识别_No module named ‘dataset‘ 问题解决

    深度学习入门在看这本书,到这里的时候,运行的时候遇到些问题(我是真的很菜- -) 1. No module named 'dataset'  2. SyntaxError: (unicode erro ...

  3. 深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别

    深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别 一.前言 二.网络结构 三.可解释性 四.记忆主线 五.遗忘门 六.输入门 七.输出门 八.手写数字识别实战 8.1 引入依赖库 8. ...

  4. OpenCV4学习笔记(55)——基于KNN最近邻算法实现鼠标手写数字识别

    在上一篇博客<OpenCV4学习笔记(54)>中,整理了关于KNN最近邻算法的一些相关内容和一个手写体数字识别的例子.但是上次所实现的手写体数字识别,每次只能固定地输入测试图像进行预测,而 ...

  5. 机器学习笔记-神经网络的原理、数学、代码与手写数字识别

    机器学习笔记-神经网络 作者:星河滚烫兮 文章目录 前言 一.神经网络的灵感 二.基本原理 1.神经网络最小单元--神经元 2.神经网络层结构 3.正向传播 4.反向传播 5.梯度下降 三.数学理论推 ...

  6. matlab 对mnist手写数字数据集进行判决分析_人工智能TensorFlow(十四)MINIST手写数字识别...

    MNIST是一个简单的视觉计算数据集,它是像下面这样手写的数字图片: MNIST 每张图片还额外有一个标签记录了图片上数字是几,例如上面几张图的标签就是:5.0.4.1. MINIST数据 MINIS ...

  7. Keras搭建CNN(手写数字识别Mnist)

    MNIST数据集是手写数字识别通用的数据集,其中的数据是以二进制的形式保存的,每个数字是由28*28的矩阵表示的. 我们使用卷积神经网络对这些手写数字进行识别,步骤大致为: 导入库和模块 我们导入Se ...

  8. 基于人工智能方法的手写数字图像识别_【工程分析】基于ResNet的手写数字识别...

    ねぇ 呐 私に気付いてよ 快点注意到我吧 もう そんな事 那种事 一定 望んでも 再去奢求 しょうがないだろ 也无可奈何吧 --真野あゆみ<Bipolar emotion>(作詞:Mits ...

  9. 莫烦keras学习代码二(手写数字识别MNIST classifier CNN版)

    知道了CNN的原理,同样是只要将之前用tensorflow写的几个建立网络的函数用keras的更简单的方法替换就行. 训练结果: 用Sequential().add()添加想要的层,添加卷积层就用Co ...

  10. 菜菜学paddle第一篇:单层网络构建手写数字识别

    前言: 1.数字识别是计算机从纸质文档.照片或其他来源接收.理解并识别可读的数字的能力,目前比较受关注的是手写数字识别.手写数字识别是一个典型的图像分类问题,已经被广泛应用于汇款单号识别.手写邮政编码 ...

最新文章

  1. 中文版GPT-3来了?智源研究院发布清源 CPM —— 以中文为核心的大规模预训练模型...
  2. ios 小数保留位数
  3. 海康sdk返回yv12
  4. overflow encountered in exp
  5. 【杂谈】为了让大家学好深度学习模型设计和优化,有三AI都做了什么
  6. Qt加载本地字体 .ttc或.ttf
  7. MS UC 2013-0-虚拟机-标准化-部署-1-虚拟化-部署
  8. linuxoracle静默安装应答文件修改_Oracle 19c的examples静默安装
  9. 【BZOJ3451】Normal【期望线性性】【点分治】【NTT卷积】
  10. c# 无损高质量压缩图片代码
  11. html表单代码例子_关于React的这些细节,你知道吗?-表单
  12. mnist torch加载fashion_Pytorch加载并可视化FashionMNIST指定层(Udacity)
  13. nginx的目录结构和配置文件
  14. 【渗透测试实战】PHP语言有哪些后门?以及利用方法
  15. [转]gcc下程序调用静态库编译命令:主文件必须在静态库前面!
  16. 捕获事件要比冒泡事件先触发
  17. 洛谷P2770 航空路线问题(费用流)
  18. 利用arcgis基本比例尺标准分幅编号流程
  19. 【CodeForces】gym-101205B Curvy Little Bottles (2012 ACM-ICPC World Finals B)
  20. Windows下winrm的网络类型公用的问题解决

热门文章

  1. android自定义application,Android Test Orchestrator和自定义Application类
  2. Hbase之过滤器的使用
  3. LVS虚拟服务器的实现方式
  4. ReentrantLock中的公平锁与非公平锁
  5. Math,Number
  6. Exchange管理界面
  7. POJ 1088解题报告
  8. GridBagLayout用法
  9. Java Thread.yield详解
  10. Springboot整合JasperReport报表以及报表打印功能