目录

  • 训练
    • 训练代码
    • nn.NLLLoss()与nn.CrossEntropyLoss()的区别
  • 测试
    • 混淆矩阵(Confusion Matrix)
    • 验证代码

github地址:
https://github.com/Huyf9/mnist_pytorch/

训练

在训练之前,需要定义以下几个参数

DEVICE  #设备
BATCH_SIZE  #批量数
CRITERION  #损失函数
LR  #学习率
OPTIMIZER  #优化器
EPOCHS  #训练轮数

我的训练参数设置如下:

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 64
CRITERION  = nn.NLLLoss()
LR = 0.001
OPTIMIZER  = torch.optim.SGD(net.parameters(), lr=LR)
EPOCHS = 200

训练代码

以卷积网络模型为例,训练代码如下:

import torch
import torch.nn as nn
from MnistDataset import Mydataset
from torch.utils.data import DataLoader
from model import ConvNet
from tqdm import tqdmDEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(64)  # 设置一个随机种子,保证每次训练的结果一样train_path = ['train.txt', 'tr_label.txt']
val_path = ['val.txt', 'val_label.txt']
train_dataset = Mydataset(train_path[0], train_path[1], device)
val_dataset = Mydataset(val_path[0], val_path[1], device)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, drop_last=True, num_workers=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=True, drop_last=True, num_workers=4)net = ConvNet().to(device)
CRITERION = nn.NLLLoss()
LR = 0.001
OPTIMIZER = torch.optim.SGD(net.parameters(), lr=LR)
EPOCHS = 100def train(epoch, epoch_loss, batch_num):for i, (pic, label) in tqdm(enumerate(train_dataloader)):batch_num += 1net.zero_grad()out = net(pic)# print(out)loss_value = loss(out, label)epoch_loss += loss_valueloss_value.backward()optimizer.step()print(f'epoch: {epoch}\ttrain_loss: {epoch_loss/batch_num}')def val(epoch, epoch_loss, batch_num):for i, (pic, label) in tqdm(enumerate(val_dataloader)):epoch_num += 1out = net(pic)loss_value = loss(out, label)epoch_loss += loss_valueprint(f'epoch: {epoch}\tval_loss: {epoch_loss/batch_num}')def main():for epoch in range(EPOCHS):epoch_loss, batch_num = 0, 0train(epoch, epoch_loss, batch_num)val(epoch, epoch_loss, batch_num)if (epoch+1) % 10 == 0:torch.save(net.state_dict(), f'model_parameter\\parameter_epo{epoch}.pth')if __name__ == '__main__':main()

这里我利用epoch_loss来累加每一个批次的损失函数,再用batch_num来记录每一轮的批次数,最后相除作为这一轮的平均损失。

nn.NLLLoss()与nn.CrossEntropyLoss()的区别

CrossEntropyLoss()=NLLLoss()+LogSoftmax()CrossEntropyLoss() = NLLLoss() + LogSoftmax()CrossEntropyLoss()=NLLLoss()+LogSoftmax()
由于我们在构建网络的时候在最后一层加上了nn.LogSoftmax(),因此在定义损失函数时我们采用NLLLoss()。

测试

验证部分我们利用训练好的训练模型,将图片输入进去,返回一个1行10列的向量表示数字0-9的概率,我们取最大概率的索引表示模型判断的数字。我们将其保存在混淆矩阵中

混淆矩阵(Confusion Matrix)

混淆矩阵表示分类模型的预测值与真实值的对比情况。以二分类的混淆矩阵为例:

Positive Negative
True TP TN
False FP FN

TP表示正确被预测为正例的数量。
TN表示正确被预测为负例的数量。
FP表示错误预测为正例的数量。
FN表示错误被预测为负例的数量。

这里我们引入三个评价指标来度量一个模型的预测能力:查准率(Percision)、查全率(Recall)、F1。
Percision=TP/(TP+FP)Percision = TP / (TP+FP)Percision=TP/(TP+FP)

表示分类模型预测为正例的样本中,真正为正例的样本比重。

Recall=TP/(TP+FN)Recall = TP / (TP+FN)Recall=TP/(TP+FN)

表示分类模型预测为正例的样本占总正例样本的比重

一般来说,Percision与Recall为一对矛盾的指标,我们一般不会指定某一个指标衡量模型性能,因此我们需要协调两种指标的值,这样就需要引入一个评价指标F1:
F1=2∗Percision∗RecallPercision+RecallF1 = 2*{{Percision*Recall}\over{Percision+Recall}}F1=2∗Percision+RecallPercision∗Recall​

验证代码

验证代码如下

