在PyTorch框架训练模型的时候,经常会看到.pth这样的文件,如果我们直接打开看是一些乱码,那这个文件是做什么的,保存了一些什么东西呢?实质其实是一个.pkl文件,想了解这个文件的可以参阅:Python基础知识汇总
既然是pkl文件,保存的也是key-value键值对,或说字典类型,我们来保存并显示看下:

#将字典保存到tony.pth文件
torch.save({'hi':123,'hello':'haha','name':'Tony'}, 'tony.pth')
#读取
t=torch.load("tony.pth")
for k,v in t.items():print(k,v)'''
hi 123
hello haha
name Tony
'''

上面是很简单的一个字典类型<class 'dict'>,一般在模型训练中,使用的是有序字典。

我们找一个例子来看下,前面在训练WGAN的时候,会生成很多.pth文件,我选一个netG_epoch_24.pth文件查看下内容。

import torch
net2 = torch.load("netG_epoch_24.pth")
print(type(net2),len(net2))for k,v in net2.items():print(k,type(v),v.size())

<class 'collections.OrderedDict'> 25
main.initial:100-512:convt.weight <class 'torch.Tensor'> torch.Size([100, 512, 4, 4])
main.initial:512:batchnorm.weight <class 'torch.Tensor'> torch.Size([512])
main.initial:512:batchnorm.bias <class 'torch.Tensor'> torch.Size([512])
main.initial:512:batchnorm.running_mean <class 'torch.Tensor'> torch.Size([512])
main.initial:512:batchnorm.running_var <class 'torch.Tensor'> torch.Size([512])
main.initial:512:batchnorm.num_batches_tracked <class 'torch.Tensor'> torch.Size([])
main.pyramid:512-256:convt.weight <class 'torch.Tensor'> torch.Size([512, 256, 4, 4])
main.pyramid:256:batchnorm.weight <class 'torch.Tensor'> torch.Size([256])
main.pyramid:256:batchnorm.bias <class 'torch.Tensor'> torch.Size([256])
main.pyramid:256:batchnorm.running_mean <class 'torch.Tensor'> torch.Size([256])
main.pyramid:256:batchnorm.running_var <class 'torch.Tensor'> torch.Size([256])
main.pyramid:256:batchnorm.num_batches_tracked <class 'torch.Tensor'> torch.Size([])
main.pyramid:256-128:convt.weight <class 'torch.Tensor'> torch.Size([256, 128, 4, 4])
main.pyramid:128:batchnorm.weight <class 'torch.Tensor'> torch.Size([128])
main.pyramid:128:batchnorm.bias <class 'torch.Tensor'> torch.Size([128])
main.pyramid:128:batchnorm.running_mean <class 'torch.Tensor'> torch.Size([128])
main.pyramid:128:batchnorm.running_var <class 'torch.Tensor'> torch.Size([128])
main.pyramid:128:batchnorm.num_batches_tracked <class 'torch.Tensor'> torch.Size([])
main.pyramid:128-64:convt.weight <class 'torch.Tensor'> torch.Size([128, 64, 4, 4])
main.pyramid:64:batchnorm.weight <class 'torch.Tensor'> torch.Size([64])
main.pyramid:64:batchnorm.bias <class 'torch.Tensor'> torch.Size([64])
main.pyramid:64:batchnorm.running_mean <class 'torch.Tensor'> torch.Size([64])
main.pyramid:64:batchnorm.running_var <class 'torch.Tensor'> torch.Size([64])
main.pyramid:64:batchnorm.num_batches_tracked <class 'torch.Tensor'> torch.Size([])
main.final:64-3:convt.weight <class 'torch.Tensor'> torch.Size([64, 3, 4, 4])

可以看出保存的键值类型是参数,有权重值、偏置、均值和方差等,值内容的类型是Tensor张量,整个使用的是有序的字典类型OrderedDict。

DataLoader数据加载器的用法

DataLoader数据加载器,来自torch.utils.data模块,参数可以使用数据集与采样器,可以使用多进程来处理数据集(Linux)。将数据集装载进去训练的时候,会将数据分成多个小组(每次小组数量取决于批量大小),批量的进行迭代。

