Pytorch创建模型-小试牛刀
Pytorch创建模型
写这篇博客的初衷是因为非常多情况下需要用到pytorch的包,但是每一次调用都需要额外编写函数,评估呀什么的,特别要牵扯上攻击和防御,所以就想写个博客,总结一下,彻底研究这个内容
torch模型的定义
一般来说,都会创建一个类(继承torch.nn.Module)作为模型。一开始入门,只需要关注两个函数。
特别用来提醒torch的全连接和keras的全连接不同
def __init__(self):
#用来完成模型的细节定义
def forward(self,x):
#用来表示模型具体的前向传播过程
训练过程
数据处理会专门另作一篇博客,更新后会放上链接
最简单的训练过程可以利用torch的自带函数,设置优化器,损失函数,如下:
losses = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=torchmodel.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8,weight_decay=0, amsgrad=False)
之后在完成数据导入后,最关键的就是输入数据(前向传播),计算损失(求梯度),更新权重(反向传播)
outputs = model(X_train)
optimizer.zero_grad()
loss = losses(outputs, y_train)
loss.backward()
optimizer.step()
完整代码可以参见工程
模型保存与读取
保存与读取的时候格式要相互匹配。
torch.save(torchmodel.state_dict(), path)Mymodel.load_state_dict(torch.load(path))torch.save(torchmodel, path)Mymodel = torch.load(path)
体会
model.eval() 与 model.train() 在模型训练和测试的时候要保持好逻辑关系
如果想要查看模型特定的参数 1 for name in model.state_dict(): 2 model.named_parameters() 3 for layer in model.modules():
for layers in torchmodel.modules():
print(layers.requires_grad_()) 固定某层
一个是设置不要更新参数的网络层为false,另一个就是在定义优化器时只传入要更新的参数。当然最优的做法是,优化器中只传入requires_grad=True的参数,这样占用的内存会更小一点,效率也会更高。
更新 11.25 ———— 查看参数以及提取某层特定输出
提取某层的输出比较简单的方法就是在forward内部,对需要记录的数据保存下来。
在forward函数内,对数据保存,以load——state保存模型互不干扰!!!
另外一种,实用的方法就是hook,钩子方法。
主要实现途径是设置函数hook,完成对output的一些操作,例如保存。在通过把某个层与hook函数挂上关系,就可以实现。
但一定要在最后取消hook。
activation = {}def get_activation(name):def hook(model, input, output):# 如果你想feature的梯度能反向传播,那么去掉 detach()activation[name] = output.detach()return hookh = model.dense[0].register_forward_hook(get_activation('希望的输出'))!!!!h.remove()
Pytorch创建模型-小试牛刀相关推荐
- pytorch创建模型并训练(初探文本分类问题)
本博客对pytorch在深度学习上的使用进行了介绍,本博客并不会对怎么训练一个好的模型进行介绍(其实我也不会),我觉得训练一个好的模型首先得选对一个模型(关键的问题在于模型如何设计),然后再经 ...
- pytorch保存模型pth_Day159:模型的保存与加载
网络结构和参数可以分开的保存和加载,因此,pytorch保存模型有两种方法: 保存 整个模型 (结构+参数) 只保存模型参数(官方推荐) # 保存整个网络torch.save(model, check ...
- 《Pytorch - 线性回归模型》
2020年10月4号,依然在家学习. 今天是我写的第一个 Pytorch程序,从今天起也算是入门了. 就从简单的线性回归开始吧. 话不多说,我就直接上代码实例,代码的注释我都是用中文直接写的. imp ...
- ONNX系列三 --- 使用ONNX使PyTorch AI模型可移植
目录 PyTorch简介 导入转换器 快速浏览模型 将PyTorch模型转换为ONNX 摘要和后续步骤 参考文献 下载源547.1 KB 系列文章列表如下: ONNX系列一 --- 带有ONNX的便携 ...
- c++list遍历_小白学PyTorch | 6 模型的构建访问遍历存储(附代码)
关注一下不迷路哦~喜欢的点个星标吧~<> 小白学PyTorch | 5 torchvision预训练模型与数据集全览 小白学PyTorch | 4 构建模型三要素与权重初始化 小白学PyT ...
- c++ 遍历list_小白学PyTorch | 6 模型的构建访问遍历存储(附代码
文章来自微信公众号:[机器学习炼丹术],是个人的学习心得分享基地. 文章目录: 1 模型构建函数 1.1 add_module 1.2 ModuleList 1.3 Sequential 1.4 小总 ...
- 一大波PyTorch图像分割模型来袭,俄罗斯程序员出品新model zoo
鱼羊 发自 凹非寺 量子位 报道 | 公众号 QbitAI 一个新的图像分割model zoo来啦! 一大波基于PyTorch的图像分割模型整理好了就等你来用~ 这个新集合由俄罗斯的程序员小哥Pave ...
- Python: 从PYTORCH导出模型到ONNX,并使用ONNX运行时运行它
Python: 从PYTORCH导出模型到ONNX,并使用ONNX运行时运行它 本教程我们将描述如何将PyTorch中定义的模型转换为ONNX格式,然后使用ONNX运行时运行它. ONNX运行时是一个 ...
- Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()
Pytorch 保存和加载模型后缀:.pt 和.pth 1 torch.save() [source] 保存一个序列化(serialized)的目标到磁盘.函数使用了Python的pickle程序用于 ...
最新文章
- 静态路由_【零基础学云计算】静态路由!静态路由!静态路由!原理与配置
- FMDB支持的事务类型
- nextcloud安装教程
- 问题 | Spare BA 中的Eigen运行错误
- 分享codeigniter框架,在zend studio 环境下的代码提示
- gdb常用命令及参考文档
- ds排序--希尔排序_图解直接插入排序和希尔排序
- XML文件的读取(XmlParserDemo)
- python编写arcgis脚本_ArcGis Python脚本——批量添加字段
- thinkphp6 加载第三方类库_thinkphp中第三方类引入问题
- .NET2.0 事务处理
- paip.windows io监控总结
- sqlserver两种分页方法比较
- 英文网站排名优化 谷歌SEO优化技巧方法
- android 恢复出厂设置流程分析,Android恢复出厂设置流程分析
- 教您在Xshell中清除历史记录
- 深圳学校积分计算机,深圳市龙岗区小学积分入学排行榜
- Python财务分析
- pyrosetta下载及安装(linux服务器)
- 统计学的Python实现-009:四分位数
热门文章
- -XX:+TraceClassLoading和-XX:+TraceClassUnloading
- AD16如何在3D环境翻转PCB
- 《阅读的方法》读书笔记2-2:遥远的地方
- 0402/0603/0805/1206封装尺寸
- 我的世界java版1.12.2版_我的世界Minecraft Java版1.12.2 pre2 发布
- word中最后一行留白太多
- BZOJ 2448: 挖油
- 按键精灵大漠插件使用基础练习入门代码
- 【Todo】【读书笔记】机器学习实战(Python版)
- bean login not found within scope