tensorflow学习笔记(三十二):conv2d_transpose ("解卷积")

deconv解卷积,实际是叫做conv_transposeconv_transpose实际是卷积的一个逆向过程,tf 中, 编写conv_transpose代码的时候,心中想着一个正向的卷积过程会很有帮助。

想象一下我们有一个正向卷积: 
input_shape = [1,5,5,3] 
kernel_shape=[2,2,3,1] 
strides=[1,2,2,1] 
padding = "SAME"

那么,卷积激活后,我们会得到 x(就是上面代码的x)。那么,我们已知x,要想得到input_shape 形状的 tensor,我们应该如何使用conv2d_transpose函数呢? 
就用下面的代码

import tensorflow as tf
tf.set_random_seed(1)
x = tf.random_normal(shape=[1,3,3,1])
#正向卷积的kernel的模样
kernel = tf.random_normal(shape=[2,2,3,1])# strides 和padding也是假想中 正向卷积的模样。当然,x是正向卷积后的模样
y = tf.nn.conv2d_transpose(x,kernel,output_shape=[1,5,5,3],strides=[1,2,2,1],padding="SAME")
# 在这里,output_shape=[1,6,6,3]也可以,考虑正向过程,[1,6,6,3]
# 通过kernel_shape:[2,2,3,1],strides:[1,2,2,1]也可以
# 获得x_shape:[1,3,3,1]
# output_shape 也可以是一个 tensor
sess = tf.Session()
tf.global_variables_initializer().run(session=sess)print(y.eval(session=sess))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

conv2d_transpose 中会计算 output_shape 能否通过给定的参数计算出 inputs的维度,如果不能,则报错

import tensorflow as tf
from tensorflow.contrib import sliminputs = tf.random_normal(shape=[3, 97, 97, 10])conv1 = slim.conv2d(inputs, num_outputs=20, kernel_size=3, stride=4)de_weight = tf.get_variable('de_weight', shape=[3, 3, 10, 20])deconv1 = tf.nn.conv2d_transpose(conv1, filter=de_weight, output_shape=tf.shape(inputs),strides=[1, 3, 3, 1], padding='SAME')# ValueError: Shapes (3, 33, 33, 20) and (3, 25, 25, 20) are not compatible
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

上面错误的意思是:

  • conv1 的 shape 是 (3, 25, 25, 20)
  • 但是 deconv1 对 conv1 求导的时候,得到的导数 shape 却是 [3, 33, 33, 20],这个和 conv1 的shape 不匹配,当然要报错咯。
import tensorflow as tf
from tensorflow.contrib import slim
import numpy as npinputs = tf.placeholder(tf.float32, shape=[None, None, None, 3])conv1 = slim.conv2d(inputs, num_outputs=20, kernel_size=3, stride=4)de_weight = tf.get_variable('de_weight', shape=[3, 3, 3, 20])deconv1 = tf.nn.conv2d_transpose(conv1, filter=de_weight, output_shape=tf.shape(inputs),strides=[1, 3, 3, 1], padding='SAME')loss = deconv1 - inputs
train_op = tf.train.GradientDescentOptimizer(0.001).minimize(loss)with tf.Session() as sess:tf.global_variables_initializer().run()for i in range(10):data_in = np.random.normal(size=[3, 97, 97, 3])_, los_ = sess.run([train_op, loss], feed_dict={inputs: data_in})print(los_)
# InvalidArgumentError (see above for traceback): Conv2DSlowBackpropInput: Size of out_backprop doesn't match computed: actual = 25, computed = 33
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

如果 输入的 shape 有好多 None 的话,那就是另外一种 报错方式了,如上所示: 
这个错误的意思是:

  • conv1 的 shape 第二维或第三维的 shape 是 25
  • 但是 deconv1 对 conv1 求导的时候,得到的 倒数 shape 的第二位或第三维却是 33

至于为什么会这样,因为 deconv 的计算方式就是 conv 求导的计算方式,conv 的计算方式,就是 decov 求导的方式。

deconv 求导就相当于 拿着 conv_transpose 中的参数对 deconv 输出的值的导数做卷积。

如何灵活的控制 deconv 的output shape

在 conv2d_transpose() 中,有一个参数,叫 output_shape, 如果对它传入一个 int list 的话,那么在运行的过程中,output_shape 将无法改变(传入int list已经可以满足大部分应用的需要),但是如何更灵活的控制 output_shape 呢?

  • 传入 tensor
# 可以用 placeholder
outputs_shape = tf.placeholder(dtype=tf.int32, shape=[4])
deconv1 = tf.nn.conv2d_transpose(conv1, filter=de_weight, output_shape=output_shape,strides=[1, 3, 3, 1], padding='SAME')# 可以用 inputs 的shape,但是有点改变
inputs_shape = tf.shape(inputs)
outputs_shape = [inputs_shape[0], inputs_shape[1], inputs_shape[2], some_value]
deconv1 = tf.nn.conv2d_transpose(conv1, filter=de_weight, output_shape=outputs_shape,strides=[1, 3, 3, 1], padding='SAME')     

