简述

一般来讲,Pytorch用torchvision调用vgg之类的模型话,如果电脑在cache(Pytorch硬编码的一个地址)(如果在环境变量中添加了TORCH_HOMETORCH_MODEL_ZOO的话,就是在这两个位置的联合的路径下,比如TORCH_MODEL_ZOO\model)否则就是在TORCH_HOME\models或者是~/.torch/models

比如,我的就是C:\Users\lijy2/.torch\models\vgg11-bbd30ac9.pth

这很有可能并不是我们想要的下载模型放的地址,或者是这样的下载方式很慢等等。

而且这个地址不可以很容易的直接调用,非常不方便。

这点,在我现在用pytorch版本还是github上的最新版本都是没有做类似的改进的。

但是这种设计(可能对我这种强迫症来说),是有需求的。

解决办法

首先,先处理下载的问题。

读了下源码,是使用import torch.utils.model_zoo as model_zoo里面的函数来加载数据。
整理了下源码中涉及的这一部分

from urllib.parse import urlparse
import torch.utils.model_zoo as model_zoo
import re
import os
def download_model(url, dst_path):parts = urlparse(url)filename = os.path.basename(parts.path)HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')hash_prefix = HASH_REGEX.search(filename).group(1)model_zoo._download_url_to_file(url, os.path.join(dst_path, filename), hash_prefix, True)return filename

调用实例

