和上相比。。下使用了board可视化训练过程,训练结束后在log文件下面生成日志
在终端输入命令

tensorboard --logdir ./

打开

# -*- encoding: utf-8 -*-
"""
@File    : train.py
@Time    : 2021-03-07 16:24
@Author  : XD
@Email   : gudianpai@qq.com
@Software: PyCharm
"""
import osimport torch
import torch.nn as nn
import torchvision
import tensorboardXfrom vggnet import VGGNetfrom load_cifar10 import train_loader
from load_cifar10 import test_loader#判断是否有gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#遍历200次
epoch_num = 2#学习率
lr = 0.01#batch_size = 128
batch_size = 128net = VGGNet().to(device)#loss多分类问题,交叉熵来定义
loss_func = nn.CrossEntropyLoss()#定义优化器
optimizer = torch.optim.Adam(net.parameters(),lr = lr)
#optimizer = torch.optim.SGD(net.parameters(),lr = lr,
#                            monmentum = 0.9,weight_decat = 5e-4)#质数衰减学习率
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size = 1,gamma = 0.9)if not os.path.exists("log"):os.mkdir("log")
writer = tensorboardX.SummaryWriter("log")step_n = 0for epoch in range(epoch_num):print(" epoch is: ", epoch)net.train() #train BN dropoutfor i, data in enumerate(train_loader):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)outputs = net(inputs)loss = loss_func(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()print("step:",i,"loss is:",loss.item())_,pred = torch.max(outputs.data,dim = 1)correct = pred.eq(labels.data).cpu().sum()# print(" epoch is: ", epoch)# print("step:",i,"loss is:",loss.item(),#       "mini-batch correct is:",100.0 * correct / batch_size)# print("lr is:", optimizer.state_dict()["param_groups"][0]["lr"])#x = torch.tensor([1.0])#x.item()# 1.0writer.add_scalar("train loss:",loss.item(),global_step = step_n)writer.add_scalar("train correct",100.0 * correct.item(),global_step = step_n)im = torchvision.utils.make_grid(inputs)writer.add_image("train im",im,global_step = step_n)step_n += 1if not os.path.exists("models"):os.mkdir("models")torch.save(net.state_dict(),"models\{}.pth".format(epoch + 1))scheduler.step()#编写一个测试脚本sum_loss = 0sum_correct = 0for i, data in enumerate(test_loader):net.eval()inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)outputs = net(inputs)loss = loss_func(outputs, labels)_, pred = torch.max(outputs.data, dim=1)correct = pred.eq(labels.data).cpu().sum()sum_loss += loss.item()sum_correct += correct.item()writer.add_scalar("test loss:", loss.item(),global_step = step_n)writer.add_scalar("test correct:",100.0 * correct.item() / batch_size,global_step = step_n)writer.add_image("test im",im,global_step = step_n)test_loss = sum_loss * 1.0 / len(test_loader)test_correct = sum_correct * 100.0 / len(test_loader) / batch_sizeprint("epoch is:", epoch + 1, "loss is:", test_loss,"test correct is:", test_correct)writer.close()

