tensorflow学习笔记(三十二):conv2d_transpose (解卷积)
tensorflow学习笔记(三十二):conv2d_transpose ("解卷积")
deconv
解卷积,实际是叫做conv_transpose
, conv_transpose
实际是卷积的一个逆向过程,tf
中, 编写conv_transpose
代码的时候,心中想着一个正向的卷积过程会很有帮助。
想象一下我们有一个正向卷积:
input_shape = [1,5,5,3]
kernel_shape=[2,2,3,1]
strides=[1,2,2,1]
padding = "SAME"
那么,卷积激活后,我们会得到 x(就是上面代码的x)。那么,我们已知x,要想得到input_shape 形状的 tensor,我们应该如何使用conv2d_transpose
函数呢?
就用下面的代码
import tensorflow as tf
tf.set_random_seed(1)
x = tf.random_normal(shape=[1,3,3,1])
#正向卷积的kernel的模样
kernel = tf.random_normal(shape=[2,2,3,1])# strides 和padding也是假想中 正向卷积的模样。当然,x是正向卷积后的模样
y = tf.nn.conv2d_transpose(x,kernel,output_shape=[1,5,5,3],strides=[1,2,2,1],padding="SAME")
# 在这里,output_shape=[1,6,6,3]也可以,考虑正向过程,[1,6,6,3]
# 通过kernel_shape:[2,2,3,1],strides:[1,2,2,1]也可以
# 获得x_shape:[1,3,3,1]
# output_shape 也可以是一个 tensor
sess = tf.Session()
tf.global_variables_initializer().run(session=sess)print(y.eval(session=sess))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
conv2d_transpose 中会计算 output_shape 能否通过给定的参数计算出 inputs的维度,如果不能,则报错
import tensorflow as tf
from tensorflow.contrib import sliminputs = tf.random_normal(shape=[3, 97, 97, 10])conv1 = slim.conv2d(inputs, num_outputs=20, kernel_size=3, stride=4)de_weight = tf.get_variable('de_weight', shape=[3, 3, 10, 20])deconv1 = tf.nn.conv2d_transpose(conv1, filter=de_weight, output_shape=tf.shape(inputs),strides=[1, 3, 3, 1], padding='SAME')# ValueError: Shapes (3, 33, 33, 20) and (3, 25, 25, 20) are not compatible
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
上面错误的意思是:
- conv1 的 shape 是 (3, 25, 25, 20)
- 但是 deconv1 对 conv1 求导的时候,得到的导数 shape 却是 [3, 33, 33, 20],这个和
conv1
的shape 不匹配,当然要报错咯。
import tensorflow as tf
from tensorflow.contrib import slim
import numpy as npinputs = tf.placeholder(tf.float32, shape=[None, None, None, 3])conv1 = slim.conv2d(inputs, num_outputs=20, kernel_size=3, stride=4)de_weight = tf.get_variable('de_weight', shape=[3, 3, 3, 20])deconv1 = tf.nn.conv2d_transpose(conv1, filter=de_weight, output_shape=tf.shape(inputs),strides=[1, 3, 3, 1], padding='SAME')loss = deconv1 - inputs
train_op = tf.train.GradientDescentOptimizer(0.001).minimize(loss)with tf.Session() as sess:tf.global_variables_initializer().run()for i in range(10):data_in = np.random.normal(size=[3, 97, 97, 3])_, los_ = sess.run([train_op, loss], feed_dict={inputs: data_in})print(los_)
# InvalidArgumentError (see above for traceback): Conv2DSlowBackpropInput: Size of out_backprop doesn't match computed: actual = 25, computed = 33
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
如果 输入的 shape 有好多 None
的话,那就是另外一种 报错方式了,如上所示:
这个错误的意思是:
conv1
的 shape 第二维或第三维的 shape 是 25- 但是 deconv1 对 conv1 求导的时候,得到的 倒数 shape 的第二位或第三维却是 33
至于为什么会这样,因为 deconv
的计算方式就是 conv
求导的计算方式,conv
的计算方式,就是 decov
求导的方式。
对deconv
求导就相当于 拿着 conv_transpose
中的参数对 deconv
输出的值的导数做卷积。
如何灵活的控制 deconv 的output shape
在 conv2d_transpose()
中,有一个参数,叫 output_shape
, 如果对它传入一个 int list 的话,那么在运行的过程中,output_shape
将无法改变(传入int list
已经可以满足大部分应用的需要),但是如何更灵活的控制 output_shape
呢?
- 传入
tensor
# 可以用 placeholder
outputs_shape = tf.placeholder(dtype=tf.int32, shape=[4])
deconv1 = tf.nn.conv2d_transpose(conv1, filter=de_weight, output_shape=output_shape,strides=[1, 3, 3, 1], padding='SAME')# 可以用 inputs 的shape,但是有点改变
inputs_shape = tf.shape(inputs)
outputs_shape = [inputs_shape[0], inputs_shape[1], inputs_shape[2], some_value]
deconv1 = tf.nn.conv2d_transpose(conv1, filter=de_weight, output_shape=outputs_shape,strides=[1, 3, 3, 1], padding='SAME')
tensorflow学习笔记(三十二):conv2d_transpose (解卷积)相关推荐
- Mr.J-- jQuery学习笔记(三十二)--jQuery属性操作源码封装
扫码看专栏 jQuery的优点 jquery是JavaScript库,能够极大地简化JavaScript编程,能够更方便的处理DOM操作和进行Ajax交互 1.轻量级 JQuery非常轻巧 2.强大的 ...
- tensorflow学习笔记(三十四):Saver(保存与加载模型)
Saver tensorflow 中的 Saver 对象是用于 参数保存和恢复的.如何使用呢? 这里介绍了一些基本的用法. 官网中给出了这么一个例子: v1 = tf.Variable(..., na ...
- TensorFlow学习笔记(十二)TensorFLow tensorBoard 总结
Tensorboard: 如何更直观的观察数据在神经网络中的变化,或是已经构建的神经网络的结构.上一篇文章说到,可以使用matplotlib第三方可视化,来进行一定程度上的可视化.然而Tensorfl ...
- Windows保护模式学习笔记(十二)—— 控制寄存器
Windows保护模式学习笔记(十二)-- 控制寄存器 控制寄存器 Cr0寄存器 Cr2寄存器 Cr4寄存器 控制寄存器 描述: 控制寄存器有五个,分别是:Cr0 Cr1 Cr2 Cr3 Cr4 Cr ...
- OpenCV学习笔记(十二):边缘检测:Canny(),Sobel(),Laplace(),Scharr滤波器
OpenCV学习笔记(十二):边缘检测:Canny(),Sobel(),Laplace(),Scharr滤波器 1)滤波:边缘检测的算法主要是基于图像强度的一阶和二阶导数,但导数通常对噪声很敏感,因此 ...
- MATLAB学习笔记(十二)
MATLAB学习笔记(十二) 一.数据插值 1.1 数据插值的计算机制 1.2 数据插值的matlab函数 二.曲线拟合 2.1 曲线拟合原理 2.2 曲线拟合的实现方法 三.数据插值与曲线拟合比较 ...
- Spring Cloud学习笔记【十二】Hystrix的使用和了解
Spring Cloud学习笔记[十二]Hystrix的使用和了解 Hystrix [hɪst'rɪks],中文含义是豪猪,因其背上长满棘刺,从而拥有了自我保护的能力.本文所说的Hystrix是Net ...
- 汇编入门学习笔记 (十二)—— int指令、port
疯狂的暑假学习之 汇编入门学习笔记 (十二)-- int指令.port 參考: <汇编语言> 王爽 第13.14章 一.int指令 1. int指令引发的中断 int n指令,相当于引 ...
- QT学习笔记(十二):透明窗体设置
QT学习笔记(十二):透明窗体设置 创建 My_Widget 类 基类为QWidget , My_Widget.cpp 源文件中添加代码 #include "widget.h" # ...
最新文章
- 最高标号预留与推进算法 --- 就是要比 Dinic 快!
- 直播 | AAAI 2021最佳论文:比Transformer更有效的长时间序列预测
- why My Lead OPA test add Lead fails
- 解决链接模型的可见性问题
- 用Python连接MySQL并进行CRUD
- mysql 分组查询原理,MySQL分組查詢Group By實現原理詳解
- Golang 二叉树系列【二叉树深度】
- Mono喜迁新家-http://www.xamarin.com/
- android 柱状图_安卓控件 仪表盘控件 柱状图控件 曲线控件 xamarin.android 分类器 瓶子控件 报警控件 水箱控件 进度条控件等...
- Linux查看域名对应的ip地址
- 8月前端挑战-----如何做到这个月内每天下班学习两小时
- Activiti教程(一)activiti工作流简介
- 微信小程序开发者工具详解
- nature 计算机论文,10分钟读懂6篇Nature/Science系列文章
- 装逼技能:怎样优雅地摆放桌面图标?
- 常用的三种机器学习预测方法
- Linux基础内容介绍
- 微信小程序成语小秀才,成语接龙超详细搭建教程
- 【Ubuntu】虚拟机屏幕大小共享文件
- [OHIF-Viewers]医疗数字阅片-医学影像-cornerstone-core-Cornerstone.js-Cornerstone Examples-基石实例-上...
热门文章
- leetcode28. Implement strStr() (以及个人对KMP算法理解)
- php判断字符串是否为IP,php 判断IP为有效IP地址的方法
- 用计算机进行图片处理教学设计,三年级信息技术上教学设计
- etcd 指定配置文件启动_5步完成 etcd 单机集群部署
- mega_[MEGA DEAL]带有Kotlin捆绑包的完整Android Oreo(95%折扣)
- lambdas_借助Java 8和lambdas,可以一起使用AssertJ和Awaitility
- python venv 复制_pythonenv的安装及迁移
- 凯撒密码C语言去掉空格字符,凯撒密码的问题C语言
- jexus php 重写,如何让我们的PHP在Jexus中跑起来
- Java中映射怎么实现_我们如何在Java 9的JShell中实现映射?