网络保存与加载

1.保存

torch.manual_seed(1)    # reproducible# 假数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)def save():# 建网络net1 = torch.nn.Sequential(torch.nn.Linear(1, 10),torch.nn.ReLU(),torch.nn.Linear(10, 1))optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)loss_func = torch.nn.MSELoss()# 训练for t in range(100):prediction = net1(x)loss = loss_func(prediction, y)optimizer.zero_grad()loss.backward()optimizer.step()torch.save(net1, 'net.pkl')  # 保存整个网络
torch.save(net.state_dict(), 'net_params.pkl') # 只保存网络中的参数(速度快,占内存少)

2.加载网络

def restore_net():# restore entire net1 to net2net2 = torch.load('net.pkl')prediction = net2(x)# 只提取网络参数
def restore_params():# 新建 net3net3 = torch.nn.Sequential(torch.nn.Linear(1, 10),torch.nn.ReLU(),torch.nn.Linear(10, 1))# 将保存的参数复制到 net3net3.load_state_dict(torch.load('net_params.pkl'))prediction = net3(x)# 保存 net1 (1. 整个网络, 2. 只有参数)
save()
# 提取整个网络
restore_net()
# 提取网络参数, 复制到新网络
restore_params()

3.批训练

DataLoader是torch给你用来包装你的数据的工具。所以要将自己的(numpy array或其他)数据形式转换成Tensor, 然后再放进这个包装器中。使用DataLoader的好处就是帮你有效地迭代数据。