import torch
import torch.utils.data as udatax=torch.linspace(1,10,10)
y=torch.linspace(10,1,10)
#使用TensorDataset将数据包装成Dataset类
dset = udata.TensorDataset(x, y)
#每次批量5个,为了更直观没有打乱,实际训练中一般都是打乱比较好,也就是shuffle=True
#loader = udata.DataLoader(dataset=dset, batch_size=5,shuffle=False,num_workers=0)
'''
epoch:0, step:0, batch_x:tensor([1., 2., 3., 4., 5.]), batch_y:tensor([10.,  9.,  8.,  7.,  6.])
epoch:0, step:1, batch_x:tensor([ 6.,  7.,  8.,  9., 10.]), batch_y:tensor([5., 4., 3., 2., 1.])
epoch:1, step:0, batch_x:tensor([1., 2., 3., 4., 5.]), batch_y:tensor([10.,  9.,  8.,  7.,  6.])
epoch:1, step:1, batch_x:tensor([ 6.,  7.,  8.,  9., 10.]), batch_y:tensor([5., 4., 3., 2., 1.])
epoch:2, step:0, batch_x:tensor([1., 2., 3., 4., 5.]), batch_y:tensor([10.,  9.,  8.,  7.,  6.])
epoch:2, step:1, batch_x:tensor([ 6.,  7.,  8.,  9., 10.]), batch_y:tensor([5., 4., 3., 2., 1.])
'''
#每次批量4个,没有被整除那就是10个数字除以4还剩余2个再迭代一次
#也可以指定drop_last=True,将剩余的删除掉,那剩余的就不会再迭代了
#loader = udata.DataLoader(dataset=dset, batch_size=4,shuffle=False,num_workers=0)
'''
epoch:0, step:0, batch_x:tensor([1., 2., 3., 4.]), batch_y:tensor([10.,  9.,  8.,  7.])
epoch:0, step:1, batch_x:tensor([5., 6., 7., 8.]), batch_y:tensor([6., 5., 4., 3.])
epoch:0, step:2, batch_x:tensor([ 9., 10.]), batch_y:tensor([2., 1.])
epoch:1, step:0, batch_x:tensor([1., 2., 3., 4.]), batch_y:tensor([10.,  9.,  8.,  7.])
epoch:1, step:1, batch_x:tensor([5., 6., 7., 8.]), batch_y:tensor([6., 5., 4., 3.])
epoch:1, step:2, batch_x:tensor([ 9., 10.]), batch_y:tensor([2., 1.])
epoch:2, step:0, batch_x:tensor([1., 2., 3., 4.]), batch_y:tensor([10.,  9.,  8.,  7.])
epoch:2, step:1, batch_x:tensor([5., 6., 7., 8.]), batch_y:tensor([6., 5., 4., 3.])
epoch:2, step:2, batch_x:tensor([ 9., 10.]), batch_y:tensor([2., 1.])
'''
indices=range(len(dset))
sub_rnd_sample=indices[:10]
#随机子采样类似于打乱,所以如果是子采样的话,就不要设定shuffle=True
subsampler = udata.sampler.SubsetRandomSampler(sub_rnd_sample)
loader = udata.DataLoader(dataset=dset, batch_size=4,sampler=subsampler)
'''
epoch:0, step:0, batch_x:tensor([ 8., 10.,  3.,  1.]), batch_y:tensor([ 3.,  1.,  8., 10.])
epoch:0, step:1, batch_x:tensor([6., 2., 9., 7.]), batch_y:tensor([5., 9., 2., 4.])
epoch:0, step:2, batch_x:tensor([4., 5.]), batch_y:tensor([7., 6.])
epoch:1, step:0, batch_x:tensor([7., 6., 8., 3.]), batch_y:tensor([4., 5., 3., 8.])
epoch:1, step:1, batch_x:tensor([5., 9., 2., 1.]), batch_y:tensor([ 6.,  2.,  9., 10.])
epoch:1, step:2, batch_x:tensor([10.,  4.]), batch_y:tensor([1., 7.])
epoch:2, step:0, batch_x:tensor([10.,  1.,  2.,  5.]), batch_y:tensor([ 1., 10.,  9.,  6.])
epoch:2, step:1, batch_x:tensor([8., 9., 6., 7.]), batch_y:tensor([3., 2., 5., 4.])
epoch:2, step:2, batch_x:tensor([3., 4.]), batch_y:tensor([8., 7.])
'''for epoch in range(3):for step, (batch_x, batch_y) in enumerate(loader):print("epoch:{}, step:{}, batch_x:{}, batch_y:{}".format(epoch,step, batch_x, batch_y))

更详细的一些参数可以查看其源码来了解