model_urls = {'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth','vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth','vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth','vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth','vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth','vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth','vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth','vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}import ospath = 'D:/Software/DataSet/models/vgg'
if not (os.path.exists(path)):os.makedirs(path)
for url in model_urls.values():download_model(url, path)

输出

100%|███████████████████████████████████████████████████████████████| 531456000/531456000 [01:14<00:00, 7114218.15it/s]
100%|███████████████████████████████████████████████████████████████| 532194478/532194478 [02:46<00:00, 3193007.52it/s]
100%|███████████████████████████████████████████████████████████████| 553433881/553433881 [01:13<00:00, 7536750.60it/s]
100%|██████████████████████████████████████████████████████████████| 574673361/574673361 [00:54<00:00, 10587712.79it/s]
100%|███████████████████████████████████████████████████████████████| 531503671/531503671 [01:10<00:00, 7548305.64it/s]
100%|███████████████████████████████████████████████████████████████| 532246301/532246301 [01:35<00:00, 5598996.73it/s]
100%|██████████████████████████████████████████████████████████████| 553507836/553507836 [00:50<00:00, 10900603.60it/s]
100%|███████████████████████████████████████████████████████████████| 574769405/574769405 [01:11<00:00, 8023263.07it/s]

其他的模型地址,可以打开github里面对应模型的代码,一打开就看到了。

https://github.com/pytorch/vision/tree/master/torchvision/models

再看看加载

import glob
import os
def load_model(model_name, model_dir):model  = eval('models.%s(init_weights=False)' % model_name)path_format = os.path.join(model_dir, '%s-[a-z0-9]*.pth' % model_name)model_path = glob.glob(path_format)[0]model.load_state_dict(torch.load(model_path))return model

使用实例:

model_dir = 'D:/Software/DataSet/models/vgg/'
model = load_model('vgg11', model_dir)

输出:

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace)(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(4): ReLU(inplace)(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(7): ReLU(inplace)(8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(9): ReLU(inplace)(10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(12): ReLU(inplace)(13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(14): ReLU(inplace)(15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(16): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(17): ReLU(inplace)(18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(19): ReLU(inplace)(20): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace)(2): Dropout(p=0.5)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace)(5): Dropout(p=0.5)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

Pytorch预训练模型下载并加载(以VGG为例)自定义路径相关推荐

  1. Pytorch预训练模型下载慢解决方式

    最近在使用与训练网络模型ResNet50 Faster R-CNN的时候,发现系统自带的下载方式是非常慢的,而且等待好久一段时间后出现: TimeoutError: [WinError 10060] ...

  2. Pytorch预训练模型加载

    1. 保存模型:torch.save(model.state_dict(), PATH) 加载模型:model.load_state_dict(torch.load(PATH)) model.eval ...

  3. Dns-prefetch DNS 预解析优化页面加载速度

    Dns-prefetch DNS 预解析优化页面加载速度 浏览器访问一个链接时并不是直接将请求到网页对应的服务器上,而是先要做域名解析--将域名解析到网页对应的服务器 ip 地址,然后浏览器才能和服务 ...

  4. Hugging face预训练模型下载和使用

    Huggingface Huggingface是一家公司,在Google发布BERT模型不久之后,这家公司推出了BERT的pytorch实现,形成一个开源库pytorch-pretrained-ber ...

  5. Bert,Albert,Roberta,XLNet的中英文预训练模型下载网址及教程

    自然语言处理的各大热门的中英文预训练模型下载网址,包含了Bert,Albert, Roberta, XLNet等模型的base和large.tensorflow和pytorch版本的预训练模型. ht ...

  6. PyTorch 保存模型结构参数及加载模型

    PyTorch 保存模型结构参数及加载模型 保存模型与加载 保存模型分为两种方式: 保存整个网络结构和参数 保存整个网络的参数 # 1.保存并加载整个网络结构和参数 # 保存模型 torch.save ...

  7. 【ArcGIS风暴】ArcGIS Editor for OSM中文教程(2):下载及加载OSM数据

    本文讲解在ArcGIS中借助OpenStreetMap工具下载并加载OSM数据. 文章目录 1. 下载OSM数据 2. 加载OSM数据 1. 下载OSM数据 在工具箱中双击Download OSM D ...

  8. 使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作

    使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作 总共分为四步 构造一个my_dataset类,继承自torch.utils.data.Dataset 重写__getite ...

  9. PyTorch训练中Dataset多线程加载数据,比Dataloader里设置多个workers还要快

    PyTorch训练中Dataset多线程加载数据,而不是在DataLoader 背景与需求 现在做深度学习的越来越多人都有用PyTorch,他容易上手,而且API相对TF友好的不要太多.今天就给大家带 ...

最新文章

  1. C++ 中的引用 和指针的区别
  2. android 音乐播放器的状态栏通知,Android仿虾米音乐播放器之通知栏notification解析...
  3. 凑个热闹-LayoutInflater相关分析
  4. [VirtaulBox]网络连接设置
  5. 五子棋ai算法python_[深度学习]实现一个博弈型的AI,从五子棋开始(1)
  6. XP系统服务启动设置优化
  7. 阿里巴巴上市路演ppt 官方完整版
  8. 计算机应用u盘解释,U盘速度测试和参数解释
  9. VBlog项目代码理解之后端
  10. 魔百和CM311-1a YST免拆机卡刷精简固件
  11. Java项目:SSM酒店客房管理系统
  12. 手游沙巴克传奇当前服务器维护,《沙巴克传奇》12月18日安卓、IOS维护公告
  13. 30 行代码实现蚂蚁森林自动收能量
  14. java web工程中如何添加图片_java web中如何添加图片
  15. python 如何实现依据依存关系构造邻接矩阵(有向图)
  16. 宝塔win安装提示非服务器系统,宝塔windows面板安装
  17. 文件搜索工具(Python实现)
  18. Nginx主要用来干什么
  19. “最不合格”的SAP应聘者: 从大学生到SAP成都研究院开发工程师
  20. 《Linux From Scratch》第三部分:构建LFS系统 第七章:基本系统配置- 7.9. 创建 /etc/shells 文件...

热门文章

  1. Android各层推荐开发书籍及参考资料
  2. ES 自动恢复分片的时候不恢复了是磁盘超过了85%,然后不恢复了 ES可以配置多个数据目录...
  3. ZABBIX3.0配置邮件报警
  4. Powershell管理系列(二十五)PowerShell操作之获取AD账号及邮箱信息
  5. innodb行锁理解
  6. websphere变成英文了
  7. DrawerLayout + Toolbar + ViewPager
  8. ANSI C标准函数库
  9. 推荐算法实现之BMF(pymc3+MovieLen)
  10. Linux操作系统Ubuntu部署Oracle篇