大家好,我是红色石头!

在上三篇文章:

这可能是神经网络 LeNet-5 最详细的解释了!

我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!

我用 PyTorch 复现了 LeNet-5 神经网络(CIFAR10 数据集篇)!

详细介绍了卷积神经网络 LeNet-5 的理论部分和使用 PyTorch 复现 LeNet-5 网络来解决 MNIST 数据集和 CIFAR10 数据集。然而大多数实际应用中,我们需要自己构建数据集,进行识别。因此,本文将讲解一下如何使用 LeNet-5 训练自己的数据。

正文开始!

三、用 LeNet-5 训练自己的数据

下面使用 LeNet-5 网络来训练本地的数据并进行测试。数据集是本地的 LED 数字 0-9,尺寸为 28x28 单通道,跟 MNIST 数据集类似。训练集 0-9 各 95 张,测试集 0~9 各 40 张。图片样例如图所示:

3.1 数据预处理

制作图片数据的索引

对于训练集和测试集,要分别制作对应的图片数据索引,即 train.txt 和 test.txt两个文件,每个 txt 中包含每个图片的目录和对应类别 class。示意图如下:

制作图片数据索引的 python 脚本程序如下:

import ostrain_txt_path = os.path.join("data", "LEDNUM", "train.txt")
train_dir = os.path.join("data", "LEDNUM", "train_data")
valid_txt_path = os.path.join("data", "LEDNUM", "test.txt")
valid_dir = os.path.join("data", "LEDNUM", "test_data")def gen_txt(txt_path, img_dir):f = open(txt_path, 'w')for root, s_dirs, _ in os.walk(img_dir, topdown=True):  # 获取 train文件下各文件夹名称for sub_dir in s_dirs:i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径img_list = os.listdir(i_dir)                    # 获取类别文件夹下所有png图片的路径for i in range(len(img_list)):if not img_list[i].endswith('jpg'):         # 若不是png文件,跳过continuelabel = img_list[i].split('_')[0]img_path = os.path.join(i_dir, img_list[i])line = img_path + ' ' + label + '\n'f.write(line)f.close()if __name__ == '__main__':gen_txt(train_txt_path, train_dir)gen_txt(valid_txt_path, valid_dir)

运行脚本之后就在 ./data/LEDNUM/ 目录下生成 train.txt 和 test.txt 两个索引文件。

构建Dataset子类

pytorch 加载自己的数据集,需要写一个继承自 torch.utils.data 中 Dataset 类,并修改其中的 __init__ 方法、__getitem__ 方法、__len__ 方法。默认加载的都是图片,__init__ 的目的是得到一个包含数据和标签的 list,每个元素能找到图片位置和其对应标签。然后用 __getitem__ 方法得到每个元素的图像像素矩阵和标签,返回 img 和 label。

from PIL import Image
from torch.utils.data import Datasetclass MyDataset(Dataset):def __init__(self, txt_path, transform = None, target_transform = None):fh = open(txt_path, 'r')imgs = []for line in fh:line = line.rstrip()words = line.split()imgs.append((words[0], int(words[1])))self.imgs = imgs self.transform = transformself.target_transform = target_transformdef __getitem__(self, index):fn, label = self.imgs[index]#img = Image.open(fn).convert('RGB') img = Image.open(fn)if self.transform is not None:img = self.transform(img) return img, labeldef __len__(self):return len(self.imgs)

getitem 是核心函数。self.imgs 是一个 list,self.imgs[index] 是一个 str,包含图片路径,图片标签,这些信息是从上面生成的txt文件中读取;利用 Image.open 对图片进行读取,注意这里的 img 是单通道还是三通道的;self.transform(img) 对图片进行处理,这个 transform 里边可以实现减均值、除标准差、随机裁剪、旋转、翻转、放射变换等操作。

