文章目录

  • 全连接网络手写数字识别
    • 定义类:创建数据集
      • Dataset类
      • \_\_init\_\_函数
      • **\_\_getitem\_\_**与\_\_len\_\_
      • make_txt_file 函数
      • 测试
    • 定义类:全连接网络
      • super(HandWritingNumberRecognize_Network, self).\_\_init\_\_()
      • forward函数:
      • 网络结构
    • 训练函数
      • optimizer.zero_grad
      • 损失函数
    • 验证函数
      • torch.no_grad()
      • 计算正确率
    • 测试函数
    • main
      • optimizer =torch.optim.Adam
  • 调试过程
  • 引用与参考

全连接网络手写数字识别

pytorch

全连接网络

mnist数据集

有验证集

定义类:创建数据集

class HandWritingNumberRecognize_Dataset(Dataset):def __init__(self,datasetfold_path,dataset_type,transform=None,update_dataset=False):# 这里添加数据集的初始化内容#dataset_path = 'E:\桌面\dml\深度学习课程-实验1\深度学习课程-实验1\datas\dataset\\'dataset_path = datasetfold_pathif update_dataset:make_txt_file(dataset_path+dataset_type,dataset_type) # update datalist in dataset_path# make_txt_file函数:写一个文件,用来对应图片路径和标签self.transform = transforms.ToTensor()self.sample_list = list()self.dataset_type = dataset_typef = open(dataset_path + self.dataset_type + '/datalist.txt')lines = f.readlines()for line in lines:self.sample_list.append(line.strip())  # 这是重构之后的datalist文件。f.close()#passdef __getitem__(self, index):# 这里添加getitem函数的相关内容item = self.sample_list[index]img = Image.open(item.split(' _')[0])if self.transform is not None:img = self.transform(img)if self.dataset_type == 'test' :label = 0else:label = int(item.split(' _')[-1])return img, label#passdef __len__(self):# 这里添加len函数的相关内容return len(self.sample_list)#pass
Dataset类

这段代码可以看出,我们构建的class HandWritingNumberRecognize_Dataset,是继承了Dataset的子类。

Dataset是pytorch 给出的类,源代码如下:

class Dataset(object):"""An abstract class representing a Dataset.All other datasets should subclass it. All subclasses should override``__len__``, that provides the size of the dataset, and ``__getitem__``,supporting integer indexing in range from 0 to len(self) exclusive."""def __getitem__(self, index):raise NotImplementedErrordef __len__(self):raise NotImplementedErrordef __add__(self, other):return ConcatDataset([self, other])

我们要通过dataset类完成“数据和标签”的对应,具体的构建代码,需要先观察自己的数据集格式,保证能读到文件,并且对应上。

在本实验中,包括一下几个内容:

  1. 打开并读取标签文件。本实验老师提供的数据集标签文件中,只有数字,没有任何其他内容,所以不需要切片。
  2. 用PIL.Iimage.open读取图片,图片路径可以用递增写
  3. 图片如果需要transform操作,操作一下
__init__函数
  • 构造函数,在调用类创建实例对象时候,会自动调用
  • 参数:第一个参数必须为self,然后需要传入数据集的path,数据集的类型,是否进行transform操作。
  • 这里我加了一个# make_txt_file # 函数,用来规范书写(x,y)对应关系,后面getitem就可以适应各种函数。只需要在make_txt_file里面观察具体实验使用的数据集。
  • 把重构之后的(x,y)信息读到self.sample_list里面
**__getitem__**与__len__

__getitem__就是获取样本对,模型直接通过这一函数获得一对样本对{x:y}

在这里,把图片由路径读出图片;把标签内容读到label里面

——这里需要讨论一下,对于本实验中的test数据集,不需要加载label,直接填充为0.

__len__是指数据集长度。

由源码可以看出,继承的时候一定要覆写这两个函数,不然会返回错误。