6-7Pytorch搭建cifar10训练脚本(下)相关推荐

  1. 6-7Pytorch搭建cifar10训练脚本(上)

    需要详解一下代码~~ import torch.nn.functional as F,包含 torch.nn 库中所有函数,同时包含大量 loss 和 activation function # -* ...

  2. 基于Keras搭建cifar10数据集训练预测Pipeline

    基于Keras搭建cifar10数据集训练预测Pipeline 钢笔先生关注 0.5412019.01.17 22:52:05字数 227阅读 500 Pipeline 本次训练模型的数据直接使用Ke ...

  3. Centos7下的LibreOffice的搭建及自动化脚本部署

    Centos7下的LibreOffice的搭建及自动化脚本部署 LibreOffice 简介 LibreOffice 是一个强大的办公套件 – 它清晰的界面和强大的工具让您释放您的创造力并增长您的生产 ...

  4. pytorch下搭建网络训练并保存模型

    最近在学习pytorch,使用mnist数据集,搭建AlexNet训练并保存模型,将代码做一记录. 建立数据集的方法见pytorch建立自己的数据集(以mnist为例) 搭建网络的方法见用pytorc ...

  5. 天池训练营——基于人脸的常见表情识别(3)——模型搭建、训练与测试

    在完成数据准备之后,便可以使用 PyTorch 深度学习框架,实现卷积神经网络的定义.训练和预测. 一.模型搭建与训练 得到了数据之后,接下来咱们使用 PyTorch 这个框架来进行模型的训练.整个训 ...

  6. 基于人脸的常见表情识别(3)——模型搭建、训练与测试

    基于人脸的常见表情识别(3)--模型搭建.训练与测试 模型搭建与训练 1. 数据接口准备 2. 模型定义 3. 模型训练 模型测试 本 Task 是『基于人脸的常见表情识别』训练营的第 3 课,如果你 ...

  7. Faster-rcnn环境搭建与训练自己的数据

    Faster-RCNN环境搭建与训练自己的数据 0 前言 之前整理过一篇关于fasterrcnn的文章,文中详细介绍了fasterrcnn原理分析,近期由于工作需要利用fasterrcnn进行模型训练 ...

  8. 环境搭建:Windows系统下Nacos集群搭建

    环境搭建:Windows系统下Nacos集群搭建 一.环境准备 名称 版本 下载地址 nacos NACOS 1.2.0 下载地址,提取码:5555 MySQL mysql Ver 14.14 Dis ...

  9. 第十二章_网络搭建及训练

    文章目录 第十二章 网络搭建及训练 CNN训练注意事项 第十二章 TensorFlow.pytorch和caffe介绍 12.1 TensorFlow 12.1.1 TensorFlow是什么? 12 ...

最新文章

  1. F5负载均衡会话保持技术及原理技术白皮书
  2. 独家 | 带你认识HDFS和如何创建3个节点HDFS集群(附代码案例)
  3. 关于一个枚举IE表单的DLL,编译无错,但是得不到想到的结果。
  4. springboot map数据类型注入_SpringBoot结合策略模式实战套路
  5. 云原生全景图之六 | 托管 Kubernetes 和 PaaS 解决什么问题
  6. 华为双11发 20 亿奖金!?
  7. C#LeetCode刷题,走进Google,走近人生
  8. 【FFMPEG系列】windows下编译ffmpeg且加入libx264
  9. 银行卡号,指定字符长度分割字符串
  10. 圣经 创世纪 1:20-22
  11. java 气象数据_中国天气预报数据API收集
  12. 什么是索引,索引的作用,什么时候需要使用索引,什么时候不需要使用索引
  13. win32gui操作
  14. 1bit quantization
  15. iOS开发 ☞ Commen Sense
  16. Java基础知识Day08---Scaner类
  17. Python shellcode免杀
  18. F-score is ill-defined and being set to 0.0 in labels with no true samples.
  19. 部分国外邮箱服务商简介
  20. SVN提交报错svn: Commit blocked by pre-commit hook (exit code 1) with output: Can't get Mantis_Key的解决办法

热门文章

  1. 【BZOJ3601】一个人的数论,莫比乌斯反演+高斯消元
  2. 【BZOJ3675】序列统计,斜率优化DP
  3. python80行代码写一个文件整理软件
  4. android 自定义button,android – 如何添加自定义按钮状态
  5. java filter 回调_Java 异步回调机制实例分析
  6. CentOS7 安装lua环境
  7. 光线求交加速算法:边界体积层次结构(Bounding Volume Hierarchies)2-表面积启发式法(The Surface Area Heuristic)
  8. python空集合_python空集合
  9. python嵌入shell代码_小白进!嵌入式开发如何快速入门?
  10. 机器学习分类_机器学习之简单分类模型