当 Mydataset构 建好,剩下的操作就交给 DataLoder,在 DataLoder 中,会触发 Mydataset 中的 getiterm 函数读取一张图片的数据和标签,并拼接成一个 batch 返回,作为模型真正的输入。

pipline_train = transforms.Compose([#随机旋转图片transforms.RandomHorizontalFlip(),#将图片尺寸resize到32x32transforms.Resize((32,32)),#将图片转化为Tensor格式transforms.ToTensor(),#正则化(当模型出现过拟合的情况时,用来降低模型的复杂度)transforms.Normalize((0.1307,),(0.3081,))
])
pipline_test = transforms.Compose([#将图片尺寸resize到32x32transforms.Resize((32,32)),transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))
])
train_data = MyDataset('./data/LEDNUM/train.txt', transform=pipline_train)
test_data = MyDataset('./data/LEDNUM/test.txt', transform=pipline_test)#train_data 和test_data包含多有的训练与测试数据,调用DataLoader批量加载
trainloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=8, shuffle=True)
testloader = torch.utils.data.DataLoader(dataset=test_data, batch_size=4, shuffle=False)

3.2 搭建 LeNet-5 神经网络结构

class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5) self.relu = nn.ReLU()self.maxpool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.maxpool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = x.view(-1, 16*5*5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)output = F.log_softmax(x, dim=1)return output

3.3 将定义好的网络结构搭载到 GPU/CPU,并定义优化器

#创建模型,部署gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet().to(device)
#定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)

3.4 定义训练函数

def train_runner(model, device, trainloader, optimizer, epoch):#训练模型, 启用 BatchNormalization 和 Dropout, 将BatchNormalization和Dropout置为Truemodel.train()total = 0correct =0.0#enumerate迭代已加载的数据集,同时获取数据和数据下标for i, data in enumerate(trainloader, 0):inputs, labels = data#把模型部署到device上inputs, labels = inputs.to(device), labels.to(device)#初始化梯度optimizer.zero_grad()#保存训练结果outputs = model(inputs)#计算损失和#多分类情况通常使用cross_entropy(交叉熵损失函数), 而对于二分类问题, 通常使用sigmodloss = F.cross_entropy(outputs, labels)#获取最大概率的预测结果#dim=1表示返回每一行的最大值对应的列下标predict = outputs.argmax(dim=1)total += labels.size(0)correct += (predict == labels).sum().item()#反向传播loss.backward()#更新参数optimizer.step()if i % 100 == 0:#loss.item()表示当前loss的数值print("Train Epoch{} \t Loss: {:.6f}, accuracy: {:.6f}%".format(epoch, loss.item(), 100*(correct/total)))Loss.append(loss.item())Accuracy.append(correct/total)return loss.item(), correct/total

3.5 定义测试函数

def test_runner(model, device, testloader):#模型验证, 必须要写, 否则只要有输入数据, 即使不训练, 它也会改变权值#因为调用eval()将不启用 BatchNormalization 和 Dropout, BatchNormalization和Dropout置为Falsemodel.eval()#统计模型正确率, 设置初始值correct = 0.0test_loss = 0.0total = 0#torch.no_grad将不会计算梯度, 也不会进行反向传播with torch.no_grad():for data, label in testloader:data, label = data.to(device), label.to(device)output = model(data)test_loss += F.cross_entropy(output, label).item()predict = output.argmax(dim=1)#计算正确数量total += label.size(0)correct += (predict == label).sum().item()#计算损失值print("test_avarage_loss: {:.6f}, accuracy: {:.6f}%".format(test_loss/total, 100*(correct/total)))

3.6 运行

#调用
epoch = 5
Loss = []
Accuracy = []
for epoch in range(1, epoch+1):print("start_time",time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))loss, acc = train_runner(model, device, trainloader, optimizer, epoch)Loss.append(loss)Accuracy.append(acc)test_runner(model, device, testloader)print("end_time: ",time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),'\n')print('Finished Training')
plt.subplot(2,1,1)
plt.plot(Loss)
plt.title('Loss')
plt.show()
plt.subplot(2,1,2)
plt.plot(Accuracy)
plt.title('Accuracy')
plt.show()