import torch
import seaborn as sn
from matplotlib import pyplot as plt
from model import ConvNet
from MnistDataset import Mydataset
from torch.utils.data import DataLoader
import numpy as np
import pandas as pdtorch.manual_seed(13)def get_score(confusion_mat):smooth = 0.0001  #防止出现除数为0而加上一个很小的数tp = np.diagonal(confusion_mat)fp = np.sum(confusion_mat, axis=0)fn = np.sum(confusion_mat, axis=1)precision = tp / (fp + smooth)recall = tp / (fn + smooth)f1 = 2 * precision * recall / (precision + recall + smooth)return precision, recall, f1def get_confusion(confusion_matrix, out, label):idx = np.argmax(out.detach().numpy())confusion_matrix[idx, label] += 1return confusion_matrixdef main():confusion_matrix = np.zeros((10, 10))net = ConvNet()net.load_state_dict(torch.load('model_parameter\\parameter_epo90.pth'))test_path = ['test.txt', r'dataset/test_label.txt']test_dataset = Mydataset(test_path[0], test_path[1], 'cpu')test_dataloader = DataLoader(test_dataset, 1, True)for i, (pic, label) in enumerate(test_dataloader):out = net(pic)confusion_matrix = get_confusion(confusion_matrix, out, label)precision, recall, f1 = get_score(confusion_matrix)print(f'precision: {np.average(precision)}\trecall: {np.average(recall)}\tf1: {np.average(f1)}')confusion_mat = pd.DataFrame(confusion_matrix)confusion_df = pd.DataFrame(confusion_mat, index=[i for i in range(10)], columns=[i for i in range(10)])sn.heatmap(data=confusion_df, cmap='RdBu_r')plt.show()confusion_df.to_csv(r'confusion.csv', encoding='ANSI')if __name__ == '__main__':main()

手写字体识别(3) 训练及测试相关推荐

  1. 人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist)

    人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist),使用技术(Django+js+tensorflow+html+bootstrap+inspinia框架) 直接上图,项目效果 1 ...

  2. pytorch rnn 实现手写字体识别

    pytorch rnn 实现手写字体识别 构建 RNN 代码 加载数据 使用RNN 训练 和测试数据 构建 RNN 代码 import torch import torch.nn as nn from ...

  3. 手写字体识别 --MNIST数据集

    Matlab 手写字体识别 忙过这段时间后,对于上次读取的Matlab内部数据实现的识别,我回味了一番,觉得那个实在太小.所以打算把数据换成[MNIST数据集][1]. 基础思想还是相同的,使用Tre ...

  4. 神经网络学习(二)Tensorflow-简单神经网络(全连接层神经网络)实现手写字体识别

    神经网络学习(二)神经网络-手写字体识别 框架:Tensorflow 1.10.0 数据集:mnist数据集 策略:交叉熵损失 优化:梯度下降 五个模块:拿数据.搭网络.求损失.优化损失.算准确率 一 ...

  5. 神经网络实现手写字体识别

    神经网络入门学习中,进行了手写字体识别实践,该篇博客用于记录实践代码,以备后续使用. 关键词:神经网络,前向传播.反向传播.梯度下降.权值更新.手写字体识别 1. 实践代码 import numpy ...

  6. 计算机视觉ch8 基于LeNet的手写字体识别

    文章目录 原理 LeNet的简单介绍 Minist数据集的特点 Python代码实现 原理 卷积神经网络参考:https://www.cnblogs.com/chensheng-zhou/p/6380 ...

  7. 【PyTorch学习笔记_04】--- PyTorch(开始动手操作_案例1:手写字体识别)

    手写字体识别的流程 定义超参数(自己定义的参数) 构建transforms, 主要是对图像做变换 下载,加载数据集MNIST 构建网络模型(重要,自己定义) 定义训练方法 定义测试方法 开始训练模型, ...

  8. PyTorch手写字体识别MNIST

    手写字体识别MNIST 1.准备工作 可以看这个老师的视频进行学习,讲解的非常仔细:视频学习 2.项目代码 2.1 导入模块 # 1.加载相关库 import torch import torch.n ...

  9. 基于Python神经网络的手写字体识别

    本文将分享实现手写字体识别的神经网络实现,代码中有详细注释以及我自己的一些体会,希望能帮助到大家 (≧∇≦)/ ############################################ ...

最新文章

  1. 【第23周复盘】懒癌犯了,拖到今天!
  2. MacOS系统下简单安装以及配置MongoDB数据库(一)
  3. 一张图解释什么是遗传算法_一张图告诉你什么叫真正的满配m416,吃鸡玩家看懵了...
  4. mysql为查询结果字段赋默认值
  5. 数据下载工作笔记三:脚本
  6. jzoj3084-超级变变变【数学】
  7. echarts label固定位置_ECharts+百度地图网络拓扑应用
  8. Arduino笔记-外部中断实验(震动传感器实时亮灯)
  9. font awesome java_java awt实现 fontawesome转png
  10. mysql explain desc_MySQL中EXPLAIN命令详解
  11. python中tmp什么意思_python中temp是什么意思-问答-阿里云开发者社区-阿里云
  12. 《ANTLR 4权威指南》——第2章 纵 观 全 局 2.1 从ANTLR元语言开始
  13. html5游戏打包apk,laya打包APK无法进入游戏
  14. 45_局域网ip正则表达式
  15. c语言中取反符号的理解
  16. EXCEL如何将两列数据合并为一列并在中间加符号
  17. 配置管理的目标和主要活动
  18. 分布式数据结构与算法面试题
  19. 常用测试用例设计方法总结
  20. ZOJ3587 Marlon's String

热门文章

  1. [NOIP2001] 统计单词个数
  2. 新一代智能搜索引擎,让搜索一击即中
  3. 区块链电子签名技术及方案
  4. linux能ping通ssh连不上,能ping通Linux但是ssh连不上问题解决方法
  5. 微波技术基础------史密斯原图
  6. 2018最新出现的勒索病毒及变种统计丨阿里云河南
  7. spring-boot-2.0.3启动源码篇一 - SpringApplication构造方法
  8. 学习php技术最快需要多久,学php最快要多久 学习路线
  9. Jetson AGX Xavier部署深度学习环境
  10. 企业级ajax框架,企业级AJAX框架设计与实现.pdf