Pytorch基础知识之pth文件与DataLoader数据加载器相关推荐

  1. PyTorch基础-自定义数据集和数据加载器(2)

    处理数据样本的代码可能会变得混乱且难以维护: 理想情况下,我们想要数据集代码与模型训练代码解耦,以获得更好的可读性和模块化.PyTorch 域库提供了许多预加载的数据(例如 FashionMNIST) ...

  2. pytorch自定义数据集和数据加载器

    假设有一个保存为npy格式的numpy数据集,现在需要将其变为pytorch的数据集,并能够被数据加载器DataLoader所加载 首先自定义一个数据集类,继承torch.utils.data.Dat ...

  3. 深度学习-Pytorch:项目标准流程【构建、保存、加载神经网络模型;数据集构建器Dataset、数据加载器DataLoader(线性回归案例、手写数字识别案例)】

    1.拿到文本,分词,清晰数据(去掉停用词语): 2.建立word2index.index2word表 3.准备好预训练好的word embedding 4.做好DataSet / Dataloader ...

  4. PyTorch—torch.utils.data.DataLoader 数据加载类

    文章目录 DataLoader(object)类: _DataLoaderIter(object)类 __next__函数 pin_memory_batch() _get_batch函数 _proce ...

  5. PyTorch 编写自定义数据集,数据加载器和转换

    本文为 pytorch 官方教程https://pytorch.org/tutorials/beginner/data_loading_tutorial.html代码的注释 w3cschool 的翻译 ...

  6. PyTorch数据加载器

    We'll be covering the PyTorch DataLoader in this tutorial. Large datasets are indispensable in the w ...

  7. [YOLO专题-19]:YOLO V5 - ultralytics代码解析-dataloader数据加载机制

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  8. c语言文件 加载内存吗,把文件中的数据加载到内存进行查找C语言实现.docx

    把文件中的数据加载到内存进行查找C语言实现 #define _CRT_SECURE_NO_WARNINGS#include#include#includechar **pp=NULL;void ini ...

  9. 第七章:在Spark集群上使用文件中的数据加载成为graph并进行操作(2)

    Spark-shell启动后我们可以在控制台看到起运行信息: 点击作业ID即可查看Spark shell运行信息: 下面我们就开始在集群上通过读取hdfs文件的方式来构建graph对象,首先要做的就是 ...

  10. 第七章:在Spark集群上使用文件中的数据加载成为graph并进行操作(3)

    你可以调整graph的构造参数来指定partition的数量. 当数据加载完毕的时候整个web-Googel.txt就缓存进了内存之中,如下所示: 可以看到数据被缓存成了edges. 下面我们使用把m ...

最新文章

  1. docker :open /var/lib/docker/tmp/GetImageBlob318829910: no such file or directory异常解决
  2. vue的自定义组件如何使用prop传值?
  3. crm---本项目的权限控制模式
  4. android布局技巧:创建高效布局
  5. JAVA面试常考系列九
  6. linux强制使用windows命名,如何强制Windows重命名带有特殊字符的文件?
  7. centos7时间同步_centos 8.x系统配置chrony时间同步服务
  8. Daily Scrum 11.6
  9. 孤读Paper——《CenterNet:Objects as Points》
  10. 09.Java数据算法
  11. WAP的技术、运动和现状(转)
  12. html缩放背景不缩放_如何在缩放通话中静音
  13. 创客集结号:3D打印如何与中小学教育有机结合?
  14. 域名检测工具图文教程
  15. excel取末尾数字_Excel公式技巧11: 从字符串中提取数字——数字位于字符串末尾...
  16. 模拟淘宝侧边服务模块鼠标悬停效果的三种实现方式总结
  17. Excel如何设置下拉选项,并应用到整列
  18. 统信下人大金仓创建表空间及导入oracle数据
  19. 查看linux下文件是否存在,linux中怎么查看文件是否存在
  20. 211大学生自我反省

热门文章

  1. 解决zui-upload(ZUI: 文件上传 - v1.8.1)移动端上传组件的bug
  2. 模拟人生显示无法连接服务器,模拟人生总是显示无法连接网络
  3. easyboot的一个严重不足
  4. ADI高速信号采集芯片与JESD204B接口简介
  5. 研究了四大计算机名校的培养方案,核心课程都在这了
  6. L3HSEC 2022秋季招新赛部分WP
  7. 色彩搭配及设计金字塔的总结
  8. VC、PE和天使投资的解释与区别?
  9. 博客上云历程(二):Docker入门介绍与使用
  10. 停车还能360全方位影像_辅助停车,新手司机就选360全景吧!