经历 5 次 epoch 的 loss 和 accuracy 曲线如下:

3.7 模型保存

torch.save(model, './models/model-mine.pth') #保存模型

3.8 模型测试

下面使用上面训练的模型对一张 LED 图片进行测试。

from PIL import Image
import numpy as npif __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = torch.load('./models/model-mine.pth') #加载模型model = model.to(device)model.eval()    #把模型转为test模式#读取要预测的图片# 读取要预测的图片img = Image.open("./images/test_led.jpg") # 读取图像#img.show()plt.imshow(img,cmap="gray") # 显示图片plt.axis('off') # 不显示坐标轴plt.show()# 导入图片,图片扩展后为[1,1,32,32]trans = transforms.Compose([#将图片尺寸resize到32x32transforms.Resize((32,32)),transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])img = trans(img)img = img.to(device)img = img.unsqueeze(0)  #图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]# 预测 output = model(img)prob = F.softmax(output,dim=1) #prob是10个分类的概率print("概率:",prob)value, predicted = torch.max(output.data, 1)predict = output.argmax(dim=1)print("预测类别:",predict.item())

概率:tensor([[7.2506e-11, 7.0065e-18, 7.1749e-06, 7.4855e-13, 7.3532e-08, 8.5405e-17,2.5753e-15, 9.7887e-10, 2.7855e-05, 9.9996e-01]],grad_fn=<SoftmaxBackward>)
预测类别:9

模型预测结果正确!

以上就是 PyTorch 构建 LeNet-5 卷积神经网络并用它来识别自定义数据集的例子。全文的代码都是可以顺利运行的,建议大家自己跑一边。

总结:

是我们目前分别复现了 LeNet-5 来识别 MNIST、CIFAR10 和自定义数据集,基本上涵盖了基于 PyToch 的 LeNet-5 实战的所有内容。希望对大家有所帮助!

所有完整的代码我都放在 GitHub 上,GitHub地址为:

https://github.com/RedstoneWill/ObjectDetectionLearner/tree/main/LeNet-5

也可以点击阅读原文进入~


推荐阅读

(点击标题可跳转阅读)

干货 | 公众号历史文章精选

我的深度学习入门路线

我的机器学习入门路线图

重磅

AI有道年度技术文章电子版PDF来啦!

