Introduce

在pytorch中,torch.nn.Module模块中的state_dict变量存放训练过程中需要学习的权重和偏执系数,state_dict作为python的字典对象将每一层的参数映射成tensor张量,需要注意的是torch.nn.Module模块中的state_dict只包含卷积层和全连接层的参数,当网络中存在batchnorm时,例如vgg网络结构,torch.nn.Module模块中的state_dict也会存放batchnorm's running_mean,关于batchnorm详解可见https://blog.csdn.net/wzy_zju/article/details/81262453

torch.optim模块中的Optimizer优化器对象也存在一个state_dict对象,此处的state_dict字典对象包含state和param_groups的字典对象,而param_groups key对应的value也是一个由学习率,动量等参数组成的一个字典对象。

因为state_dict本质上Python字典对象,所以可以很好地进行保存、更新、修改和恢复操作(python字典结构的特性),从而为PyTorch模型和优化器增加了大量的模块化。

Sample

通过一个简单的案例来输出state_dict字典对象中存放的变量

  1. #encoding:utf-8

  2. import torch

  3. import torch.nn as nn

  4. import torch.optim as optim

  5. import torchvision

  6. import numpy as mp

  7. import matplotlib.pyplot as plt

  8. import torch.nn.functional as F

  9. #define model

  10. class TheModelClass(nn.Module):

  11. def __init__(self):

  12. super(TheModelClass,self).__init__()

  13. self.conv1=nn.Conv2d(3,6,5)

  14. self.pool=nn.MaxPool2d(2,2)

  15. self.conv2=nn.Conv2d(6,16,5)

  16. self.fc1=nn.Linear(16*5*5,120)

  17. self.fc2=nn.Linear(120,84)

  18. self.fc3=nn.Linear(84,10)

  19. def forward(self,x):

  20. x=self.pool(F.relu(self.conv1(x)))

  21. x=self.pool(F.relu(self.conv2(x)))

  22. x=x.view(-1,16*5*5)

  23. x=F.relu(self.fc1(x))

  24. x=F.relu(self.fc2(x))

  25. x=self.fc3(x)

  26. return x

  27. def main():

  28. # Initialize model

  29. model = TheModelClass()

  30. #Initialize optimizer

  31. optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

  32. #print model's state_dict

  33. print('Model.state_dict:')

  34. for param_tensor in model.state_dict():

  35. #打印 key value字典

  36. print(param_tensor,'\t',model.state_dict()[param_tensor].size())

  37. #print optimizer's state_dict

  38. print('Optimizer,s state_dict:')

  39. for var_name in optimizer.state_dict():

  40. print(var_name,'\t',optimizer.state_dict()[var_name])

  41. if __name__=='__main__':

  42. main()

  43. 具体的输出结果如下:可以很清晰的观测到state_dict中存放的key和value的值

  44. Model.state_dict:

  45. conv1.weight torch.Size([6, 3, 5, 5])

  46. conv1.bias torch.Size([6])

  47. conv2.weight torch.Size([16, 6, 5, 5])

  48. conv2.bias torch.Size([16])

  49. fc1.weight torch.Size([120, 400])

  50. fc1.bias torch.Size([120])

  51. fc2.weight torch.Size([84, 120])

  52. fc2.bias torch.Size([84])

  53. fc3.weight torch.Size([10, 84])

  54. fc3.bias torch.Size([10])

  55. Optimizer,s state_dict:

  56. state {}

  57. param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [367949288, 367949432, 376459056, 381121808, 381121952, 381122024, 381121880, 381122168, 381122096, 381122312]}]