make_txt_file 函数
def make_txt_file(path,dataset_type):labellist=list()if dataset_type=='test':files = os.listdir(path+'\images')   # 读入文件夹num_png = len(files)  #统计test文件夹中图片的个数f = open(path+'/datalist.txt','w') #写对应文件,这里用'w'方法,覆盖原文件countimg=0while countimg<num_png:f.write(path+'\images\\'+dataset_type +'_'+str(countimg)+'.bmp\n')countimg+=1f.close()else:f = open(path + '\labels_'+dataset_type+'.txt')lines = f.readlines()for line in lines:labellist.append(line.strip())  #把读到的纯标签写入列表f.close()f = open(path+'/datalist.txt','w') #写对应文件,这里用'w'方法,覆盖原文件countimg=0for i in labellist:f.write(path+'\images\\'+dataset_type +'_'+str(countimg)+'.bmp'+' _'+i+'\n')#写完的效果是“图片路径 _标签”countimg+=1f.close()

这个函数要观察数据集,本实验提供的数据集中,图片为一个文件夹,中图片名数字递增;train和validation有label文件,里面每一行为一个标签,和文件夹里面图片顺次对应。

这里用**os.listdir(path+’\images’)**来读文件夹

测试

构造完之后,测试一下

刚开始测试时候,遇到了IndexError: list index out of range 错误,可能有两个原因:

  • 一个可能是下标超出范围,
  • 一个可能是list是空的,没有一个元素’

检查代码,发现是读datalist时候,里面只有数字,没有图片路径。于是重写了def make_txt_file(path,dataset_type):函数,解决了。

if __name__ == '__main__':ds = HandWritingNumberRecognize_Dataset('train',update_dataset=True)print(ds.__len__())img, gt = ds.__getitem__(100) # get the 34th sampleprint(type(img))print(gt)

再看一下自己写的makefile

在make_txt_file代码里面可以看到,我用的都是绝对路径

定义类:全连接网络

class HandWritingNumberRecognize_Network(torch.nn.Module):def __init__(self,in_dim,out):super(HandWritingNumberRecognize_Network, self).__init__()self.in_dim=in_dim# 此处添加网络的相关结构,下面的pass不必保留self.layer = torch.nn.Linear(in_dim,out)#self.layer2 = torch.nn.Linear(hidden,out)#passdef forward(self, input_data):# 此处添加模型前馈函数的内容,return函数需自行修改x = self.layer(input_data)#x = self.layer2(x)x = torch.nn.functional.softmax(x)# x = self.layer3(x)return x
super(HandWritingNumberRecognize_Network, self).__init__()

在定义类时可见,我们继承的是torch.nn.Module类

这个super函数的意思是, 继承了父类(nn.module)的初始化设置, 这一步一定要写

继承 nn.Module 的神经网络模块在实现自己的 init 函数时,一定要先调用 super().init()

这是nn.module 的 init:

self.training = True  # 控制 training/testing 状态self._parameters = OrderedDict()  # 在训练过程中会随着 BP 而更新的参数self._buffers = OrderedDict()  # 在训练过程中不会随着 BP 而更新的参数self._non_persistent_buffers_set = set()self._backward_hooks = OrderedDict()  # Backward 完成后会被调用的 hookself._forward_hooks = OrderedDict()  # Forward 完成后会被调用的 hookself._forward_pre_hooks = OrderedDict()  # Forward 前会被调用的 hookself._state_dict_hooks = OrderedDict()  # 得到 state_dict 以后会被调用的 hookself._load_state_dict_pre_hooks = OrderedDict()  # load state_dict 前会被调用的 hookself._modules = OrderedDict()  # 子神经网络模块
forward函数:

这是torch.nn.Module的一个方法,类似的还有

class Module(object):def __init__(self):def forward(self, *input):def add_module(self, name, module):def cuda(self, device=None):def cpu(self):def __call__(self, *input, **kwargs):def parameters(self, recurse=True):def named_parameters(self, prefix='', recurse=True):def children(self):def named_children(self):def modules(self):  def named_modules(self, memo=None, prefix=''):def train(self, mode=True):def eval(self):def zero_grad(self):def __repr__(self):def __dir__(self):
'''
有一部分没有完全列出来
'''

有一些注意技巧:

  1. 一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中,当然我也可以吧不具有参数的层也放在里面;
  2. 一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替
  3. forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心
网络结构

根据实验报告要求,只用了一层全连接,也没用relu激活,在最后用了一层softmax

