因为需求,需要调研tensorRT与ONNX关于自定义层的方法。经过之前的调研,首先,关于onnx,开发者手册中的介绍有限,在已知的demo中没有关于onnx自定义层的,详情见TensorRT 5.1.5.0入门 Pytorch & ONNX.
后来,自己下载了onnx-tensorrt的代码onnx-tensorrt,发现NvOnnxParser.h中只写了IParser类,不像NvCaffeParser.h包装了IPluginFactory和IPluginFactoryExt等类,所以相当于是没有自定义层接口的。
然后查询资料发现,pytorch基于op,所以学习一下pytorch扩展op的方法。

文章结构

  • 用python进行pytorch扩展(继承autograd.Function)
    • 创建op操作(函数)(扩展torch.autograd)
    • 使用自定义的op(扩展 torch.nn)创建自定义layer
      • nn.Parameter
      • .data.uniform_(-0.1,0.1)
  • 用C++进行pytorch扩展(继承autograd.Function)
  • 用C进行pytorch扩展(继承autograd.Function)
  • Python 积累
    • 1. super()
    • 2. isinstance()

去pytorch1.1.0的官网看,有extending pytorch的模块 extending_pytorch,有三种扩展方法,python,c++和C。中文版本 扩展pytorch.

用python进行pytorch扩展(继承autograd.Function)

添加op,需要用autograd的Function为每个操作实现一个新的子类。其实,Function就是autograd用来计算结果和渐变以及编码操作历史的模块。

创建op操作(函数)(扩展torch.autograd)

同时参考了之前版本pytorch自定义层实现博文Pytorch入门学习(八)-----自定义层的实现(甚至不可导operation的backward写法),根据步骤定义了自己的LinearFunction.

from torch.autograd import Function@staticmethoddef forward(ctx,input,weight,bias=None,beta_f=1.0,alpha_f=1.0):ctx.save_for_backward(input,weight,bias)ctx.beta=beta_fctx.alpha=alpha_foutput=input.mm(weight.t())if bias is not None:output+=bias.unsqueeze(0).expand_as(output)return output@staticmethoddef backward(ctx,grad_output):input,weight,bias=ctx.saved_variablesgrad_input=grad_weight=grad_bias=Noneif ctx.needs_input_grad[0]:grad_input=grad_output.mm(weight)if ctx.needs_input_grad[1]:grad_weight=grad_output.t().mm(input)if bias is not None and ctx.needs_input_grad[2]:grad_bias=grad_output.sum(0).squeeze(0)return grad_input,grad_weight,grad_bias,None,None

总的来说,我的理解就是继承torch.autograd.Function基类,然后写forward和backward操作,symbolic是后面转onnx相关的,和自定义op暂时无关。
在1.1.0版本的pytorch,variable和tensor已经统一了,所以好像不用考虑很多博客提到的转tensor问题。
forward:
(1)首先用save_for_backward是用来存tensor的,存起来之后留给backward的时候用。
(2)下面就是全连接层的常规操作。
backward:
(1)首先,从ctx中取出需要的tensor(以前的功能也有转成variable,现在可能没有了吧,我太知道)
(2)grad_output,是否梯度改变,即requires_grad是否为True取决于在外面调用.backward或是.grad时候的那个Variable是不是需要grad的。如果那个Variable是需要grad的,那么我们这里反向的grad_ouput也是requires_grad为True,那么我们甚至可以计算二阶梯度。用WGAN-GP之类的。

使用自定义的op(扩展 torch.nn)创建自定义layer

官网上说(划重点):
nn 包含两种接口 - modules和他们的functional版本。通过这两个接口,你都可以扩展nn。但是我们建议,在扩展layer的时候,使用modules, 因为modules保存着参数和buffer。如果不需要参数的话,那么建议使用functional(激活函数,pooling等)。
增加一个operation的 functional版本已经在上面一节介绍完毕。
所以就用Module了,都是官网上有的代码:

    def __init__(self,input_featrues,output_features,bias=True,beta=1.0,alpha=1.0):super(Linear,self).__init__()self.input_features=input_featruesself.output_features=output_featuresself.weight=nn.Parameter(torch.Tensor(output_features,input_featrues))if bias:self.bias=nn.Parameter(torch.Tensor(output_features))else:self.register_parameter('bias',None)self.weight.data.uniform_(-0.1,0.1)if bias is not None:self.bias.data.uniform_(-0.1,0.1)# 这里的beta和alpha没有实际用处,只是证明使用自定义的op,在torch->onnx过程中,是可以传递网络参数的。self.beta=betaself.alpha=alphadef forward(self,input):return LinearFunction.apply(input,self.weight,self.bias,self.beta,self.alpha)

