训练和测试的完整代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import argparse
import os# 训练
def train(args, model, device, train_loader, optimizer):model.train()num_correct = 0for batch_index, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# forwardoutputs = model(images)loss = F.cross_entropy(outputs, labels)# backwardoptimizer.zero_grad()   # 梯度清空loss.backward()            # 梯度回传,更新参数optimizer.step()_, predicted = torch.max(outputs, dim=1)# 每一个batch预测对的个数batch_correct = (predicted == labels).sum().item()# 每一个batch的准确率batch_accuracy = batch_correct / args.batch_size# 每一个epoch预测对的总个数num_correct += (predicted == labels).sum().item()# print sth.print(f'Epoch:{epoch},Batch ID:{batch_index}/{len(train_loader)}, loss:{loss}, Batch accuracy:{batch_accuracy*100}%')# 每一个epoch的准确率epoch_accuracy = num_correct / len(train_loader.dataset)# print epoch_accuracyprint(f'Epoch Accuracy:{epoch_accuracy}')# 保存模型if epoch % args.checkpoint_interval == 0:torch.save(model.state_dict(), f"checkpoints/VGG16_MNIST_%d.pth" % epoch)# 验证
def test(args, model, device, test_loader):model.eval()total_loss = 0num_correct = 0if args.pretrained_weights.endswith(".pth"):model.load_state_dict(torch.load(args.pretrained_weights))# 不计算梯度,节省计算资源with torch.no_grad():for images, labels in test_loader:images = images.to(device)labels = labels.to(device)output = model(images)# 总的losstotal_loss += F.cross_entropy(output, labels).item()     # item()用于取出tensor里边的值# torch.max():返回的是两个值,第一个值是具体的value,第二个值是value所在的index_, predicted = torch.max(output, dim=1)# 预测对的总个数num_correct += (predicted == labels).sum().item()# 平均losstest_loss = total_loss / len(test_loader.dataset)# 平均准确率accuracy = num_correct / len(test_loader.dataset)# print sth.print(f'Average loss:{test_loss}\nTest Accuracy:{accuracy*100}%')if __name__ == '__main__':parser = argparse.ArgumentParser(description = 'Pytorch-MNIST_classification')parser.add_argument('--epochs', type=int, default=20, help='number of epochs')parser.add_argument('--batch_size', type=int, default=32, help='size of each image batch' )parser.add_argument('--num_classes', type=int, default=10, help='number of classes')parser.add_argument('--lr', type=float, default=0.001, help='learning rate')parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')parser.add_argument('--pretrained_weights', type=str, default='checkpoints/', help='if specified starts from checkpoint model')parser.add_argument("--img_size", type=int, default=224, help="size of each image dimension")parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")parser.add_argument("--train",  default=False, help="train or test")args = parser.parse_args()print(args)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# os.makedirs() 方法用于递归创建目录os.makedirs("output", exist_ok=True)os.makedirs("checkpoints", exist_ok=True)# transformdata_transform = transforms.Compose([transforms.ToTensor(), transforms.RandomResizedCrop(args.img_size)])# 下载训练数据train_data = datasets.MNIST(root = 'data',train = True,transform = data_transform,target_transform = None,download = True)# 下载测试数据test_data = datasets.MNIST(root = 'data',train = False,transform = data_transform,target_transform = None,download = True)# 加载训练数据train_loader = DataLoader(dataset = train_data,batch_size = args.batch_size,shuffle = True)# 加载测试数据test_loader = DataLoader(dataset = test_data,batch_size = args.batch_size)# 创建模型model = models.vgg16(pretrained = True)# 修改vgg16的输出维度model.classifier[6] = nn.Linear(in_features=4096, out_features=args.num_classes, bias=True)# MNIST数据集是灰度图,channel数为1model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))print(model)model = model.to(device)# 优化器(也可以选择其他优化器)optimizer = torch.optim.SGD(model.parameters(), lr = args.lr, momentum = args.momentum)# optimizer = torch.optim.Adam()if args.train == True:for epoch in range(1, args.epochs+1):# 是否加载预训练好的权重if args.pretrained_weights.endswith(".pth"):model.load_state_dict(torch.load(args.pretrained_weights))train(args, model, device, train_loader, optimizer)else:# 是否加载预训练好的权重if args.pretrained_weights.endswith(".pth"):model.load_state_dict(torch.load(args.pretrained_weights))test(args, model, device, test_loader)

测试结果:
        我只训练了不到10轮,效果不是太好,还有提升空间。

说明:
        MNIST数据集可以通过trochvision中的datasets.MNIST下载,也可以自己下载(注意存放路径);我模型使用的是torchvision中的models中预训练好的vgg16网络,也可以自己搭建网络。