训练函数

def train(epoch_num):# 循环外可以自行添加必要内容mode = Truemodel.train(mode=mode)#print(data_loader_train)for index,(images,true_labels) in enumerate(data_loader_train):#这里用index和 enumerate,为了统计训练到了第几个样本,中间可以返回过程值optimizer.zero_grad()output = model(images.reshape(-1,784))#这里不能用结果数字,需要用原来的概率和标签的独热编码比较。#保留非最大值的信息训练效果好one_hot = torch.zeros(1, 10).long()one_hot.scatter_(dim=1,index=true_labels.unsqueeze(dim=1),src=torch.ones(1, 10).long())loss = loss_function(output.float(), one_hot.float())#得到损失函数loss.backward() # 反向传播训练参数optimizer.step()# 必要的时候可以添加损失函数值的信息,即训练到现在的平均损失或最后一次的损失,下面两行不必保留if index % 3000 == 0:print(index, loss.item())  # 获取损失#print(epoch_num, images, true_labels)# pass
optimizer.zero_grad
损失函数

这里损失函数计算的时候,要用两个tensor计算,所以需要把label转为单热点的,和输出的概率向量计算。

同时,因为选用的loss时mse,需要转为float型计算

验证函数

def validation():# 验证函数,任务是在训练经过一定的轮数之后,对验证集中的数据进行预测并与真实结果进行比对,生成当前模型在验证集上的准确率correct = 0total = 0accuracy = 0with torch.no_grad():  # 该函数的意义需在实验报告中写明for data in data_loader_val:images, true_labels = data# 在这一部分撰写验证的内容,下面两行不必保留output = model(images.reshape(-1,784))pred = output.data.max(dim=-1)[-1]  # max(dim=-1):行角度找最大值,dim=0:列最大值,返回下标#在这里通过返回下标将概率转为数字结果,验证函数中用数字结果和标签对比,统计正确率if pred.equal(true_labels):correct+=1total+=1#print(images, true_labels)#passaccuracy=correct/totalprint("验证集数据总量:", total, "预测正确的数量:", correct)print("当前模型在验证集上的准确率为:", accuracy)
torch.no_grad()

torch.no_grad() 是一个上下文管理器,被该语句 wrap 起来的部分将不会track 梯度。

禁用梯度,仅在局部地区有效,将requires_grad改为false,用以节约计算资源。

