Pytorch创建模型

写这篇博客的初衷是因为非常多情况下需要用到pytorch的包,但是每一次调用都需要额外编写函数,评估呀什么的,特别要牵扯上攻击和防御,所以就想写个博客,总结一下,彻底研究这个内容

torch模型的定义

一般来说,都会创建一个类(继承torch.nn.Module)作为模型。一开始入门,只需要关注两个函数。

特别用来提醒torch的全连接和keras的全连接不同

def __init__(self):
#用来完成模型的细节定义
def forward(self,x):
#用来表示模型具体的前向传播过程

训练过程

数据处理会专门另作一篇博客,更新后会放上链接

最简单的训练过程可以利用torch的自带函数,设置优化器,损失函数,如下:

losses = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=torchmodel.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8,weight_decay=0, amsgrad=False)

之后在完成数据导入后,最关键的就是输入数据(前向传播),计算损失(求梯度),更新权重(反向传播)

outputs = model(X_train)
optimizer.zero_grad()
loss = losses(outputs, y_train)
loss.backward()
optimizer.step()

完整代码可以参见工程

模型保存与读取

保存与读取的时候格式要相互匹配。

    torch.save(torchmodel.state_dict(), path)Mymodel.load_state_dict(torch.load(path))torch.save(torchmodel, path)Mymodel = torch.load(path)

体会

  1. model.eval() 与 model.train() 在模型训练和测试的时候要保持好逻辑关系

  2. 如果想要查看模型特定的参数 1 for name in model.state_dict(): 2 model.named_parameters() 3 for layer in model.modules():
    for layers in torchmodel.modules():
    print(layers.requires_grad_()) 固定某层
    一个是设置不要更新参数的网络层为false,另一个就是在定义优化器时只传入要更新的参数。当然最优的做法是,优化器中只传入requires_grad=True的参数,这样占用的内存会更小一点,效率也会更高。

更新 11.25 ———— 查看参数以及提取某层特定输出

提取某层的输出比较简单的方法就是在forward内部,对需要记录的数据保存下来。
在forward函数内,对数据保存,以load——state保存模型互不干扰!!!

另外一种,实用的方法就是hook,钩子方法。

主要实现途径是设置函数hook,完成对output的一些操作,例如保存。在通过把某个层与hook函数挂上关系,就可以实现。
但一定要在最后取消hook。

    activation = {}def get_activation(name):def hook(model, input, output):# 如果你想feature的梯度能反向传播,那么去掉 detach()activation[name] = output.detach()return hookh = model.dense[0].register_forward_hook(get_activation('希望的输出'))!!!!h.remove()

Pytorch创建模型-小试牛刀相关推荐

  1. pytorch创建模型并训练(初探文本分类问题)

        本博客对pytorch在深度学习上的使用进行了介绍,本博客并不会对怎么训练一个好的模型进行介绍(其实我也不会),我觉得训练一个好的模型首先得选对一个模型(关键的问题在于模型如何设计),然后再经 ...

  2. pytorch保存模型pth_Day159:模型的保存与加载

    网络结构和参数可以分开的保存和加载,因此,pytorch保存模型有两种方法: 保存 整个模型 (结构+参数) 只保存模型参数(官方推荐) # 保存整个网络torch.save(model, check ...

  3. 《Pytorch - 线性回归模型》

    2020年10月4号,依然在家学习. 今天是我写的第一个 Pytorch程序,从今天起也算是入门了. 就从简单的线性回归开始吧. 话不多说,我就直接上代码实例,代码的注释我都是用中文直接写的. imp ...

  4. ONNX系列三 --- 使用ONNX使PyTorch AI模型可移植

    目录 PyTorch简介 导入转换器 快速浏览模型 将PyTorch模型转换为ONNX 摘要和后续步骤 参考文献 下载源547.1 KB 系列文章列表如下: ONNX系列一 --- 带有ONNX的便携 ...

  5. c++list遍历_小白学PyTorch | 6 模型的构建访问遍历存储(附代码)

    关注一下不迷路哦~喜欢的点个星标吧~<> 小白学PyTorch | 5 torchvision预训练模型与数据集全览 小白学PyTorch | 4 构建模型三要素与权重初始化 小白学PyT ...

  6. c++ 遍历list_小白学PyTorch | 6 模型的构建访问遍历存储(附代码

    文章来自微信公众号:[机器学习炼丹术],是个人的学习心得分享基地. 文章目录: 1 模型构建函数 1.1 add_module 1.2 ModuleList 1.3 Sequential 1.4 小总 ...

  7. 一大波PyTorch图像分割模型来袭,俄罗斯程序员出品新model zoo

    鱼羊 发自 凹非寺 量子位 报道 | 公众号 QbitAI 一个新的图像分割model zoo来啦! 一大波基于PyTorch的图像分割模型整理好了就等你来用~ 这个新集合由俄罗斯的程序员小哥Pave ...

  8. Python: 从PYTORCH导出模型到ONNX,并使用ONNX运行时运行它

    Python: 从PYTORCH导出模型到ONNX,并使用ONNX运行时运行它 本教程我们将描述如何将PyTorch中定义的模型转换为ONNX格式,然后使用ONNX运行时运行它. ONNX运行时是一个 ...

  9. Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()

    Pytorch 保存和加载模型后缀:.pt 和.pth 1 torch.save() [source] 保存一个序列化(serialized)的目标到磁盘.函数使用了Python的pickle程序用于 ...

最新文章

  1. 静态路由_【零基础学云计算】静态路由!静态路由!静态路由!原理与配置
  2. FMDB支持的事务类型
  3. nextcloud安装教程
  4. 问题 | Spare BA 中的Eigen运行错误
  5. 分享codeigniter框架,在zend studio 环境下的代码提示
  6. gdb常用命令及参考文档
  7. ds排序--希尔排序_图解直接插入排序和希尔排序
  8. XML文件的读取(XmlParserDemo)
  9. python编写arcgis脚本_ArcGis Python脚本——批量添加字段
  10. thinkphp6 加载第三方类库_thinkphp中第三方类引入问题
  11. .NET2.0 事务处理
  12. paip.windows io监控总结
  13. sqlserver两种分页方法比较
  14. 英文网站排名优化 谷歌SEO优化技巧方法
  15. android 恢复出厂设置流程分析,Android恢复出厂设置流程分析
  16. 教您在Xshell中清除历史记录
  17. 深圳学校积分计算机,深圳市龙岗区小学积分入学排行榜
  18. Python财务分析
  19. pyrosetta下载及安装(linux服务器)
  20. 统计学的Python实现-009:四分位数

热门文章

  1. -XX:+TraceClassLoading和-XX:+TraceClassUnloading
  2. AD16如何在3D环境翻转PCB
  3. 《阅读的方法》读书笔记2-2:遥远的地方
  4. 0402/0603/0805/1206封装尺寸
  5. 我的世界java版1.12.2版_我的世界Minecraft Java版1.12.2 pre2 发布
  6. word中最后一行留白太多
  7. BZOJ 2448: 挖油
  8. 按键精灵大漠插件使用基础练习入门代码
  9. 【Todo】【读书笔记】机器学习实战(Python版)
  10. bean login not found within scope