【Pytorch】MNIST数据集的训练和测试相关推荐

  1. 【caffe】mnist数据集lenet训练与测试

    在上一篇中,费了九牛二虎之力总算是把Caffe编译通过了,现在我们可以借助mnist数据集,测试下Caffe的训练和检测效果. 准备工作:在自己的工作目录下,新建一个文件夹,命名为mnist_test ...

  2. 【Pytorch分布式训练】在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练

    文章目录 普通单卡训练-GPU 普通单卡训练-CPU 分布式训练-GPU 分布式训练-CPU 租GPU服务器相关 以下代码示例基于:在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练. 普 ...

  3. 在MNIST数据集上训练一个手写数字识别模型

    使用Pytorch在MNIST数据集上训练一个手写数字识别模型, 代码和参数文件 可下载 1.1 数据下载 import torchvision as tvtraining_sets = tv.dat ...

  4. DL之DNN:自定义2层神经网络TwoLayerNet模型(计算梯度两种方法)利用MNIST数据集进行训练、预测

    DL之DNN:自定义2层神经网络TwoLayerNet模型(计算梯度两种方法)利用MNIST数据集进行训练.预测 导读 利用python的numpy计算库,进行自定义搭建2层神经网络TwoLayerN ...

  5. FCN制作自己的数据集、训练和测试 caffe

    原文:http://blog.csdn.net/zoro_lov3/article/details/74550735 FCN制作自己的数据集.训练和测试全流程 花了两三周的时间,在导师的催促下,把FC ...

  6. DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练、预测

    DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练.预测 导读           计算图在神经网络算法中的作用.计算图的节点是由局部计算构成的. ...

  7. DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练、GC对比

    DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练.GC对比 导读           神经网络算法封装为层级结构的作用.在神经网络算法中,通过将 ...

  8. DL之DCGNN:基于TF利用DCGAN实现在MNIST数据集上训练生成新样本

    DL之DCGNN:基于TF利用DCGAN实现在MNIST数据集上训练生成新样本 目录 输出结果 设计思路 实现部分代码 说明:所有图片文件丢失 输出结果 更新-- 设计思路 更新-- 实现部分代码 更 ...

  9. python划分数据集用pandas_用pandas划分数据集实现训练集和测试集

    1.使用model_select子模块中的train_test_split函数进行划分 数据:使用kaggle上Titanic数据集 划分方法:随机划分 # 导入pandas模块,sklearn中mo ...

最新文章

  1. 【Flutter】开发 Flutter 包和插件 ( Flutter 包和插件简介 | 创建 Flutter 插件 | 创建 Dart 包 )
  2. ATM(BZOJ 1179)
  3. 《统一沟通-微软-实战》-6-部署-2-中介服务器-6-语音路由-路由
  4. .NET开发不可错过的25款必备工具
  5. CRM Fiori:Complex note optimization design
  6. nor flash和nand flash
  7. 修改html本地样式,html-如何通过Javascript更改CSS类样式?
  8. 彼得.泰尔:认知未来是投资人的谋生之道
  9. (转)证券投资及财富管理市场创新趋势
  10. Windows使用ffmpeg教程
  11. 路由器:斐讯K3C刷官改,固件版本是32.1.46.268
  12. VC++调用大漠插件
  13. 62%中国AI毕业生赴美,机器学习人才最高产大学出炉 | 报告
  14. C++调用Lua出现 unproteted error in call to Lua API错误的发现过程与解决方法
  15. 分组和聚合函数的组合使用实操
  16. 原装苹果手机_苹果手机换个屏水这么深!嘉兴警方揭开“原装屏”真相
  17. Abaqus CPU并行计算 加速计算信息汇总
  18. vivo X90和iPhone 14哪个好 vivo X90 和苹果14 区别对比评测
  19. 路由器网口1一直闪烁正常吗_网口1一直闪烁上不了网
  20. 小学生体测测试环境怎么填_2016年国家学生体质健康测试数据上报工作说明

热门文章

  1. 嵌入式 U 盘自动挂载
  2. 12016.xilinx裸机开发
  3. 东财计算机应用基础单元作业,东财21春《计算机应用基础》单元作业一 【标准答案】...
  4. python写入mysql乱码_python MYsql中文乱码
  5. 【绪论】——声呐概述
  6. Java学习日报—JVM垃圾回收全解—2021/11/26
  7. 我的世界java村民繁殖_我的世界:1.14版本刷新几率小的五种村庄,没有村民咋回事?...
  8. java班长竞选投票_竞选班长采取投票式,引家长不满,班主任:您说该怎么选?...
  9. 【Spring】事务
  10. Java面试23种设计模式之单例模式的8种实现方式