代码:

import torch.nn as nn
import torch
import torch.nn.functional as F
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import argparse
import os# 训练
def train(args, model, device, train_loader, optimizer):for epoch in range(1, args.epochs + 1):model.train()for batch_index, data in enumerate(train_loader):images, labels = dataimages = images.to(device)labels = labels.to(device)# forwardoutput = model(images)loss = F.cross_entropy(output, labels)# backwardoptimizer.zero_grad()  # 梯度清空loss.backward()  # 梯度回传,更新参数optimizer.step()# 打印lossprint(f'Epoch:{epoch},Batch ID:{batch_index}/{len(train_loader)}, loss:{loss}')# 保存模型if epoch % args.checkpoint_interval == 0:torch.save(model.state_dict(), f'checkpoints/cifar10_%d.pth' % epoch)def test(args, model, device, test_loader):model.eval()total_loss = 0num_correect = 0with torch.no_grad():for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)# 总的losstotal_loss += F.cross_entropy(outputs, labels).item()# 预测值_, predected = torch.max(outputs, dim=1)# 预测对的总个数num_correect += (predected==labels).sum().item()# 计算平均lossaverage_loss = total_loss / len(test_loader.dataset)# 计算准确率accuracy = num_correect / len(test_loader.dataset)# 打印平均loss和准确率print(f'Average loss:{average_loss}\nTest Accuracy:{accuracy*100}%')if __name__ == '__main__':parser = argparse.ArgumentParser(description = 'Pytorch-cifar10_classification')parser.add_argument('--epochs', type=int, default=10, help='number of epochs')parser.add_argument('--batch_size', type=int, default=32, help='size of each image batch')parser.add_argument('--num_classes', type=int, default=10, help='number of classes')parser.add_argument('--lr', type=float, default=0.001, help='learning rate')parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')parser.add_argument('--pretrained_weights', type=str, default='checkpoints/cifar10_17.pth',help='if specified starts from checkpoint model')parser.add_argument("--img_size", type=int, default=224, help="size of each image dimension")parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")parser.add_argument("--train", default=True, help="train or test")args = parser.parse_args()print(args)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# os.makedirs() 方法用于递归创建目录os.makedirs("output", exist_ok=True)os.makedirs("checkpoints", exist_ok=True)# transformdata_transform = transforms.Compose([transforms.ToTensor(),transforms.RandomResizedCrop(args.img_size)])# 下载训练数据集trian_data = datasets.CIFAR10(root = 'data',train = True,download = False,transform = data_transform,target_transform = None,)# 下载测试数据集test_data = datasets.CIFAR10(root = "data",train = False,download = False,transform = data_transform,target_transform = None)# 加载数据train_loader = DataLoader(dataset = trian_data,batch_size = args.batch_size,shuffle = True)test_loader = DataLoader(dataset = test_data,batch_size = args.batch_size)# 创建模型,使用预训练好的权重model = models.vgg16(pretrained = True)# # 冻结模型,参数不更新# for para in model.parameters():#     para.requires_grad = False# # 只训练全连接层# model.classifier[3].requires_grad = True# model.classifier[6].requires_grad = True# 修改vgg16的输出维度model.classifier[6] = nn.Linear(in_features=4096, out_features=args.num_classes, bias=True)model = model.to(device)# 打印网络结构print(model)# 定义优化器(也可以选择其他优化器)optimizer = torch.optim.SGD(model.parameters(), lr = args.lr, momentum = args.momentum)# optimizer = torch.optim.Adam(model.parameters())if train == True:if args.pretrained_weights.endswith(".pth"):model.load_state_dict(torch.load(args.pretrained_weights))for epoch in range(1, epochs+1):train(args, model, device, train_loader, optimizer)else:if args.pretrained_weights.endswith(".pth"):model.load_state_dict(torch.load(args.pretrained_weights))test(args, model, device, test_loader)

说明:
        cifar10数据集可以通过trochvision中的datasets.CIFAR10下载,也可以自己下载(注意存放路径);我模型使用的是torchvision中的models中预训练好的vgg16网络,也可以自己搭建网络。