optimizer.zero_grad()清除了优化器中所有 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OZXCrRQJ-1647959427155)(https://www.zhihu.com/equation?tex=x)] 的 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xwQgzHLK-1647959427156)(https://www.zhihu.com/equation?tex=x.grad)] ,在每次loss.backward()之前,不要忘记使用,否则之前的梯度将会累积,这通常不是我们所期望的( 也不排除也有人需要利用这个功能)。

同样还可以用 torch.set_grad_enabled()来实现不计算梯度。
例如:

def eval():torch.set_grad_enabled(False)...  # your test codetorch.set_grad_enabled(True)

引用一条评论,体会一下这个函数作用:

在测试函数前加了装饰器,解决了cuda out of memory 感谢!!

计算正确率

这里计算正确率用的是对比结果数字。

所以用返回下标的方式,将概率向量中最大的下标返回成结果。

测试函数

def alltest(result_path):# 测试函数,需要完成的任务有:根据测试数据集中的数据,逐个对其进行预测,生成预测值。loss_list = []acc_list = []with torch.no_grad():#f = open('E:\桌面\dml\深度学习课程-实验1\\result.txt','w') #写对应文件,这里用'w'方法,覆盖原文件f = open(result_path,'w')for data in data_loader_test:images,label= dataoutput = model(images.reshape(-1,784)) pred = output.data.max(dim=-1)[-1]  # max(dim=-1):行角度找最大值,dim=0:列最大值,返回下标f.write(str(pred.item())+'\n')#将结果写入文档f.close()# 将结果按顺序写入txt文件中,下面一行不必保留pass

这里根据本次实验的要求,将test文件夹里面的图片通过网络之后的结果写入文件。

这里将路径写为参数,方便之后整理文件夹结构

main

if __name__ == "__main__":#先用绝对路径,调整文件夹结构之后,改为相对路径,方便迁移。datasetfold_path = 'E:\桌面\dml\深度学习课程-实验1\深度学习课程-实验1\datas\dataset\\'# 构建数据集,参数和值需自行查阅相关资料补充。dataset_train = HandWritingNumberRecognize_Dataset(datasetfold_path,'train',transform=True, update_dataset=True)dataset_val = HandWritingNumberRecognize_Dataset(datasetfold_path,'val',transform=True,update_dataset=True)dataset_test = HandWritingNumberRecognize_Dataset(datasetfold_path,'test',transform=True,update_dataset=True)# 构建数据加载器,参数和值需自行完善。data_loader_train = DataLoader(dataset=dataset_train)data_loader_val = DataLoader(dataset=dataset_val)data_loader_test = DataLoader(dataset=dataset_test)#print(data_loader_train)# 初始化模型对象,可以对其传入相关参数model = HandWritingNumberRecognize_Network(28*28*1,10)# 损失函数设置loss_function = torch.nn.MSELoss()  # torch.nn中的损失函数进行挑选,并进行参数设置# 优化器设置optimizer =torch.optim.Adam(model.parameters(), lr = 0.0001) # torch.optim中的优化器进行挑选,并进行参数设置max_epoch = 1  # 自行设置训练轮数num_val = 1  # 经过多少轮进行验证# 然后开始进行训练for epoch in range(max_epoch):train(epoch)# 在训练数轮之后开始进行验证评估if epoch % num_val == 0:validation()# 自行完善测试函数,并通过该函数生成测试结果result_path = 'E:\桌面\dml\深度学习课程-实验1\\result.txt'alltest(result_path)

这里的损失函数先选择的是MSEloss

优化器用的Adam

训练轮数和验证轮数都是先设定这样,自己电脑cpu跑一轮太慢了

optimizer =torch.optim.Adam

params(iterable):可用于迭代优化的参数或者定义参数组的dicts。
lr (float, optional) :学习率(默认: 1e-3)
betas (Tuple[float, float], optional):用于计算梯度的平均和平方的系数(默认: (0.9, 0.999))
eps (float, optional):为了提高数值稳定性而添加到分母的一个项(默认: 1e-8)
weight_decay (float, optional):权重衰减(如L2惩罚)(默认: 0)

调试过程

本文的网络层数只有一层,不涉及到中间层层数和神经元个数

激活函数也定位softmax了

超参数有:优化器的选择(Adam的lr需要调)、损失函数的选择,训练的轮数

没啥经验,瞎调

引用与参考

  1. https://blog.csdn.net/leviopku/article/details/99958182
  2. https://blog.csdn.net/qq_27825451/article/details/90551513
  3. https://blog.csdn.net/qq_36108664/article/details/107205942
  4. Pytorch学习笔记07----nn.Module类与前向传播函数forward的理解 - 雨后观山色 - 博客园 (cnblogs.com)
  5. https://blog.csdn.net/qq_43082153/article/details/108579168
  6. https://blog.csdn.net/weixin_46559271/article/details/105658654
  7. 【PyTorch】实现手写数字识别_风口IT猪的成长录-CSDN博客
  8. (38条消息) TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists found_CV干饭王的博客-CSDN博客
  9. (38条消息) torch.optim优化算法理解之optim.Adam()_KGzhang的博客-CSDN博客_torch.optim.adam
  10. (38条消息) 完美分析 解决 python 中 not enough values to unpack (expected 2, got 1) 的 报错 。_风蓝天明的博客-CSDN博客

全连接网络手写数字识别(极详细,互助)相关推荐

  1. 用PyTorch实现MNIST手写数字识别(非常详细)

    ​​​​​Keras版本: Keras入门级MNIST手写数字识别超级详细教程 2022/4/17 更新修复下代码.完善优化下文章结构,文末提供一个完整版代码. 可以在这里下载源码文件(免积分): 用 ...

  2. 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)

    主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...

  3. 【项目实践】:KNN实现手写数字识别(附Python详细代码及注释)

    ↑ 点击上方[计算机视觉联盟]关注我们 本节使用KNN算法实现手写数字识别.KNN算法基本原理前边文章已经详细叙述,盟友们可以参考哦! 数据集介绍 有两个文件: (1)trainingDigits文件 ...

  4. 深度学习(32)随机梯度下降十: 手写数字识别问题(层)

    深度学习(32)随机梯度下降十: 手写数字识别问题(层) 1. 数据集 2. 网络层 3. 网络模型 4. 网络训练 本节将利用前面介绍的多层全连接网络的梯度推导结果,直接利用Python循环计算每一 ...

  5. 我的Go+语言初体验——Go+语言构建神经网络实战手写数字识别

    "我的Go+语言初体验" | 征文活动进行中- 我的Go+语言初体验--Go+语言构建神经网络实战手写数字识别 0. 前言 1. 神经网络相关概念 2. 构建神经网络实战手写数字识 ...

  6. pytorch 预测手写体数字_深度学习之PyTorch实战(3)——实战手写数字识别

    如果需要小编其他论文翻译,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 上一节,我们已经 ...

  7. 简陋的CNN实现手写数字识别

    文章目录 前言 背景知识 Neural Network Backpropagation CNN pytorch 介绍 代码 CNN模型 训练&测试 前言 日常翘课,但是作业还是要写的. 数据集 ...

  8. 使用飞桨完成手写数字识别模型

    手写数字识别任务 数字识别是计算机从纸质文档.照片或其他来源接收.理解并识别可读的数字的能力,目前比较受关注的是手写数字识别.手写数字识别是一个典型的图像分类问题,已经被广泛应用于汇款单号识别.手写邮 ...

  9. 基于支持向量机的手写数字识别详解(MATLAB GUI代码,提供手写板)

    摘要:本文详细介绍如何利用MATLAB实现手写数字的识别,其中特征提取过程采用方向梯度直方图(HOG)特征,分类过程采用性能优异的支持向量机(SVM)算法,训练测试数据集为学术及工程上常用的MNIST ...