这里,除了告诉我,要用.apply来引用自定义op之外。
作为pytorch小白的我,学到的就是nn.Parameter,uniform_

nn.Parameter

torch.nn.Parameterm,在看过很多博客的时候发现了一个用法self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)),首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。
换句话说就是个可训练的tensor
主要用于达到学习的效果,如下例所示,concat注意力机制

class Attn(torch.nn.Module):def __init__(self,method,hidden_size):super(Attn,self).__init__()self.method=methodif self.method not in ['dot','general','concat']:raise ValueError(self.method, "is not an supportted method")self.hidden_size=hidden_sizeif self.method == 'general':self.attn=torch.nn.Linear(self.hidden_size,hidden_size)elif self.method=='concat':self.attn=torch.nn.Linear(self.hidden_size*2,hidden_size)self.v=torch.nn.Parameter(torch.FloatTensor(hidden_size))...

self.v需要不断学习,而实验发现,linear里面的weight和bias就是parameter类型,且不能够使用tensor类型替换,还有linear里面的weight甚至可能通过指定一个不同于初始化时候的形状进行模型的更改。大致是这么写的:

...
self.linear_weight=torch.nn.Parameter(torch.nn.uniform_(-0.1,0.1)
self.linear_bias=torch.nn.Parameter(torch.zeros(in_dim,hid))
...

.data.uniform_(-0.1,0.1)

这个就是个初始化,uniform_指的是均匀分布。

用C++进行pytorch扩展(继承autograd.Function)

用C进行pytorch扩展(继承autograd.Function)

啊啊啊,满地都是坑啊,等着我来填....人生好难我好烦,冲鸭~

Python 积累

1. super()

描述
super()函数是调用父类(超类)的一个方法。
super是用来解决多继承问题的,直接用类名调用父类方法在使用单继承的时候没有问题,但是如果使用多继承,会涉及到查找顺序(MRO)、重复调用(钻石继承)等种种问题。
MRO就是类的方法解析顺序表,其实就是继承父类方法的顺序表。
语法
super(type[,objct-or-type])
参数
type:类
object-or-type:类,一般是self
实例
使用实例

class A:def add(self, x):y = x+1print(y)
class B(A):def add(self, x):super().add(x)
b = B()
b.add(2)  # 3

关于多继承问题的,super函数使用实例:

class FooParent(object):def __init__(self):self.parent = 'I\'m the parent.'print ('Parent')def bar(self,message):print ("%s from Parent" % message)class FooChild(FooParent):def __init__(self):# super(FooChild,self) 首先找到 FooChild 的父类(就是类 FooParent),然后把类 FooChild 的对象转换为类 FooParent 的对象super(FooChild,self).__init__()    print ('Child')def bar(self,message):super(FooChild, self).bar(message)print ('Child bar fuction')print (self.parent)if __name__ == '__main__':fooChild = FooChild()fooChild.bar('HelloWorld')

即super(FooChild,self)指的就是把现在的FooChild类中的self变成FooChild父类的self,这一行就执行了FooParent的初始化函数。
扩展

class parent():def __init__(self):print("parent")self.hungary=12def add(self,x):return x+1
class child(parent):def __init__(self):super(child,self).__init__()print("child")def add(self,x):return x+100input=0
a=child()
print(a.add(input))

输出为100
所以是优先调用子类中定义的同名函数的。

2. isinstance()

描述
isinstance()函数来判断一个对象是否是一个已知的类型,类似 type()。
isinstance() 与 type() 区别:

type() 不会认为子类是一种父类类型,不考虑继承关系。
isinstance() 会认为子类是一种父类类型,考虑继承关系

如果要判断两个类型是否相同推荐使用 isinstance()。
语法
isinstance(object,classinfo)
参数
object – 实例对象
classinfo – 可以是直接或间接类名、基本类型或者由他们组成的元组
实例


>>>a = 2
>>> isinstance (a,int)
True
>>> isinstance (a,str)
False
>>> isinstance (a,(str,int,list))    # 是元组中的一个返回 True
True

type和isinstance的区别:


class A:passclass B(A):passisinstance(A(), A)    # returns True
type(A()) == A        # returns True
isinstance(B(), A)    # returns True
type(B()) == A        # returns False

Pytorch1.1.0 入门 自定义op(python)相关推荐

  1. TensorFlow使用Python自定义op和损失函数

    TensorFlow使用Python自定义op和损失函数 TensorFlow是静态图结构,即必须把所有的操作以及网络结构定义好(后来有了动态图功能,即Eager Execution ),在没有用tf ...

  2. 0基础学python难吗-0基础学武汉Python开发课程有多难?该怎么入门?

    Python语言可谓十分强大,正如它的两个外号所称,一个是"内置电池",另一个是"胶水语言".开源社区和独立开发者长期为Python贡献了丰富大量的第三方库,其 ...

  3. python入门指南-python3.6.0入门指南(官方版).pdf

    您所在位置:网站首页 > 海量文档 &nbsp>&nbsp计算机&nbsp>&nbspPython python3.6.0入门指南(官方版).pdf7 ...

  4. python入门指南小说免费阅读-python3.6.0入门指南(官方版).pdf

    您所在位置:网站首页 > 海量文档 &nbsp>&nbsp计算机&nbsp>&nbspPython python3.6.0入门指南(官方版).pdf7 ...

  5. 给深度学习入门者的Python快速教程 - 番外篇之Python-OpenCV

    转载自:https://zhuanlan.zhihu.com/p/24425116 本篇是前面两篇教程:给深度学习入门者的Python快速教程 - 基础篇 给深度学习入门者的Python快速教程 - ...

  6. tensorflow:自定义op

    比官网介绍的更好理解,特此转载 tensorflow:自定义op简单介绍 2017年06月26日 13:32:55 阅读数:6094 tensorflow 自定义 op 本文只是简单的翻译了 http ...

  7. tensorflow自定义op:梯度

    暂时并未解决我的问题,但感觉将来会有用,特此转载 . 在使用 tensorflow 的时候,有时不可避免的会需要自定义 op,官方文档对于 定义 op 的前向过程介绍挺详细,但是对于 梯度 的介绍有点 ...

  8. tensorflow:自定义op简单介绍

    本文只是简单的翻译了 https://www.tensorflow.org/extend/adding_an_op 的简单部分,高级部分请移步官网. 可能需要新定义 c++ operation 的几种 ...

  9. python 入门题库————python语句和基础数理

    python 入门题库 python 题库 Python使用符号_______表示注释 Python不支持的数据类型有 查看python版本的命令是 在Python中,print(type(16/4) ...

最新文章

  1. 国内第一本项目管理的实践书籍——《IT项目管理那些事儿》
  2. mysql 数据库引擎介绍_MYSQL 数据库引擎介绍
  3. 开发VR游戏的基本要求
  4. condition的作用
  5. 实现ip数据包抓取并分析_一些网站https证书出现问题的情况分析
  6. 代码同步工具_构建现代化的命令行工具
  7. java rsa2加密算法_java RSA加密解密
  8. 电脑文件夹同步到云端
  9. 【docker】如何在docker中执行redis命令
  10. PTA 数据结构与算法分析 7-38 寻找大富翁 (25 分)
  11. vue.js开发微信公众号加载缓慢出现的白页问题-随笔
  12. A callback was made on a garbage collected delegate of type...
  13. python生成快递取件码没了怎么办_货到速递易,但没有收到取件码,怎么办
  14. Python制做动态图
  15. DSPE-PEG-NGR,NGR-PEG-DSPE,磷脂-聚乙二醇-靶向肽NRG
  16. 那位把每天当做试用期的女孩,升职为总裁助理了
  17. Vs Cood更新失败
  18. android 倒计时翻牌子,android倒计时(整理)
  19. 浅谈LLC变换器的设计经历
  20. 斗鱼直播画面怎么弄到自己网页上_如何一部手机玩转手游直播

热门文章

  1. 认证管理(锐捷网关篇)
  2. L2-015 互评成绩
  3. 脚本显示服务器超时,服务器诡异的请求超时问题
  4. 抖音跳转微信加好友功能实现解析
  5. 如何将已有图片做成透明水印_如何用Photoshop在图片上添加透明水印?
  6. Python - 期货CTP常见问题解答
  7. 【UML】— 用例图
  8. div的display和visible的区别
  9. 设置Windows本地DNS域名解析hosts
  10. FROM_UNIXTIME()函数UNIX_TIMESTAMP()函数