【Pytorch】CIFAR1010数据集的训练和测试相关推荐

  1. FCN制作自己的数据集、训练和测试 caffe

    原文:http://blog.csdn.net/zoro_lov3/article/details/74550735 FCN制作自己的数据集.训练和测试全流程 花了两三周的时间,在导师的催促下,把FC ...

  2. python划分数据集用pandas_用pandas划分数据集实现训练集和测试集

    1.使用model_select子模块中的train_test_split函数进行划分 数据:使用kaggle上Titanic数据集 划分方法:随机划分 # 导入pandas模块,sklearn中mo ...

  3. [Python+sklearn] 拆分数据集为训练和测试子集 sklearn.model_selection.train_test_split()

    Python - sklearn 拆分数据集为训练和测试子集 sklearn.model_selection.train_test_split() 功能: 将数组或矩阵拆分为随机的训练子集和测试子集 ...

  4. 将数据集分为训练集和测试集(python脚本)

    文章目录 程序: 下面简单介绍一下程序流程 1.引入库 os库 shutil random 2.mk_file函数 3.主函数 程序: 我们在训练卷积神经网络之前,要搭建好数据集,分成训练集和测试集两 ...

  5. 【Pytorch】MNIST数据集的训练和测试

    训练和测试的完整代码: import torch import torch.nn as nn import torch.nn.functional as F from torchvision impo ...

  6. 自定义ava数据集及训练与测试 完整版 时空动作/行为 视频数据集制作 yolov5, deep sort, VIA MMAction, SlowFast

    前言 这一篇博客应该是我花时间最多的一次了,从2022年1月底至2022年4月底. 我已经将这篇博客的内容写为论文,上传至arxiv:https://arxiv.org/pdf/2204.10160. ...

  7. 7个Bert变种模型baseline在7个文本分类数据集上训练和测试

    引入和代码项目简介 https://github.com/songyingxin/Bert-TextClassification 模型有哪些? 使用的模型有下面七个 BertOrigin, BertC ...

  8. 【caffe】mnist数据集lenet训练与测试

    在上一篇中,费了九牛二虎之力总算是把Caffe编译通过了,现在我们可以借助mnist数据集,测试下Caffe的训练和检测效果. 准备工作:在自己的工作目录下,新建一个文件夹,命名为mnist_test ...

  9. 机器学习之数据集划分——训练集测试集划分,划分函数,估计器的使用

    训练集测试集划分,划分函数,估计器的使用 参考文章 训练集.验证集和测试集的划分及交叉验证的讲解 划分训练集和测试集的函数学习 sklearn数据集,数据集划分,估计器详细讲解 参考文章 训练集.验证 ...

最新文章

  1. 特斯拉遇上 CPU:程序员的心思你别猜
  2. 二级C语言学习宝典下载,二级C语言学习宝典
  3. [武道资料]《菲律宾短棍-单棍》(Edgar Sulite Lameco Escrima Single Stick)
  4. 详细盘点joomla1.5和2.5中那些常用的扩展
  5. leetcode 292. Nim Game | 292. Nim 游戏(DP->数学推理)
  6. lame,把ios录音转换为mp3格式
  7. 使用Drools 6.0进行部署
  8. ajax前台multipartfile,在SpringBoot中使用Ajax方式MultipartFile上传失败
  9. nhibernate学习之集合组合依赖
  10. Win7旗舰版系统网页显示不全怎么办
  11. 【项目管理】敏捷和Scrum
  12. 五、Hashtable与HashMap的区别
  13. css3 media媒体查询器用法总结
  14. 多线程同步工具——volatile变量
  15. 图形界面编程成就了C++
  16. 专题八图形窗口与坐标轴
  17. 使用网络爬虫为英语单词添加音标
  18. android百度地图api两点画线,android百度地图:在地图上绘制点、线、多边形、圆形和文字...
  19. 摩拜单车的“黑科技”
  20. 反直觉的三门问题,80%的人都会错?

热门文章

  1. iconv 判断字符编码_php下用iconv函数转换字符编码的问题
  2. linux中流设备_Linux中的标准文件I/O流
  3. 【C语言】又是排序(指针专题)
  4. python算法应用(六)——搜索与排名2(PageRank算法及其拓展应用)
  5. 6410的系统时钟设置(上)---6410时钟控制逻辑框架分析
  6. pthread_mutex_init函数《代码》
  7. pthread条件变量函数的使用
  8. 【STM32】STM32F4 CAN2只能发送无法接收问题解决
  9. QT5开发及实例学习之十四Qt5排版功能
  10. python输出布尔值true_关于python中bool类型的重要细节