最新文章

  1. python中字符串前面加一个u或者r的区别
  2. 数据库技术mysql能干什么_MySQL外键有什么作用
  3. XHTML学习笔记 Part2:核心元素
  4. 18awg线材最大电流_小米生态链拉车线:2.4A大电流,苹果MFi认证,高速充电不断裂...
  5. mac_android_studio_环境搭建
  6. 梯度下降优化算法总结
  7. 【CCCC】L2-008 最长对称子串 (25分),直接枚举遍历
  8. 小D课堂 - 新版本微服务springcloud+Docker教程_2_04微服务下电商项目基础模块设计...
  9. Qt编写/注册/使用activex控件
  10. unity中Asset Store下载的资源保存位置
  11. 在手机相册(ios设备相册)中创建相册
  12. 网络安全笔记5——数字签名
  13. 谷歌Cartographer的论文研读(一)
  14. 如何设置计算机的休眠时间,电脑的睡眠时间如何设置?
  15. CA(证书颁发机构)
  16. 基于java+SpringBoot+HTML+Mysql旅游网站设计与实现
  17. 小娴的男友小旭不幸患了一种怪病,这种怪病吞噬了他的大部分记忆,同时让他突然间不会书写符合正确语序的英文。神奇的是,虽然他写出的句子看上去杂乱无章,不过经过仔细分析可以发现,如果把单词的顺序倒过来,语法
  18. NOIP2009年普及组初赛试题答案及解析
  19. 在JavaScript中组合字符串的4种方法
  20. ubuntu ibus安装极点五笔

热门文章

  1. VTK Examples中MIP和MPR的功能实现与源码分析
  2. 已知满二叉树先序序列存在于数组中,设计算法将其变成后序序列
  3. AcWing with LeetCode
  4. UML序列图 - 示例总汇
  5. arcgis英文版投影在哪_ArcGIS地图投影转换
  6. Shopify怎么添加发货方式
  7. 两款识图神器,你这个图我认识!
  8. Codeforces C. Numbers on Whiteboard (Round 96 Rated for Div.2) (思维 / 贪心)
  9. 麦块服务器正版登,我的世界麦块盒子正版
  10. 4. Bootstrap - intermediate