PyTorch中常用工具

torchvision

  1. models:提供深度学习中各种经典的网络结构以及预训练好的模型,包括AlexNet, VGG, ResNet, Inception
  2. datsets:提供常用的数据集加载,设计上均继承torch.utils.data.dataset,主要包括MNIST,CIFAR10/100,ImageNet,COCO
  3. transforms:提供常用的数据预处理操作,主要包括对tensor和PILImage对象的操作
  4. torchvision.utils.save_image:直接将tensor保存成图片
    from torch.utils import data
    import os
    from PIL import Image
    from torchvision import transforms
    from torchvision import utils
    import numpy
    import torchclass Data(data.Dataset):def __init__(self, root):# 返回指定路径下的文件和文件夹列表。imgs_HR = os.listdir(os.path.join(root, 'gt'))self.imgs_HR = [os.path.join(root, 'gt', img) for img in imgs_HR]imgs_LR = os.listdir(os.path.join(root, 'lr'))self.imgs_LR = [os.path.join(root, 'lr', img) for img in imgs_LR]self.transform = transforms.ToTensor()def __getitem__(self, item):img_path_LR = self.imgs_LR[item]img_path_HR = self.imgs_HR[item]LR_img = Image.open(img_path_LR)HR_img = Image.open(img_path_HR)HR = self.transform(HR_img)# print(HR1.shape)LR = self.transform(LR_img)filename = os.path.splitext(os.path.basename(img_path_HR))return LR, HR, filenamedef __len__(self):return len(self.imgs_HR)if __name__ == '__main__':train = Data(root='/data/wcy/celebA-18000/test')for LR, HR, filename in train:print(filename)utils.save_image(HR,os.path.join('/data/wcy/wcy/fishnet/result-test/{}.png').format(str(filename[0])))
    

可视化工具

  1. tensorboardX:命令:tensorboard --logdir=log目录 --port=指定端口
  2. visdom:pip install visdom

GPU加速

tensor,variable,nn.module都有一个.cuda对象,通过调用该方法可以将其转为对应的GPU对象,variable和tensor.cuda会将新对象转移到GPU而其他的数据还保留在cpu。而module.cuda会将所有的数据都转移到GPU上。

服务器有多个GPU,tensor.cuda()会将tensor保存到第一块GPU上,等价于tensor.cuda(0)。指定其他GPU的方法:

  1. torch.cuda.set_device(1):指定第二块GPU
  2. 设置环境变量:CUDA_VISIBLE_DEVICES=0,2,3,设置使用第一,三、四块卡。此时tensor.cuda(1)会将tensor转移到CUDA_VISIBLE_DEVICES[1]=2,即第三块卡上。
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2"

持久化

tensor,variable,nn.module和optimizer均可以保存到硬盘中。保存:torch.save(obj, filename)加载:obj = torch.load(filename)

module和optimizer建议保存他们的state_dict()

from torchvision import transforms
import torch as t
a = t.Tensor(3, 4)
if t.cuda().is_available():a = a.cuda()t.save(a, 'a.pth')b = t.load('a.pth')# 加载为c, 存储于CPUc = t.load('a.pth', map_llocation=lambda storage, loc:storage)# 加载为d,存储于GPUd = t.load('a.pth', map_location={'cuda:1':'cuda:0'})# 设置默认类型
t.set_default_tensor_type('torch.FloatTensor')
from torchvision.models import AlexNet
model = AlexNet()
t.save(model.state_dict(), 'alexnet.pth')
model.load_state_dict(t.load('alexnet.pth'))
optimizer = t.optim.Adam(model.parameters(), lr=0.1)
t.save(optimizer.state_dict(), 'optimizer.pth')
optimizer.load_state_dict(t.load('optimizer.pth'))
all_data = dict(optimizer=optimizer.state_dict(),model = model.state_dict(),info = u'模型和优化器的所有参数'
)
t.save(all_data, 'all.pth')
all_data = t.load('all_pth')
all_data.keys()

