Tensorflow反卷积(DeConv)实现原理+手写python代码实现反卷积(DeConv)

理解:

https://www.zhihu.com/question/43609045/answer/130868981

上一篇文章已经介绍过卷积的实现,这篇文章我们学习反卷积原理,同样,在了解反卷积原理后,在后面手写python代码实现反卷积。

反卷积用途:
上采样,
gan,反卷积来生成图片

1 反卷积原理

反卷积原理不太好用文字描述,这里直接以一个简单例子描述反卷积过程。

假设输入如下:

[[1,0,1],[0,2,1],[1,1,0]]

反卷积卷积核如下:

[[ 1, 0, 1],[-1, 1, 0],[ 0,-1, 0]]

现在通过stride=2来进行反卷积,使得尺寸由原来的3*3变为6*6.那么在Tensorflow框架中,反卷积的过程如下(不同框架在裁剪这步可能不一样):


其实通过我绘制的这张图,就已经把原理讲的很清楚了。大致步奏就是,先填充0,然后进行卷积,卷积过程跟上一篇文章讲述的一致。最后一步还要进行裁剪。好了,原理讲完了,(#^.^#)….

2 代码实现

上一篇文章我们只针对了输出通道数为1进行代码实现,在这篇文章中,反卷积我们将输出通道设置为多个,这样更符合实际场景。

先定义输入和卷积核:

input_data=[[[1,0,1],[0,2,1],[1,1,0]],[[2,0,2],[0,1,0],[1,0,0]],[[1,1,1],[2,2,0],[1,1,1]],[[1,1,2],[1,0,1],[0,2,2]]]
weights_data=[ [[[ 1, 0, 1],[-1, 1, 0],[ 0,-1, 0]],[[-1, 0, 1],[ 0, 0, 1],[ 1, 1, 1]],[[ 0, 1, 1],[ 2, 0, 1],[ 1, 2, 1]], [[ 1, 1, 1],[ 0, 2, 1],[ 1, 0, 1]]],[[[ 1, 0, 2],[-2, 1, 1],[ 1,-1, 0]],[[-1, 0, 1],[-1, 2, 1],[ 1, 1, 1]],[[ 0, 0, 0],[ 2, 2, 1],[ 1,-1, 1]], [[ 2, 1, 1],[ 0,-1, 1],[ 1, 1, 1]]]  ]

上面定义的输入和卷积核,在接下的运算过程如下图所示:

可以看到实际上,反卷积和卷积基本一致,差别在于,反卷积需要填充过程,并在最后一步需要裁剪。具体实现代码如下:

#根据输入map([h,w])和卷积核([k,k]),计算卷积后的feature map
import numpy as np
def compute_conv(fm,kernel):[h,w]=fm.shape [k,_]=kernel.shape r=int(k/2)#定义边界填充0后的mappadding_fm=np.zeros([h+2,w+2],np.float32)#保存计算结果rs=np.zeros([h,w],np.float32) #将输入在指定该区域赋值,即除了4个边界后,剩下的区域padding_fm[1:h+1,1:w+1]=fm #对每个点为中心的区域遍历for i in range(1,h+1):for j in range(1,w+1): #取出当前点为中心的k*k区域roi=padding_fm[i-r:i+r+1,j-r:j+r+1]#计算当前点的卷积,对k*k个点点乘后求和rs[i-1][j-1]=np.sum(roi*kernel)return rs#填充0
def fill_zeros(input):[c,h,w]=input.shapers=np.zeros([c,h*2+1,w*2+1],np.float32)for i in range(c):for j in range(h):for k in range(w): rs[i,2*j+1,2*k+1]=input[i,j,k] return rsdef my_deconv(input,weights):#weights shape=[out_c,in_c,h,w][out_c,in_c,h,w]=weights.shape   out_h=h*2out_w=w*2rs=[]for i in range(out_c):w=weights[i]tmp=np.zeros([out_h,out_w],np.float32)for j in range(in_c):conv=compute_conv(input[j],w[j])#注意裁剪,最后一行和最后一列去掉tmp=tmp+conv[0:out_h,0:out_w]rs.append(tmp)return rs def main():  input=np.asarray(input_data,np.float32)input= fill_zeros(input)weights=np.asarray(weights_data,np.float32)deconv=my_deconv(input,weights)print(np.asarray(deconv))if __name__=='__main__':main()

计算卷积代码,跟上一篇文章一致。代码直接看注释,不再解释。运行结果如下:

[[[  4.   3.   6.   2.   7.   3.][  4.   3.   3.   2.   7.   5.][  8.   6.   8.   5.  11.   2.][  3.   2.   7.   2.   3.   3.][  5.   5.  11.   3.   9.   3.][  2.   1.   4.   5.   4.   4.]][[  4.   1.   7.   0.   7.   2.][  5.   6.   0.   1.   8.   5.][  8.   0.   8.  -2.  14.   2.][  3.   3.   9.   8.   1.   0.][  3.   0.  13.   0.  11.   2.][  3.   5.   3.   1.   3.   0.]]]

为了验证实现的代码的正确性,我们使用tensorflow的conv2d_transpose函数执行相同的输入和卷积核,看看结果是否一致。验证代码如下:

import tensorflow as tf
import numpy as np
def tf_conv2d_transpose(input,weights):#input_shape=[n,height,width,channel]input_shape = input.get_shape().as_list()#weights shape=[height,width,out_c,in_c]weights_shape=weights.get_shape().as_list() output_shape=[input_shape[0], input_shape[1]*2 , input_shape[2]*2 , weights_shape[2]]print("output_shape:",output_shape)deconv=tf.nn.conv2d_transpose(input,weights,output_shape=output_shape,strides=[1, 2, 2, 1], padding='SAME')return deconvdef main(): weights_np=np.asarray(weights_data,np.float32)#将输入的每个卷积核旋转180°weights_np=np.rot90(weights_np,2,(2,3))const_input = tf.constant(input_data , tf.float32)const_weights = tf.constant(weights_np , tf.float32 )input = tf.Variable(const_input,name="input")#[c,h,w]------>[h,w,c]input=tf.transpose(input,perm=(1,2,0))#[h,w,c]------>[n,h,w,c]input=tf.expand_dims(input,0)#weights shape=[out_c,in_c,h,w]weights = tf.Variable(const_weights,name="weights")#[out_c,in_c,h,w]------>[h,w,out_c,in_c]weights=tf.transpose(weights,perm=(2,3,0,1))#执行tensorflow的反卷积deconv=tf_conv2d_transpose(input,weights) init=tf.global_variables_initializer()sess=tf.Session()sess.run(init)deconv_val  = sess.run(deconv) hwc=deconv_val[0]print(hwc) if __name__=='__main__':main()

上面代码中,有几点需要注意:

  1. 每个卷积核需要旋转180°后,再传入tf.nn.conv2d_transpose函数中,因为tf.nn.conv2d_transpose内部会旋转180°,所以提前旋转,再经过内部旋转后,能保证卷积核跟我们所使用的卷积核的数据排列一致。
  2. 我们定义的输入的shape为[c,h,w]需要转为tensorflow所使用的[n,h,w,c]。
  3. 我们定义的卷积核shape为[out_c,in_c,h,w],需要转为tensorflow反卷积中所使用的[h,w,out_c,in_c]

执行上面代码后,执行结果如下:

[[  4.   3.   6.   2.   7.   3.][  4.   3.   3.   2.   7.   5.][  8.   6.   8.   5.  11.   2.][  3.   2.   7.   2.   3.   3.][  5.   5.  11.   3.   9.   3.][  2.   1.   4.   5.   4.   4.]]
[[  4.   1.   7.   0.   7.   2.][  5.   6.   0.   1.   8.   5.][  8.   0.   8.  -2.  14.   2.][  3.   3.   9.   8.   1.   0.][  3.   0.  13.   0.  11.   2.][  3.   5.   3.   1.   3.   0.]]

对比结果可以看到,数据是一致的,证明前面手写的python实现的反卷积代码是正确的。

python 反卷积(DeConv) tensorflow反卷积(DeConv)(实现原理+手写)相关推荐

  1. tensorflow应用:双向LSTM神经网络手写数字识别

    tensorflow应用:双向LSTM神经网络手写数字识别 思路 Python程序1.建模训练保存 Tensorboard检查计算图及训练结果 打开训练好的模型进行预测 思路 将28X28的图片看成2 ...

  2. TensorFlow实战之Softmax Regression识别手写数字

       本文根据最近学习TensorFlow书籍网络文章的情况,特将一些学习心得做了总结,详情如下.如有不当之处,请各位大拿多多指点,在此谢过. 一.相关概念 1.MNIST MNIST(Mixed N ...

  3. 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

    用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...

  4. Tensorflow反卷积(DeConv)实现原理+手写python代码实现反卷积(DeConv)

    最近看到一个巨牛的人工智能教程,分享一下给大家.教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家.平时碎片时间可以当小说看,[点这里可以去膜拜一下大神的" ...

  5. 深度篇—— CNN 卷积神经网络(四) 使用 tf cnn 进行 mnist 手写数字 代码演示项目

    返回主目录 返回 CNN 卷积神经网络目录 上一章:深度篇-- CNN 卷积神经网络(三) 关于 ROI pooling 和 ROI Align 与 插值 本小节,细说 使用 tf cnn 进行 mn ...

  6. 深度学习21天——卷积神经网络(CNN):实现mnist手写数字识别(第1天)

    目录 一.前期准备 1.1 环境配置 1.2 CPU和GPU 1.2.1 CPU 1.2.2 GPU 1.2.3 CPU和GPU的区别 第一步:设置GPU 1.3 MNIST 手写数字数据集 第二步: ...

  7. 【人工智能实验】卷积神经网络CNN框架的实现与应用-手写数字识别

    目录 实验六 卷积神经网络CNN框架的实现与应用 一.实验目的 二.实验原理 三.实验结果 1.调整学习率.epochs以及bacth_size这三个参数,分别观察参数的变化对于实验结果的影响. 2. ...

  8. 用tensorflow.js实现浏览器内的手写数字识别

    原文 简介 Tensorflow.js是google推出的一个开源的基于JavaScript的机器学习库,相对与基于其他语言的tersorflow库,它的最特别之处就是允许我们直接把模型的训练和数据预 ...

  9. TensorFlow神经网络(五)输入手写数字图片进行识别

    一.断点续训 为防止突然断电.参数白跑的情况发生,在backward中加入类似于之前test中加载ckpt的操作,给所有w和b赋保存在ckpt中的值: 1. 如果存储断点文件的目录文件夹中,包含有效断 ...

最新文章

  1. 试题 入门训练 Fibonacci数列(Java)
  2. 内核电源管理器已启动关机转换_Linux系统启动流程
  3. java 画布实验报告_编辑画布图像
  4. chrome更新flash player失败
  5. 创建线程的三种方法_Netty源码分析系列之NioEventLoop的创建与启动
  6. 微软五月份安全补丁发布
  7. usb连接不上 艾德克斯电源_硬核充电宝?360汽车应急电源入手体验
  8. 轻松四步配置Oracle数据库监听
  9. QEMU CVE-2020-14364 漏洞分析(含 PoC 演示)
  10. 事件库之Redis自己的事件模型-ae
  11. FATA[0000] (省略) Are you trying to connect to a TLS-enabled daemon without TLS?
  12. 数据结构乐智教学百度云_数据结构 百度网盘分享
  13. Quartus ii仿真界面闪退
  14. Windows - 安装/卸载服务 - 学习/实践
  15. 计算机导论高清课件教程,计算机导论-PPT课件
  16. Mac、M1怎么安装Maven
  17. 图解机器学习—算法原理与Python语言实现(文末留言送书)
  18. python 三角形雷达图,python 画雷达图
  19. 蓝色配色灵感 | 解读蓝色
  20. 贝叶斯公式求解公园凉鞋问题

热门文章

  1. 阿里资深技术专家总结:要怎样努力才可以成为公司主力架构师
  2. la是什么牌子_la clover兰卡文是什么牌子_哪个国家的_什么档次?
  3. Fabric ca学习笔记
  4. 懂得爱――在亲密关系中成长
  5. 十字军东征一些君主AI的对话
  6. 黑苹果虚拟机好用吗_苹果手机上有什么好用的工作提醒便签软件工具吗?
  7. VSCode下的51单片机开发环境搭建
  8. 百度地图开发者使用教程
  9. web课程设计:HTML非遗文化网页设计题材【京剧文化】HTML+CSS(大美中国 14页 带bootstarp)
  10. 数据结构与算法之排序(Java版)