import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducibleBATCH_SIZE = 5      # 批训练的数据个数x = torch.linspace(1, 10, 10)       # x data (torch tensor)
y = torch.linspace(10, 1, 10)       # y data (torch tensor)# 先转换成 torch 能识别的 Dataset
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)# 把 dataset 放入 DataLoader
loader = Data.DataLoader(dataset=torch_dataset,      # torch TensorDataset formatbatch_size=BATCH_SIZE,      # mini batch sizeshuffle=True,               # 要不要打乱数据 (打乱比较好)num_workers=2,              # 多线程来读数据
)for epoch in range(3):   # 训练所有!整套!数据 3 次for step, (batch_x, batch_y) in enumerate(loader):  # 每一步 loader 释放一小批数据用来学习# 假设这里就是你训练的地方...# 打出来一些数据print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',batch_x.numpy(), '| batch y: ', batch_y.numpy())"""
Epoch:  0 | Step:  0 | batch x:  [ 6.  7.  2.  3.  1.] | batch y:  [  5.   4.   9.   8.  10.]
Epoch:  0 | Step:  1 | batch x:  [  9.  10.   4.   8.   5.] | batch y:  [ 2.  1.  7.  3.  6.]
Epoch:  1 | Step:  0 | batch x:  [  3.   4.   2.   9.  10.] | batch y:  [ 8.  7.  9.  2.  1.]
Epoch:  1 | Step:  1 | batch x:  [ 1.  7.  8.  5.  6.] | batch y:  [ 10.   4.   3.   6.   5.]
Epoch:  2 | Step:  0 | batch x:  [ 3.  9.  2.  6.  7.] | batch y:  [ 8.  2.  9.  5.  4.]
Epoch:  2 | Step:  1 | batch x:  [ 10.   4.   8.   1.   5.] | batch y:  [  1.   7.   3.  10.   6.]
"""

当数据最后不足batch时,就会返回这个epoch中剩下的数据。

转载于:https://www.cnblogs.com/o-v-o/p/10946146.html

pytorch使用说明2相关推荐

  1. 【pytorch】torch.cdist使用说明

    使用说明 torch.cdist的使用介绍如官网所示, 它是批量计算两个向量集合的距离. 其中, x1和x2是输入的两个向量集合. p 默认为2,为欧几里德距离. 它的功能上等同于 scipy.spa ...

  2. 超赞的PyTorch资源大列表,GitHub标星9k+,中文版也上线了

    点击阅读原文,快速报名! 作者 | 红色石头 来源 | AI有道(ID: redstonewill) 自 2017 年 1 月 PyTorch 推出以来,其热度持续上升.PyTorch 能在短时间内被 ...

  3. github总star超9K!一个超赞的 PyTorch 资源大列表,有人把它翻译成了中文版!

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 转自:程序员爱码士 自 2017 年 1 月 PyTorch 推出以来,其热度持续上升 ...

  4. Github标星9k+,超赞的 PyTorch 资源大列表!

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 原来的英文版 GitHub 项目地址: https://github.com/bh ...

  5. pytorch 笔记:tensorboardX

    1 SummaryWriter 1.1 创建 首先,需要创建一个 SummaryWriter 的示例: from tensorboardX import SummaryWriter#以下是三种不同的初 ...

  6. 超赞的 PyTorch 资源大列表,有人把它翻译成了中文版!

    点击上方"AI有道",选择"星标"公众号 重磅干货,第一时间送达 自 2017 年 1 月 PyTorch 推出以来,其热度持续上升.PyTorch 能在短时间 ...

  7. 【NLP】Github标星7.7k+:常见NLP模型的PyTorch代码实现

    推荐github上的一个NLP代码教程:nlp-tutorial,教程中包含常见的NLP模型代码实现(基于Pytorch1.0+),而且教程中的大多数NLP模型都使用少于100行代码. 教程说明 这是 ...

  8. Github标星5.4k+:常见NLP模型的代码实现(基于TensorFlow和PyTorch)

    推荐github上的一个NLP代码教程:nlp-tutorial,教程中包含常见的NLP模型代码实现(基于TensorFlow和Pytorch),而且教程中的大多数NLP模型都使用少于100行代码. ...

  9. 推荐:常见NLP模型的代码实现(基于TensorFlow和PyTorch)

    推荐github上的一个NLP代码教程:nlp-tutorial,教程中包含常见的NLP模型代码实现(基于TensorFlow和Pytorch),而且教程中的大多数NLP模型都使用少于100行代码. ...

  10. linux卸载anaconda_Win10安装Anaconda和Pytorch(CPU版)

    1.Anaconda安装 Anaconda的安装网上的教程非常非常多,很简单,下面这篇博客写的很详细,看我写的也可以. 地址:https://blog.csdn.net/u014546828/arti ...

最新文章

  1. How to reduce Index size on disk?减少ES索引大小的一些小手段
  2. javascript-tab切换效果
  3. Python合并两个List
  4. 数据结构与算法 -- 二叉树 ADT
  5. 顺序表查找+折半查找(二级)
  6. 适用于Java EE / Jakarta EE开发人员的Micronaut
  7. 程序员的工资普遍在20k以上
  8. java 方法调用表达式_java lambda怎么表达式判断被调用接口名称和接口中方法
  9. linux shell脚本字符串 字段分隔符 存入数组 根据下标取值
  10. qt qtableview 刷新列表_qt qtableview基本用法
  11. 如何玩转阿里巴巴国际站Trueview视频?
  12. unity3D制作拼图游戏
  13. 【STM32H7教程】第88章 STM32H7的SDMMC总线应用之SD卡移植FatFs文件系统
  14. 1359 信息学奥赛一本通 围成面积
  15. 清明不远游 国内赏春地推荐
  16. java动态心形程序_java swing实现动态心形图案的代码下载
  17. python爬取汽车之家数据_python 实现汽车之家车型数据爬虫
  18. 初识RFID的物理与逻辑安全机制
  19. 小程序蓝牙开发官方demo--不能发送字符串命令或发送失败10004问题
  20. 圣诞节用代码写一颗圣诞树【html5写的3D逼真圣诞树外加python无延迟的豪华圣诞树】

热门文章

  1. 看雪CTF.TSRC 2018 团队赛 第九题『谍战』 解题思路
  2. 人工智能进场 AR/VR何去何从?
  3. MaxCompute 2.0 生态开放之路及最新发展
  4. matplotlib简介
  5. 水晶報表之Datetime TO shortDate
  6. 被暴击了!22岁本科生开源的后台管理系统,太实用!
  7. 干掉 SQL 中的 like,我用 es 后运营小姐姐们都说好快!
  8. 读完 Effective Java,我整理这 59 条技巧!
  9. Google开源的操作系统Fuchsia,专为大内存硬件设计
  10. 开发、运维、测试都要了解的测试技巧