最近在学习pytorch,使用mnist数据集,搭建AlexNet训练并保存模型,将代码做一记录。

建立数据集的方法见pytorch建立自己的数据集(以mnist为例)

搭建网络的方法见用pytorch搭建AlexNet(微调预训练模型及手动搭建)

训练代码如下:

import torch
import os
from torchvision import transforms
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
import DataProcessing as DP
import BuildModel as BM
import torch.nn as nnif __name__ == '__main__':os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'root_path = '/opt/Data/lixiang/ex./pytorch/Alexnet/data/'training_path = 'trainingset/'test_path = 'testset/'model_path = '/opt/Data/lixiang/ex./pytorch/Alexnet/model/'training_imgfile = training_path + 'trainingset_img.txt'training_labelfile = training_path + 'trainingset_label.txt'training_imgdata = training_path + 'img/'test_imgfile = test_path + 'testset_img.txt'test_labelfile = test_path + 'testset_label.txt'test_imgdata = test_path + 'img/'#parameterbatch_size = 128epochs = 20model_type = 'pre'nclasses = 10lr = 0.01use_gpu = torch.cuda.is_available()transformations = transforms.Compose([transforms.Scale(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])dataset_train = DP.DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata, transformations)dataset_test = DP.DataProcessingMnist(root_path, test_imgfile, test_labelfile, test_imgdata, transformations)num_train, num_test = len(dataset_train), len(dataset_test)train_loader = DataLoader(dataset_train, batch_size = batch_size, shuffle = True, num_workers = 0)test_loader = DataLoader(dataset_test, batch_size = batch_size, shuffle = False, num_workers = 0)# build modelmodel = BM.BuildAlexNet(model_type, nclasses)optimizer = optim.SGD(model.parameters(), lr = lr)criterion = nn.CrossEntropyLoss()for epoch in range(epochs):epoch_loss = 0correct_num = 0for i, traindata in enumerate(train_loader):x_train, y_train = traindataif use_gpu:x_train, y_train = Variable(x_train.cuda()),Variable(y_train.cuda())model = model.cuda()else:x_train, y_train = Variable(x_train),Variable(y_train)y_pre = model(x_train)_, label_pre = torch.max(y_pre.data, 1)if use_gpu:y_pre = y_pre.cuda()label_pre = label_pre.cuda()model.zero_grad()loss = criterion(y_pre, y_train)loss.backward()optimizer.step()epoch_loss += loss.data[0]correct_num += torch.sum(label_pre == y_train.data)        acc = (torch.sum(label_pre == y_train.data).float()/len(y_train))  print('batch loss: {} batch acc: {}'.format(loss.data[0],acc.data[0]))print('epoch: {} training loss: {}, training acc: {}'.format(epoch, epoch_loss, correct_num.float()/num_train))if (epoch+1) % 5 ==0:test_loss = 0test_acc_num = 0for j, testdata in enumerate(test_loader):x_test, y_test = testdataif use_gpu:x_test, y_test = Variable(x_test.cuda()), Variable(y_test.cuda())else:x_test, y_test = Variable(x_test), Variable(y_test)y_pre = model(x_test)_, label_pre = torch.max(y_pre.data, 1)loss = criterion(y_pre, y_test)test_loss += loss.data[0]test_acc_num += torch.sum(label_pre == y_test.data)print('epoch: {} test loss: {} test acc: {}'.format(epoch, test_loss, test_acc_num.float()/num_test))torch.save(model.state_dict(), model_path + 'AlexNet_params.pkl')

主要注意的是一些数据类型的问题,比如label的类型要是LongTensor,损失函数nn.CrossEntropyLoss() 的输入target要是类别编号而不是one-hot编码,使用gpu时要把model和输出y_pre,label_pre移动到gpu上。