扫描下方二维码,添加 AI有道小助手微信,可申请入群,并获得2020完整技术文章合集PDF(一定要备注:入群 + 地点 + 学校/公司。例如:入群+上海+复旦

长按扫码,申请入群

(添加人数较多,请耐心等待)

感谢你的分享,点赞,在看三  

我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!相关推荐

  1. 【深度学习】我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!

    在上三篇文章: 这可能是神经网络 LeNet-5 最详细的解释了! 我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)! 我用 PyTorch 复现了 LeNet-5 ...

  2. Pytorch基础操作 —— 6. 如何使用自定义数据集

    文章目录 自定义数据集 Step 1. 熟悉你的数据集 有数据就要有标签 数据大小.维度一定要一样 归一化 Step 2. 确定如何加载你的数据集 使用 DataLoader 批量加载数据 需要注意的 ...

  3. 从零学PyTorch:DataLoader构建高效的自定义数据集

    Torch中可以创建一个DataSet对象,并与dataloader一起使用,在训练模型时不断为模型提供数据Torch中DataLoader的参数如下 DataLoader(dataset, batc ...

  4. 我用 PyTorch 复现了 LeNet-5 神经网络(CIFAR10 数据集篇)!

    大家好,我是红色石头! 在上两篇文章: 这可能是神经网络 LeNet-5 最详细的解释了! 我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)! 详细介绍了卷积神经网络 ...

  5. Pytorch之CNN:基于Pytorch框架实现经典卷积神经网络的算法(LeNet、AlexNet、VGG、NIN、GoogleNet、ResNet)——从代码认知CNN经典架构

    Pytorch之CNN:基于Pytorch框架实现经典卷积神经网络的算法(LeNet.AlexNet.VGG.NIN.GoogleNet.ResNet)--从代码认知CNN经典架构 目录 CNN经典算 ...

  6. 快速入门PyTorch(2)--如何构建一个神经网络

    2019 第 43 篇,总第 67 篇文章 本文大约 4600 字,阅读大约需要 10 分钟 快速入门 PyTorch 教程第二篇,这篇介绍如何构建一个神经网络.上一篇文章: 快速入门Pytorch( ...

  7. PyTorch如何构建和实验神经网络

    点击上方"视学算法",马上关注 真爱,请设置"星标"或点个"在看" 作者 | Tirthajyoti Sarkar 来源 | Medium ...

  8. NVIDIA新作解读:用GAN生成前所未有的高清图像(附PyTorch复现) | PaperDaily #15

    在碎片化阅读充斥眼球的时代,越来越少的人会去关注每篇论文背后的探索和思考. 在这个栏目里,你会快速 get 每篇精选论文的亮点和痛点,时刻紧跟 AI 前沿成果. 点击本文底部的「阅读原文」即刻加入社区 ...

  9. LeNet卷积神经网络

    LeNet卷积神经网络 1.介绍 LeNet分为卷积层块和全连接层块连个部分. 卷积层用来识别图像图像里的空间模式,如线条和物体局部,之后值的最大池化层则用来降低卷积层对位置的敏感性. 卷积层块的输出 ...

最新文章

  1. SAP WM 确认TO单据的时候修改目的地Storage Bin
  2. String.Format使用方法
  3. Kotlin中?和!!的区别
  4. 搜索linux中大于m文件,linux 下查找大于100M的文件(转)
  5. pcie usb3.0 驱动 for linux_微软WSL——Linux桌面版未来之光
  6. exif.js html图片旋转,解决图片显示 Exif.js更改图片的显示方向
  7. dotnet-cli命令小结
  8. windows Server 2003使用ip安全策略禁止某ip访问服务器的方法
  9. 利用jsonp实现跨域请求
  10. jquery href属性和click事件冲突
  11. [转载] Python和java中的垃圾回收机制
  12. shiro-cas------实现单点登出并自定义登出starter
  13. mysql5.7 systemctl启动_CentOS 7上配置MySQL5.7开机自启动方法
  14. 最新python中一升级所有已安装的包方法
  15. iPad2如何从iOS6降级到5.1.1
  16. html表格行的悬停事件,jQuery实现HTML表格隔行变色及鼠标悬停变色效果
  17. 04【前端工程化初探】Jenkines+GitLab+Tomcat流水线配置部署React应用
  18. finecms aip.php漏洞,FineCMS最新getshell漏洞通杀FineCMS5.0.8一下版本 | CN-SEC 中文网
  19. 网易微专业python全栈工程师_Python 的工作已经饱和?那是因为你只会 Python
  20. ADS1278学习总结

热门文章

  1. 利用:header匹配所有标题做目录
  2. window.location.href的target控制
  3. Backbone React Requirejs 应用实战(一)——RequireJS管理React依赖
  4. [Ubuntu] 解决 pip 安装 lxml 出现 x86_64-linux-gnu-gcc 异常
  5. BZOJ5102 POI2018Prawnicy(堆)
  6. window下安装nvm、node.js、npm的步骤
  7. 【转】使用C#发送Http 请求实现模拟登陆(以博客园为例)
  8. PHPExel导出报表--导出类
  9. INFO:InstallShield InstallScript工程中自定义界面文本输入控件的两个注意事项
  10. MYSQL-Can't connect to MySQL server on 'localhost' (10061)