Pytorch常用操作
创建tensor
x = torch.empty(*sizes) #创建一个未初始化的tensor(后面用torch.nn.init中的一些函数进行初始化)
>>> torch.empty(2, 3) tensor(1.00000e-08 * [[ 6.3984, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000]])
x = torch.rand(5, 3) #返回一个范围为[0,1)、size为5*3的矩阵
tensor([[0.3380, 0.3845, 0.3217],[0.8337, 0.9050, 0.2650],[0.2979, 0.7141, 0.9069],[0.1449, 0.1132, 0.1375],[0.4675, 0.3947, 0.1426]])
x = torch.zeros(5, 3, dtype=torch.long)
tensor([[0, 0, 0],[0, 0, 0],[0, 0, 0],[0, 0, 0],[0, 0, 0]])
x = torch.ones(5, 3, dtype=torch.double)
tensor([[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]], dtype=torch.float64)
x = torch.tensor([5.5, 3]) #Construct a tensor directly from data
tensor([5.5000, 3.0000])
tensor运算
torch.mm(tensor1, tensor2, out=None) #tensor的矩阵乘法matrix multiplication
torch.mul(tensor1, tensor2, out=None) #tensor的点乘Hadamard product
tensor相关操作
x.size() #取tensor的size,返回的是tuple
z = x.view(-1, 8) #和reshape功能一样,只是参数少了一层括号
x = x.clamp(min, max) #取最大最小值,和numpy.clip(a, a_min, a_max, out=None)类似
torch.max()
torch.max(Tensor) #对所有元素,取最大值,返回只有一个数的tensor
torch.max(input, dim, keepdim=False, out=None) #对dim维度上的元素取最大值,返回两个tensor,第一个是dim上的最大值,第二个是最大值所在的位置(argmax)
torch.cat(seq, dim=0, out=None) #concatenate,功能和numpy.concatenate((a1, a2, ...), axis=0, out=None)一样,格式也恰好一样
一个技巧:inputs = torch.cat(inputs).view(len(inputs), 1, -1) #先cat再view(reshape)
torch.stack( (a,b,c) ,dim = 2) #建立一个新的维度,然后再在该纬度上进行拼接
torch.stack VS torch.cat:cat是在已有的维度上拼接,而stack是建立一个新的维度,然后再在该纬度上进行拼接。
用其实现 x.append(in_tensor) 的功能:先构造已经append好的x(此时x为list),然后x = torch.stack(x, dim = 0)
可参考 https://blog.csdn.net/Teeyohuang/article/details/80362756
torch.unsqueeze(input, dim, out=None) #给input(一个tensor)在dim维度上增加一个维度
>>> x = torch.tensor([1, 2, 3, 4]) >>> torch.unsqueeze(x, 0) tensor([[ 1, 2, 3, 4]]) >>> torch.unsqueeze(x, 1) tensor([[ 1], [ 2], [ 3], [ 4]])
b = a.numpy() #torch tensor转numpy array
b = torch.from_numpy(a) #numpy array转torch tensor(两种转都是没有复制,而是直接引用的)
tensor_a , idx_sort = torch.sort(tensor_a, dim=0, descending=True) #tensor排序,返回排序后的tensor和下标
tensor求导
x = torch.ones(2, 2, requires_grad=True) #创建时设置requires_grad为True,将x看成待优化的参数(权重)
model.zero_grad() #将每个权重的梯度清零(因为梯度会累加)
optimizer.zero_grad() #当optimizer=optim.Optimizer(model.parameters())时,其与model.zero_grad()等效
loss.backward() #求导,即对loss进行back propagation
optimizer.step() #在back propagation后更新参数
定义神经网络:
1. 定义网络架构(模型的forward,通常用一个继承自torch.nn.Module的类)
__init__():将nn实例化(每一个nn都是一个类),参数自己定义
forward(self, x):模型的forward,参数x为模型输入
self.add_module("conv", nn.Conv2d(10, 20, 4)) # self.conv = nn.Conv2d(10, 20, 4) 和这个增加module的方式等价
torch.nn.Embedding(num_embeddings, embedding_dim, ...) #是一个矩阵类,里面初始化了一个随机矩阵,矩阵的长是字典的大小,宽是用来表示字典中每个元素的属性向量,向量的维度根据你想要表示的元素的复杂度而定。类实例化之后可以根据字典中元素的下标来查找元素对应的向量。
2. 定义输入输出
3. 定义loss(如果用nn需要实例化才定义,否则用functional直接在训练中用)
4. 定义优化器
训练:
1. 初始化,如model.zero_grad()将一些参数初始化为0
2. 准备好输入
3. 将模型设置为train模式
4. 将模型forward
5. 计算loss和accuracy
6. back propagation并计算权重的梯度
7. 做validation
8. 打印Epoch、loss、acc、time等信息
(不一定所有步骤都有,可以看情况省略部分)
验证或测试:
1. 准备好输入
2. 将模型设置为eval模式
3. 将模型forward
4. 计算loss和accuracy
5. 打印loss、acc等信息
(train和test相比,主要多了bp相关的,包括zero_grad()和backward()等)
torch.
max
(input,dim,keepdim=False,out=None)
转载于:https://www.cnblogs.com/sbj123456789/p/9483760.html
Pytorch常用操作相关推荐
- 收藏!PyTorch常用代码段合集
↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:Jack Stark,来源:极市平台 来源丨https://zhu ...
- PyTorch常用代码段合集
↑ 点击蓝字 关注视学算法 作者丨Jack Stark@知乎 来源丨https://zhuanlan.zhihu.com/p/104019160 极市导读 本文是PyTorch常用代码段合集,涵盖基本 ...
- 【深度学习】PyTorch常用代码段合集
来源 | 极市平台,机器学习算法与自然语言处理 本文是PyTorch常用代码段合集,涵盖基本配置.张量处理.模型定义与操作.数据处理.模型训练与测试等5个方面,还给出了多个值得注意的Tips,内容非常 ...
- pytorch list转tensor_PyTorch 52.PyTorch常用代码段合集
本文参考于: Jack Stark:[深度学习框架]PyTorch常用代码段zhuanlan.zhihu.com 1. 基本配置 导入包和版本查询: import torch import torc ...
- (pytorch-深度学习系列)pytorch数据操作
pytorch数据操作 基本数据操作,都详细注释了,如下: import torch#5x3的未初始化的Tensor x = torch.empty(5, 3) print("5x3的未初始 ...
- 收藏 | PyTorch常用代码段合集
点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨Jack Stark@知乎 来源丨https://zh ...
- Pytorch常用总结(持续更新...)
Pytorch 1. transform部分: 参考:transforms的二十二个方法 2. dataloader部分: dataset&dataloader: 参考:dataset& ...
- Pytorch常用技巧记录
Pytorch常用技巧记录 目录 文章目录 Pytorch常用技巧记录 1.指定GPU编号 2.查看模型每层输出详情 3.梯度裁剪(Gradient Clipping) 4.扩展单张图片维度 5.独热 ...
- 【docker容器常用操作】
docker容器常用操作 docker生成镜像 docker 加载镜像 docker生成镜像 step1: 查看需要生成镜像的容器的id sudo docker ps 例如: comacai@DGX2 ...
最新文章
- oracle 10g数据库的异步提交
- flash mini播放器
- Python 学习日记 第四天
- ubuntu中显示本机的gpu_Ubuntu下如何查看GPU版本和使用信息
- css的一些常见问题处理方法
- java 蓝桥杯 石子游戏(题解)
- insert 数组_Java数组和集合的效率问题
- 基于HTML5的iPad电子杂志横竖屏自适应方案
- windows下phpstorm的常用快捷键及使用技巧
- C语言之枚举的定义以及测试
- vue-cli 上传图片上传到OSS(阿里云)
- java线程--object.waitobject.notify
- window 和虚拟机通过tftp实现文件传输
- php程序员中文,php中文网“php程序员工具箱” v0.1版本上线
- 用c语言switch写运费的,超级新手,用switch写了个计算器程序,求指导
- java星星随机下落_随机产生星星,单击星星消失
- 制作路由器openwrt安装及配置
- (02)Cartographer源码无死角解析-(32) LocalTrajectoryBuilder2D::AddRangeData()→点云的体素滤波
- ESP8266开发之旅 阿里云生活物联网平台篇② 使用云智能App,配置自己的App,无需开发
- 郑州大学计算机系好请假吗,郑州大学网上信息
热门文章
- nslookup 包含在那个包中_nslookup命令详解
- 同一个ip能否两次加入组播_组播IGMPv1/v2/v3精华知识汇总
- android 指针是什么意思,Android系统的智能指针(轻量级指针、强指针和弱指针)的实现原理分析(3)...
- nodejs android 推送,利用Nodejs怎么实现一个微信小程序消息推送功能
- html怎么设置数据条的颜色,jQuery EasyUI 数据网格 – 条件设置行背景颜色 | 菜鸟教程...
- html约束验证的例子,HTML5利用约束验证API来检查表单的输入数据的代码实例
- matplotlib的默认字体_浅谈matplotlib默认字体设置探索
- python读取json数据格式问题_浅谈Python中的异常和JSON读写数据的实现
- java中链式调用_Java及Android中常用链式调用写法简单示例
- 光端机安装调试需注意的几大因素