pytorch下搭建网络训练并保存模型相关推荐

  1. 神经网络学习小记录14——slim常用函数与如何训练、保存模型

    神经网络学习小记录14--slim训练与保存模型 学习前言 slim是什么 slim常用函数 1.slim = tf.contrib.slim 2.slim.create_global_step 3. ...

  2. 【直播】陈安东,但扬:CNN模型搭建、训练以及LSTM模型思路详解

    CNN模型搭建.训练以及LSTM模型思路详解 目前 Datawhale第24期组队学习 正在如火如荼的进行中.为了大家更好的学习"零基础入门语音识别(食物声音识别)"的课程设计者 ...

  3. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)...

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  4. Pytorch快速搭建并训练CNN模型?

    图像来自:快速上手笔记,PyTorch模型训练实用教程(附代码) - 知乎 目录 1.数据处理模块搭建 2.模型构建 3.开始训练 4.评估模型 5.使用模型进行预测 6.保存模型 1.数据处理模块搭 ...

  5. PyTorch下的网络可视化方式和工具

    直接输出网络结果(以文本形式) 以以下博客为例:4.Deep Residual Network_马鹏森的博客-CSDN博客 The simplest way to visualize is to pr ...

  6. PyTorch中CNN网络参数计算和模型文件大小预估

    前言 在深度学习CNN构建过程中,网络的参数量是一个需要考虑的问题.太深的网络或是太大的卷积核.太多的特征图通道数都会导致网络参数量上升.写出的模型文件也会很大.所以提前计算网络参数和预估模型文件大小 ...

  7. 图像识别python cnn_MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)...

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 全连接神经网络是深度学习的基础,理解它就可以掌握深度学习的核心概念:前向传播.反向误差传递.权重.学习 ...

  8. pytorch: 在训练中保存模型,加载模型

    文章目录 1. 保存整个模型 2.仅保存和加载模型参数(推荐使用) 3. 保存其他参数到模型中,比如optimizer,epoch 等 1. 保存整个模型 torch.save(model, 'mod ...

  9. Pytorch搭建网络训练葡萄酒分类数据集(三分类)

    代码如下: import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F ...

最新文章

  1. 程序员开发进度太慢被告上法庭!公司索赔90万,拿出百度词条当证据
  2. 关于USART接收中断的BUG和注意事项
  3. Android7.0多窗口实现原理(一)
  4. 6本Android开发必备图书
  5. 一张图看透办公网安全
  6. iris数据集——决策树
  7. python生成倒计时图片_用Python自动化生成新年倒计时图片
  8. html控制word打印在一张页面,HTML文件到WORD文档双面打印三步曲
  9. 状态良好(恢复分区)空间的删除的方法
  10. linux 怎么彻底删除用户,linux如何完全删除用户
  11. A FastDetectionMethodviaRegion-BasedFullyConvolutionalNeuralNetworksforShieldTunnelLiningDefects-笔记
  12. supervisor web页面访问
  13. 计算机绘制表格教案,电脑制作表格教案设计
  14. 聊一聊推荐系统中ExploitExplore算法
  15. 中国(吉林)首批援萨摩亚医疗队凯旋
  16. Springcloud微服务概述
  17. 如何在https协议下访问http等不安全的资源
  18. 诺基亚Lumia920竞争力分析——对比三星Ativ S、Galaxy S3、HTC One X、Iphone5
  19. JavaEE 面试题总结
  20. 「新世相」都写过什么题材?如何通过数据挖掘写作题材

热门文章

  1. 课程格子是这样做大学生市场的
  2. 关于SIGTERM信号
  3. progress java驱动_JAVA连接Progress数据库
  4. 电脑如何连接打印机以及共享打印
  5. 一种更为高效的WAL的实现方式
  6. python转义html字符串,用python处理html代码的转义与复原
  7. Web3j如何在ETH智能合约调用请求发出前获取到转账Hash
  8. STM32的命名含义
  9. 2021-01-28 PMP 群内练习题 - 光环
  10. 极验验证码 Geetest