tensorflow学习笔记(三十二):conv2d_transpose (解卷积)相关推荐

  1. Mr.J-- jQuery学习笔记(三十二)--jQuery属性操作源码封装

    扫码看专栏 jQuery的优点 jquery是JavaScript库,能够极大地简化JavaScript编程,能够更方便的处理DOM操作和进行Ajax交互 1.轻量级 JQuery非常轻巧 2.强大的 ...

  2. tensorflow学习笔记(三十四):Saver(保存与加载模型)

    Saver tensorflow 中的 Saver 对象是用于 参数保存和恢复的.如何使用呢? 这里介绍了一些基本的用法. 官网中给出了这么一个例子: v1 = tf.Variable(..., na ...

  3. TensorFlow学习笔记(十二)TensorFLow tensorBoard 总结

    Tensorboard: 如何更直观的观察数据在神经网络中的变化,或是已经构建的神经网络的结构.上一篇文章说到,可以使用matplotlib第三方可视化,来进行一定程度上的可视化.然而Tensorfl ...

  4. Windows保护模式学习笔记(十二)—— 控制寄存器

    Windows保护模式学习笔记(十二)-- 控制寄存器 控制寄存器 Cr0寄存器 Cr2寄存器 Cr4寄存器 控制寄存器 描述: 控制寄存器有五个,分别是:Cr0 Cr1 Cr2 Cr3 Cr4 Cr ...

  5. OpenCV学习笔记(十二):边缘检测:Canny(),Sobel(),Laplace(),Scharr滤波器

    OpenCV学习笔记(十二):边缘检测:Canny(),Sobel(),Laplace(),Scharr滤波器 1)滤波:边缘检测的算法主要是基于图像强度的一阶和二阶导数,但导数通常对噪声很敏感,因此 ...

  6. MATLAB学习笔记(十二)

    MATLAB学习笔记(十二) 一.数据插值 1.1 数据插值的计算机制 1.2 数据插值的matlab函数 二.曲线拟合 2.1 曲线拟合原理 2.2 曲线拟合的实现方法 三.数据插值与曲线拟合比较 ...

  7. Spring Cloud学习笔记【十二】Hystrix的使用和了解

    Spring Cloud学习笔记[十二]Hystrix的使用和了解 Hystrix [hɪst'rɪks],中文含义是豪猪,因其背上长满棘刺,从而拥有了自我保护的能力.本文所说的Hystrix是Net ...

  8. 汇编入门学习笔记 (十二)—— int指令、port

    疯狂的暑假学习之  汇编入门学习笔记 (十二)--  int指令.port 參考: <汇编语言> 王爽 第13.14章 一.int指令 1. int指令引发的中断 int n指令,相当于引 ...

  9. QT学习笔记(十二):透明窗体设置

    QT学习笔记(十二):透明窗体设置 创建 My_Widget 类 基类为QWidget , My_Widget.cpp 源文件中添加代码 #include "widget.h" # ...

最新文章

  1. 最高标号预留与推进算法 --- 就是要比 Dinic 快!
  2. 直播 | AAAI 2021最佳论文:比Transformer更有效的长时间序列预测
  3. why My Lead OPA test add Lead fails
  4. 解决链接模型的可见性问题
  5. 用Python连接MySQL并进行CRUD
  6. mysql 分组查询原理,MySQL分組查詢Group By實現原理詳解
  7. Golang 二叉树系列【二叉树深度】
  8. Mono喜迁新家-http://www.xamarin.com/
  9. android 柱状图_安卓控件 仪表盘控件 柱状图控件 曲线控件 xamarin.android 分类器 瓶子控件 报警控件 水箱控件 进度条控件等...
  10. Linux查看域名对应的ip地址
  11. 8月前端挑战-----如何做到这个月内每天下班学习两小时
  12. Activiti教程(一)activiti工作流简介
  13. 微信小程序开发者工具详解
  14. nature 计算机论文,10分钟读懂6篇Nature/Science系列文章
  15. 装逼技能:怎样优雅地摆放桌面图标?
  16. 常用的三种机器学习预测方法
  17. Linux基础内容介绍
  18. 微信小程序成语小秀才,成语接龙超详细搭建教程
  19. 【Ubuntu】虚拟机屏幕大小共享文件
  20. [OHIF-Viewers]医疗数字阅片-医学影像-cornerstone-core-Cornerstone.js-Cornerstone Examples-基石实例-上...

热门文章

  1. leetcode28. Implement strStr() (以及个人对KMP算法理解)
  2. php判断字符串是否为IP,php 判断IP为有效IP地址的方法
  3. 用计算机进行图片处理教学设计,三年级信息技术上教学设计
  4. etcd 指定配置文件启动_5步完成 etcd 单机集群部署
  5. mega_[MEGA DEAL]带有Kotlin捆绑包的完整Android Oreo(95%折扣)
  6. lambdas_借助Java 8和lambdas,可以一起使用AssertJ和Awaitility
  7. python venv 复制_pythonenv的安装及迁移
  8. 凯撒密码C语言去掉空格字符,凯撒密码的问题C语言
  9. jexus php 重写,如何让我们的PHP在Jexus中跑起来
  10. Java中映射怎么实现_我们如何在Java 9的JShell中实现映射?