Pytorch预训练模型下载并加载(以VGG为例)自定义路径
简述
一般来讲,Pytorch用torchvision调用vgg之类的模型话,如果电脑在cache(Pytorch硬编码的一个地址)(如果在环境变量中添加了TORCH_HOME
和TORCH_MODEL_ZOO
的话,就是在这两个位置的联合的路径下,比如TORCH_MODEL_ZOO\model
)否则就是在TORCH_HOME\models
或者是~/.torch/models
比如,我的就是C:\Users\lijy2/.torch\models\vgg11-bbd30ac9.pth
。
这很有可能并不是我们想要的下载模型放的地址,或者是这样的下载方式很慢等等。
而且这个地址不可以很容易的直接调用,非常不方便。
这点,在我现在用pytorch版本还是github上的最新版本都是没有做类似的改进的。
但是这种设计(可能对我这种强迫症来说),是有需求的。
解决办法
首先,先处理下载的问题。
读了下源码,是使用import torch.utils.model_zoo as model_zoo
里面的函数来加载数据。
整理了下源码中涉及的这一部分
from urllib.parse import urlparse
import torch.utils.model_zoo as model_zoo
import re
import os
def download_model(url, dst_path):parts = urlparse(url)filename = os.path.basename(parts.path)HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')hash_prefix = HASH_REGEX.search(filename).group(1)model_zoo._download_url_to_file(url, os.path.join(dst_path, filename), hash_prefix, True)return filename
调用实例
model_urls = {'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth','vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth','vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth','vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth','vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth','vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth','vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth','vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}import ospath = 'D:/Software/DataSet/models/vgg'
if not (os.path.exists(path)):os.makedirs(path)
for url in model_urls.values():download_model(url, path)
输出
100%|███████████████████████████████████████████████████████████████| 531456000/531456000 [01:14<00:00, 7114218.15it/s]
100%|███████████████████████████████████████████████████████████████| 532194478/532194478 [02:46<00:00, 3193007.52it/s]
100%|███████████████████████████████████████████████████████████████| 553433881/553433881 [01:13<00:00, 7536750.60it/s]
100%|██████████████████████████████████████████████████████████████| 574673361/574673361 [00:54<00:00, 10587712.79it/s]
100%|███████████████████████████████████████████████████████████████| 531503671/531503671 [01:10<00:00, 7548305.64it/s]
100%|███████████████████████████████████████████████████████████████| 532246301/532246301 [01:35<00:00, 5598996.73it/s]
100%|██████████████████████████████████████████████████████████████| 553507836/553507836 [00:50<00:00, 10900603.60it/s]
100%|███████████████████████████████████████████████████████████████| 574769405/574769405 [01:11<00:00, 8023263.07it/s]
其他的模型地址,可以打开github里面对应模型的代码,一打开就看到了。
https://github.com/pytorch/vision/tree/master/torchvision/models
再看看加载
import glob
import os
def load_model(model_name, model_dir):model = eval('models.%s(init_weights=False)' % model_name)path_format = os.path.join(model_dir, '%s-[a-z0-9]*.pth' % model_name)model_path = glob.glob(path_format)[0]model.load_state_dict(torch.load(model_path))return model
使用实例:
model_dir = 'D:/Software/DataSet/models/vgg/'
model = load_model('vgg11', model_dir)
输出:
VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace)(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(4): ReLU(inplace)(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(7): ReLU(inplace)(8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(9): ReLU(inplace)(10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(12): ReLU(inplace)(13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(14): ReLU(inplace)(15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(16): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(17): ReLU(inplace)(18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(19): ReLU(inplace)(20): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace)(2): Dropout(p=0.5)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace)(5): Dropout(p=0.5)(6): Linear(in_features=4096, out_features=1000, bias=True))
)
Pytorch预训练模型下载并加载(以VGG为例)自定义路径相关推荐
- Pytorch预训练模型下载慢解决方式
最近在使用与训练网络模型ResNet50 Faster R-CNN的时候,发现系统自带的下载方式是非常慢的,而且等待好久一段时间后出现: TimeoutError: [WinError 10060] ...
- Pytorch预训练模型加载
1. 保存模型:torch.save(model.state_dict(), PATH) 加载模型:model.load_state_dict(torch.load(PATH)) model.eval ...
- Dns-prefetch DNS 预解析优化页面加载速度
Dns-prefetch DNS 预解析优化页面加载速度 浏览器访问一个链接时并不是直接将请求到网页对应的服务器上,而是先要做域名解析--将域名解析到网页对应的服务器 ip 地址,然后浏览器才能和服务 ...
- Hugging face预训练模型下载和使用
Huggingface Huggingface是一家公司,在Google发布BERT模型不久之后,这家公司推出了BERT的pytorch实现,形成一个开源库pytorch-pretrained-ber ...
- Bert,Albert,Roberta,XLNet的中英文预训练模型下载网址及教程
自然语言处理的各大热门的中英文预训练模型下载网址,包含了Bert,Albert, Roberta, XLNet等模型的base和large.tensorflow和pytorch版本的预训练模型. ht ...
- PyTorch 保存模型结构参数及加载模型
PyTorch 保存模型结构参数及加载模型 保存模型与加载 保存模型分为两种方式: 保存整个网络结构和参数 保存整个网络的参数 # 1.保存并加载整个网络结构和参数 # 保存模型 torch.save ...
- 【ArcGIS风暴】ArcGIS Editor for OSM中文教程(2):下载及加载OSM数据
本文讲解在ArcGIS中借助OpenStreetMap工具下载并加载OSM数据. 文章目录 1. 下载OSM数据 2. 加载OSM数据 1. 下载OSM数据 在工具箱中双击Download OSM D ...
- 使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作
使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作 总共分为四步 构造一个my_dataset类,继承自torch.utils.data.Dataset 重写__getite ...
- PyTorch训练中Dataset多线程加载数据,比Dataloader里设置多个workers还要快
PyTorch训练中Dataset多线程加载数据,而不是在DataLoader 背景与需求 现在做深度学习的越来越多人都有用PyTorch,他容易上手,而且API相对TF友好的不要太多.今天就给大家带 ...
最新文章
- C++ 中的引用 和指针的区别
- android 音乐播放器的状态栏通知,Android仿虾米音乐播放器之通知栏notification解析...
- 凑个热闹-LayoutInflater相关分析
- [VirtaulBox]网络连接设置
- 五子棋ai算法python_[深度学习]实现一个博弈型的AI,从五子棋开始(1)
- XP系统服务启动设置优化
- 阿里巴巴上市路演ppt 官方完整版
- 计算机应用u盘解释,U盘速度测试和参数解释
- VBlog项目代码理解之后端
- 魔百和CM311-1a YST免拆机卡刷精简固件
- Java项目:SSM酒店客房管理系统
- 手游沙巴克传奇当前服务器维护,《沙巴克传奇》12月18日安卓、IOS维护公告
- 30 行代码实现蚂蚁森林自动收能量
- java web工程中如何添加图片_java web中如何添加图片
- python 如何实现依据依存关系构造邻接矩阵(有向图)
- 宝塔win安装提示非服务器系统,宝塔windows面板安装
- 文件搜索工具(Python实现)
- Nginx主要用来干什么
- “最不合格”的SAP应聘者: 从大学生到SAP成都研究院开发工程师
- 《Linux From Scratch》第三部分:构建LFS系统 第七章:基本系统配置- 7.9. 创建 /etc/shells 文件...
热门文章
- Android各层推荐开发书籍及参考资料
- ES 自动恢复分片的时候不恢复了是磁盘超过了85%,然后不恢复了 ES可以配置多个数据目录...
- ZABBIX3.0配置邮件报警
- Powershell管理系列(二十五)PowerShell操作之获取AD账号及邮箱信息
- innodb行锁理解
- websphere变成英文了
- DrawerLayout + Toolbar + ViewPager
- ANSI C标准函数库
- 推荐算法实现之BMF(pymc3+MovieLen)
- Linux操作系统Ubuntu部署Oracle篇