目录

1、导入必要库

2、加载数据

3、构建网络

4、训练模型

5、保存模型参数

1)、仅仅保存和加载模型参数

2)、保存和加载整个模型

3)、保存多个模型参数


1、导入必要库

import torch
from torch import optim, nn
import torch.utils.data as Data

2、加载数据

x = torch.linspace(1, 10, 10)       # x data (torch tensor)
y = torch.linspace(10, 1, 10)       # y data (torch tensor)# 注意:x的数据类型是 torch.FloatTensor
# y的数据类型是 torch.LongTensor
# x = torch.cat((x0, x1), 0).type(torch.FloatTensor)  # FloatTensor = 32-bit floating
# y = torch.cat((y0, y1), ).type(torch.LongTensor)    # LongTensor = 64-bit integer# 先转换成 torch 能识别的 Dataset
torch_dataset = Data.TensorDataset(x, y)# 把 dataset 放入 DataLoader
loader = Data.DataLoader(dataset=torch_dataset,      # torch TensorDataset formatbatch_size=3,      # mini batch sizeshuffle=True,               # 要不要打乱数据 (打乱比较好)num_workers=0,              # 多线程来读数据
)

3、构建网络

# 定义网络结构 build net
class Net(torch.nn.Module):def __init__(self,n_feature,n_hidden,n_output):super(Net, self).__init__()self.fc1 =torch.nn.Linear(n_feature,n_hidden)self.fc2 =torch.nn.Linear(n_hidden,n_output)# 定义一个前向传播过程函数def forward(self, x):x=F.relu(self.fc1(x))x=self.fc2(x)return x
# 实例化一个网络为 model
model = Net(n_feature=1,n_hidden=10,n_output=10)
print(model)

4、训练模型

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss() # 训练模型
model.train()
for epoch in range(5):for step, (b_x, b_y) in enumerate(loader): output = model(b_x)loss = loss_func(output, b_y)optimizer.zero_grad()loss.backward()optimizer.step()# 测试模型
model.eval()
for step, (b_x, b_y) in enumerate(loader):output = model(b_x)loss = loss_func(output, b_y)_, pred_y = torch.max(output.data, 1)correct = (pred_y == b_y).sum()total = b_y.size(0)print('Epoch: ', step, '| test loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % (float(correct)/total))

5、保存模型参数

1)、仅仅保存和加载模型参数

# 保存模型参数
torch.save(model.state_dict(), './path/model.pkl')
# 读取模型参数
model.load_state_dict(torch.load('./path/model.pkl'))

2)、保存和加载整个模型

# 保存整个模型
torch.save(model,  './path/model.pkl')
# 加载整个模型
model = torch.load('./path/model.pkl')

3)、保存多个模型参数

# 多个模型参数保存
torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,...}, PATH)# 模型参数加载
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

Pytorch 模型训练步骤相关推荐

  1. 《PyTorch模型训练实用教程》—学习笔记

    文章目录 前言 数据 Dataset类 DataLoader类 transform 裁剪-Crop 翻转和旋转-Flip and Rotation 图像变换 对transforms操作,使数据增强更灵 ...

  2. 手把手教你洞悉 PyTorch 模型训练过程,彻底掌握 PyTorch 项目实战!(文末重金招聘导师)...

    (文末重金招募导师) 在CVPR 2020会议接收中,PyTorch 使用了405次,TensorFlow 使用了102次,PyTorch使用数是TensorFlow的近4倍. 自2019年开始,越来 ...

  3. PyTorch 模型训练实用教程(附代码)

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx PyTorch 能在短时间内被众多研究人员和工程师接受并推崇是因为其有着诸多优点,如采用 Py ...

  4. Pytorch模型训练实用教程学习笔记:四、优化器与学习率调整

    前言 最近在重温Pytorch基础,然而Pytorch官方文档的各种API是根据字母排列的,并不适合学习阅读. 于是在gayhub上找到了这样一份教程<Pytorch模型训练实用教程>,写 ...

  5. 9个让PyTorch模型训练提速的技巧!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 来源:AI公园,译者:ronghuaiyang 作者:William F ...

  6. 9个技巧让你的PyTorch模型训练变得飞快!

    公众号关注 "视学算法" 设为"星标",第一时间知晓最新干货~ 作者丨William Falcon 来源丨AI公园 不要让你的神经网络变成这样 让我们面对现实吧 ...

  7. 加速 PyTorch 模型训练的 9 个技巧

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 导读 一个step by step的指南,非常的实用. 不要让你的 ...

  8. 9 个技巧让你的 PyTorch 模型训练变得飞快!

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 作者 | William Falcon 编译 | ronghuaiyang 来源 | ...

  9. 收藏 | 9 个技巧让你的 PyTorch 模型训练变得飞快!

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者 | William Falcon 编译 | rongh ...

最新文章

  1. 【注意事项】Markdown遇到的小问题
  2. python能做什么游戏ll-一个简单的python game游戏
  3. 银行办理业务观察者模式解析
  4. Linux学习:静态库和动态库
  5. sql注入一点小心得
  6. git学习3--关联不同的网址的远程分支
  7. REGEXP_REPLACE SQL正则表达式
  8. 论文笔记:3DMM(ACM1999)
  9. eclipse安装(中文)语言包插件
  10. matlab iri模型,IRI-2016 Matlab 使用教程
  11. 随机森林python反欺诈_携程金融自动化迭代反欺诈模型体系
  12. “碳壁垒”悄然而起,碳足迹如何算清楚、减明白?|双碳科普
  13. ZT:神秘的通道——三焦经
  14. typeof和instanceof的区别
  15. Java POI解析Word提取数据存储在Excel
  16. 简单数据处理(相关系数,协方差,t检验)
  17. ubuntu 下myeclipse下载,安装,破解
  18. 信息学奥赛一本通(C++版)continue
  19. linux系统无法启动 备份恢复,Linux运维 第二阶段 (十四) 备份与恢复及常见故障排除...
  20. SRTP/SRTCP协议

热门文章

  1. 手机维修培训班无门槛 深圳红警维快维 可考取手机维修工程师证书
  2. win10计算机属性此项目属性打不开,win10回收站打不开 此项目的属性未知 的解决方法...
  3. 直播竞答类产品数据分析报告火热出炉啦
  4. 数商云石油化工行业B2B集采平台解决方案:提高采购议价能力,规范化采购流程
  5. MAC版E信心跳包加密KEY的逆向
  6. 从零开始Unity3D游戏开发【2 简单的水管工例子】
  7. C语言代码示范与讲解+C语言编程规范及基础语法+编程实战
  8. LaTeX 绘图随缘记(二)
  9. 为什么很多公司都转型go语言开发?
  10. 海康工业相机连续存图、录像功能介绍