深度学习框架PyTorch:入门与实践 学习(四)相关推荐

  1. numpy pytorch 接口对应_拆书分享篇深度学习框架PyTorch入门与实践

    <<深度学习框架PyTorch入门与实践>>读书笔记 <深度学习框架PyTorch入门与实践>读后感 小作者:马苗苗  读完<<深度学习框架PyTorc ...

  2. 深度学习框架Pytorch入门与实践——读书笔记

    2 快速入门 2.1 安装和配置 pip install torch pip install torchvision#IPython魔术命令 import torch as t a=t.Tensor( ...

  3. 深度学习框架PyTorch入门与实践:第二章 快速入门

    本章主要介绍两个内容,2.1节介绍如何安装PyTorch,以及如何配置学习环境:2.2节将带领读者快速浏览PyTorch中主要内容,给读者一个关于PyTorch的大致印象. 2.1 安装与配置 2.1 ...

  4. 深度学习框架PyTorch入门与实践:第八章 AI艺术家:神经网络风格迁移

    本章我们将介绍一个酷炫的深度学习应用--风格迁移(Style Transfer).近年来,由深度学习引领的人工智能技术浪潮越来越广泛地应用到社会各个领域.这其中,手机应用Prisma,尝试为用户的照片 ...

  5. 深度学习框架PyTorch入门与实践:第七章 AI插画师:生成对抗网络

    生成对抗网络(Generative Adversarial Net,GAN)是近年来深度学习中一个十分热门的方向,卷积网络之父.深度学习元老级人物LeCun Yan就曾说过"GAN is t ...

  6. 深度学习框架PyTorch入门与实践:第九章 AI诗人:用RNN写诗

    我们先来看一首诗. 深宫有奇物,璞玉冠何有. 度岁忽如何,遐龄复何欲. 学来玉阶上,仰望金闺籍. 习协万壑间,高高万象逼. 这是一首藏头诗,每句诗的第一个字连起来就是"深度学习". ...

  7. 深度学习框架pytorch入门之张量Tensor(一)

    文章目录 一.简介 二.查看帮助文档 三.Tensor常用方法 1.概述 2.新建方法 (1)Tensor(*sizes) tensor基础构造函数 (2)ones(*sizes) 构造一个全为1的T ...

  8. 深度学习框架PyTorch一书的学习-第四章-神经网络工具箱nn

    参考https://github.com/chenyuntc/pytorch-book/tree/v1.0 希望大家直接到上面的网址去查看代码,下面是本人的笔记 本章介绍的nn模块是构建与autogr ...

  9. 深度学习框架PyTorch一书的学习-第三章-Tensor和autograd-1-Tensor

    参考https://github.com/chenyuntc/pytorch-book/tree/v1.0 希望大家直接到上面的网址去查看代码,下面是本人的笔记 Tensor Tensor可以是一个数 ...

  10. 好书分享——《深度学习框架PyTorch:入门与实践》

    内容简介 : <深度学习框架PyTorch:入门与实践>从多维数组Tensor开始,循序渐进地带领读者了解PyTorch各方面的基础知识.结合基础知识和前沿研究,带领读者从零开始完成几个经 ...

最新文章

  1. Idea的一些调试技巧及设置todo
  2. SVN使用import导入新数据到版本库
  3. 重温这几个屌爆的Python技巧!
  4. vivo手机系统升级后没有服务器,为什么安卓手机升级到了12GB还没有iPhone 4GB运行快呢?...
  5. AI队列长度检测:使用YOLO进行视频中的对象检测
  6. 中文字符存储 mysql_中文字符的存储
  7. apache php提示下载,apache正在下载php文件而不是显示它们。
  8. Git教程_1 简介
  9. 以上是对图像的椒盐噪声处理,在p_temp[j*wide+i]=0;这句程序中为什么要乘以wide,求解,谢谢!
  10. c语言中用double写圆的面积,用java写一个函数area,接收一个double类型的参数(表示圆的半径r),用于计算圆的面积...
  11. 有什么轻量级的大数据技术?
  12. 代码中的Status和State语义
  13. 浅谈CMMI3认证从评估前准备到正式评估的全部过程
  14. 了解、熟悉、精通 的三种并代表什么意思
  15. 捋一捋Unified Language Model Pre-training for Natural Language Understanding and Generation
  16. 前端开发学习笔记(一):HTML
  17. ado.net访问ORACLE数据库点滴
  18. sniffer-agent
  19. noip模拟赛 街灯
  20. 使用rpm包制作本地镜像仓库和使用httpd发布镜像服务实现内网使用yum命令

热门文章

  1. javascript如何设置名字输入不合法
  2. POJ 2409 Let it Bead(Polya简单应用)
  3. Console-算法[for]-素数
  4. CSS定位设置实例——盒子的定位
  5. GTD+敏捷=一种新的计划列表理念和方法。
  6. Hibernate 主键
  7. OpenCV-图像处理(27、模板匹配(Template Match))
  8. OpenCV-图像处理(06、调整图像亮度与对比度)
  9. 诺基亚n1支持java功能_关于诺基亚N1你必须要了解这10个问题!
  10. linux 编译java web_linux:搭建java web环境