Pytorch state_dict介绍相关推荐

  1. pytorch学习笔记(九):PyTorch结构介绍

    PyTorch结构介绍 对PyTorch架构的粗浅理解,不能保证完全正确,但是希望可以从更高层次上对PyTorch上有个整体把握.水平有限,如有错误,欢迎指错,谢谢! 几个重要的类型 和数值相关的 T ...

  2. pytorch - state_dict() , parameters() 详解

    目录 1 parameters() 1.1 model.parameters(): 1.2 model.named_parameters(): 2 state_dict() torch.nn.Modu ...

  3. Pytorch:visdom介绍

    一.介绍 在深度学习领域,模型训练是一个必须的过程,因此常常需要实时监听并可视化一些数据,如损失值loss,正确率acc等.在Tensorflow中,最常使用的工具非Tensorboard莫属:在Py ...

  4. [转载] Pytorch基础介绍

    参考链接: PyTorch的基础 Pytorch安装.https://zhuanlan.zhihu.com/p/26871672. Pytorch中文文档.https://pytorch-cn.rea ...

  5. PyTorch框架:(1)基本处理操作

    目录 1.PyTorch框架介绍 2.安装Pytorch 2.1.CPU版本的安装命令: 2.2.GPU版本的安装命令: 2.2.1.安装CUDA 3.基本使用方法 4.Pytorch中的自动求导机制 ...

  6. PyTorch 深度剖析:如何保存和加载PyTorch模型?

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨科技猛兽 编辑丨极市平台 导读 本文详解了PyTorch 模型 ...

  7. Pytorch学习 - 保存模型和重新加载

    Pytorch学习 - 保存和加载模型 1. 3个函数 2. 模型不同后缀名的区别 3. 保存和重载模型 参考资料: Pytorch官方文档链接 某博客 1. 3个函数 torch.save() : ...

  8. CNN入门+猫狗大战(Dogs vs. Cats)+PyTorch入门

    一些修改(修改后的代码) 修改原网络的输出方式.原网络采用的交叉熵torch.nn.CrossEntropyLoss()进行Loss计算,而这个函数内部是已经进行了softmax处理的(参考),所以网 ...

  9. pytorch入门使用及前置知识(2)NLP

    1 深度学习的介绍   深度学习(deep learning) 是机器学习的分支,是一种以人工神经网络为架构,对数据进行特征学习的算法.   机器学习是人工智能的一种实现方式,深度学习是机器学习中的一 ...

  10. pytorch创建模型并训练(初探文本分类问题)

        本博客对pytorch在深度学习上的使用进行了介绍,本博客并不会对怎么训练一个好的模型进行介绍(其实我也不会),我觉得训练一个好的模型首先得选对一个模型(关键的问题在于模型如何设计),然后再经 ...

最新文章

  1. 知乎热议:高数、线代应该成为计算机专业学习的重心吗?
  2. python中参数的位置传递和名称传递各有什么优缺点_Python开发TCP和UDP的区别是什么?优缺点对比总结...
  3. 变量命名规范 匈牙利 下划线 骆驼 帕斯卡
  4. python常用内置模块-Python常用内置模块之xml模块
  5. 印度式画线乘法基本操作
  6. [GitHub] 75+的 C# 数据结构和算法实现
  7. 话里话外:中小型装备制造企业竞争优势构建之路
  8. 调试远程服务器上的代码时报错:调试设置中的Python路径无效
  9. 如何挖掘评论中的关键信息
  10. cout和printf的区别
  11. Android 真实 简历
  12. 中国十大芯片企业排名
  13. linux与ipad传输文件,实用!三种iPhone与Windows电脑互传文件操作技巧,建议收藏...
  14. promise、axios 理解
  15. 武汉星起航跨境:跨境电商新蓝海,南非跨境电商市场迸发活力
  16. layui搜索重置功能
  17. 学习gitlab-runner
  18. Adobe photoshop工具箱工具名称中英文对照
  19. 菲律宾德拉斯大学计算机专业,2020年菲律宾大学以及各专业排行榜
  20. Django urls 下划线的坑-Using the URLconf defined in xxx, Django tried these URL patterns, in thi

热门文章

  1. iOS NSUserDefaults 存放位置
  2. 在html中做表格以及给表格设置高宽字体居中和表格线的粗细
  3. 专访UCloud徐亮:UCloud虚拟网络的演进之路
  4. 网站服务器中病毒或被***怎么办?
  5. 如何解决SSM框架前台传参数到后台乱码的问题
  6. QQ去除未读状态的动画
  7. Linux命令之grep
  8. WIN2003 IIS6.0+PHP+ASP+MYSQL优化配置
  9. Google搜索技巧终极收集 - 101个Google技巧
  10. VirtualBox开发环境的搭建详解