这里写自定义目录标题

  • Tensorflow 中padding 的“SAME” 和“VALID” 详解
    • “SAME”
    • “VALID”
    • 总结

Tensorflow 中padding 的“SAME” 和“VALID” 详解

最近在准备复习找工作,边实习边耍牛客网,然后刷到vivo校招的题目,里面有一个神经网络参数的计算,由于我之前一直迷迷糊糊,所以打算好好了解一下,所以就使用tensorflow 来详细了解convolution 里面的padding 。
首先使用tensorflow代码来简单的实现一下这个卷积过程:

import tensorflow as tf
import keras# 首先,模拟输入一个图像矩阵,大小为5*5
# 输入图像矩阵的shape为[批次大小,图像的高度,图像的宽度,图像的通道数]
inp=tf.Variable(tf.random_normal([1,11,11,1]))#filter: A Tensor. Must have the same type as input. A 4-D tensor of shape [filter_height,filter_width,in_channels, out_channels]
# 卷积核的shape为[卷积核的高度,卷积核的宽度,图像通道数,卷积核的个数]
fil=tf.Variable(tf.random_normal([3,3,1,1]))result=tf.nn.conv2d(inp,fil,strides=[1,1,1,1],padding='VALID')#VALID  FULL
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
#print(sess.run(result))
print(result.shape)
sess.close()

根据不同的padding mode ,不同的input size ,不同的stride,得到下面的结果:

“SAME”

当filter的中心(K)与image的边角重合时,开始做卷积运算:(其中不同的颜色代表每一个stride,后面的颜色把前面的颜色覆盖了,k1,k2,这些代表filter中心元素的位置)

“VALID”

当filter全部在image里面的时候,进行卷积运算。

总结

这里需要注意的是,输入图片的tensor shape和filter 的tensor shape是不一样的,输入图像矩阵的shape为:
[批次大小,图像的高度,图像的宽度,图像的通道数]
而filter 的tensor shape 是:
[卷积核的高度,卷积核的宽度,图像通道数,卷积核的个数]

最后,如果真的遇见什么问题,希望查找官网的api 或者是本地的相关python 源码,比起度娘或许更加能够解决你的问题!

再次最后,给大家帖出源码的一部分用作参考:

def convolution(input,  # pylint: disable=redefined-builtinfilter,  # pylint: disable=redefined-builtinpadding,strides=None,dilation_rate=None,name=None,data_format=None):# pylint: disable=line-too-long"""Computes sums of N-D convolutions (actually cross-correlation).This also supports either output striding via the optional `strides` parameteror atrous convolution (also known as convolution with holes or dilatedconvolution, based on the French word "trous" meaning holes in English) viathe optional `dilation_rate` parameter.  Currently, however, output stridingis not supported for atrous convolutions.Specifically, in the case that `data_format` does not start with "NC", givena rank (N+2) `input` Tensor of shape[num_batches,input_spatial_shape[0],...,input_spatial_shape[N-1],num_input_channels],a rank (N+2) `filter` Tensor of shape[spatial_filter_shape[0],...,spatial_filter_shape[N-1],num_input_channels,num_output_channels],an optional `dilation_rate` tensor of shape [N] (defaulting to [1]*N)specifying the filter upsampling/input downsampling rate, and an optional listof N `strides` (defaulting [1]*N), this computes for each N-D spatial outputposition (x[0], ..., x[N-1]):output[b, x[0], ..., x[N-1], k] =sum_{z[0], ..., z[N-1], q}filter[z[0], ..., z[N-1], q, k] *padded_input[b,x[0]*strides[0] + dilation_rate[0]*z[0],...,x[N-1]*strides[N-1] + dilation_rate[N-1]*z[N-1],q]where b is the index into the batch, k is the output channel number, q is theinput channel number, and z is the N-D spatial offset within the filter. Here,`padded_input` is obtained by zero padding the input using an effectivespatial filter shape of `(spatial_filter_shape-1) * dilation_rate + 1` andoutput striding `strides` as described in the[comment here](https://tensorflow.org/api_guides/python/nn#Convolution).In the case that `data_format` does start with `"NC"`, the `input` and output(but not the `filter`) are simply transposed as follows:convolution(input, data_format, **kwargs) =tf.transpose(convolution(tf.transpose(input, [0] + range(2,N+2) + [1]),**kwargs),[0, N+1] + range(1, N+1))It is required that 1 <= N <= 3.Args:input: An (N+2)-D `Tensor` of type `T`, of shape`[batch_size] + input_spatial_shape + [in_channels]` if data_format doesnot start with "NC" (default), or`[batch_size, in_channels] + input_spatial_shape` if data_format startswith "NC".filter: An (N+2)-D `Tensor` with the same type as `input` and shape`spatial_filter_shape + [in_channels, out_channels]`.padding: A string, either `"VALID"` or `"SAME"`. The padding algorithm.strides: Optional.  Sequence of N ints >= 1.  Specifies the output stride.Defaults to [1]*N.  If any value of strides is > 1, then all values ofdilation_rate must be 1.dilation_rate: Optional.  Sequence of N ints >= 1.  Specifies the filterupsampling/input downsampling rate.  In the literature, the same parameteris sometimes called `input stride` or `dilation`.  The effective filtersize used for the convolution will be `spatial_filter_shape +(spatial_filter_shape - 1) * (rate - 1)`, obtained by inserting(dilation_rate[i]-1) zeros between consecutive elements of the originalfilter in each spatial dimension i.  If any value of dilation_rate is > 1,then all values of strides must be 1.name: Optional name for the returned tensor.data_format: A string or None.  Specifies whether the channel dimension ofthe `input` and output is the last dimension (default, or if `data_format`does not start with "NC"), or the second dimension (if `data_format`starts with "NC").  For N=1, the valid values are "NWC" (default) and"NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".For N=3, the valid values are "NDHWC" (default) and "NCDHW".Returns:A `Tensor` with the same type as `input` of shape`[batch_size] + output_spatial_shape + [out_channels]`if data_format is None or does not start with "NC", or`[batch_size, out_channels] + output_spatial_shape`if data_format starts with "NC",where `output_spatial_shape` depends on the value of `padding`.If padding == "SAME":output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])If padding == "VALID":output_spatial_shape[i] =ceil((input_spatial_shape[i] -(spatial_filter_shape[i]-1) * dilation_rate[i])/ strides[i]).Raises:ValueError: If input/output depth does not match `filter` shape, if paddingis other than `"VALID"` or `"SAME"`, or if data_format is invalid."""# pylint: enable=line-too-longwith ops.name_scope(name, "convolution", [input, filter]) as name:input = ops.convert_to_tensor(input, name="input")  # pylint: disable=redefined-builtininput_shape = input.get_shape()filter = ops.convert_to_tensor(filter, name="filter")  # pylint: disable=redefined-builtinfilter_shape = filter.get_shape()op = Convolution(input_shape,filter_shape,padding,strides=strides,dilation_rate=dilation_rate,name=name,data_format=data_format)return op(input, filter)

源码的注释里面明确的写明了最后输出的大小:

源码的本地地址在:

欢迎批评指正:tangliu625@163.com

Tensorflow 中padding 的“SAME” 和“VALID” 详解相关推荐

  1. python virtualenv conda_在vscode中启动conda虚拟环境的思路详解

    问题:cudatoolkit cudnn 通过conda 虚拟环境安装,先前已经使用virtualenv安装tf,需要在conda虚拟环境中启动外部python虚拟环境 思路:conda prompt ...

  2. 学习笔记-Flutter 布局(二)- Padding、Align、Center详解

    Flutter 布局(二)- Padding.Align.Center详解 本文主要介绍Flutter布局中的Padding.Align以及Center控件,详细介绍了其布局行为以及使用场景,并对源码 ...

  3. python argv 详解_对python中的argv和argc使用详解

    主要问题 为什么argv中第一个,即index=0的内容就是文件名? python中argc是用什么实现的? 概念解释 argc:argument counter,命令行参数个数 argv:argum ...

  4. yii mysql 事务处理_Yii2中事务的使用实例代码详解

    前言 一般我们做业务逻辑,都不会仅仅关联一个数据表,所以,会面临事务问题. 数据库事务(Database Transaction) ,是指作为单个逻辑工作单元执行的一系列操作,要么完全地执行,要么完全 ...

  5. Python中的__name__和__main__含义详解

    背景 在写Python代码和看Python代码时,我们常常可以看到这样的代码: ? 1 2 3 4 5 def main():     ...... if __name__ == "__ma ...

  6. python时间函数报错_python3中datetime库,time库以及pandas中的时间函数区别与详解...

    1介绍datetime库之前 我们先比较下time库和datetime库的区别 先说下time 在 Python 文档里,time是归类在Generic Operating System Servic ...

  7. 关于numpy中eye和identity的区别详解

    https://www.jb51.net/article/175386.htm np.identity(n, dtype=None) np.eye(N, M=None, k=0, dtype=< ...

  8. vue ajax highcharts,在vue项目中引入highcharts图表的方法(详解)

    npm进行highchars的导入,导入完成后就可以进行highchars的可视化组件开发了 npm install highcharts --save 1.components目录下新建一个char ...

  9. android中怎么网络判断,Android中判断网络是否连接实例详解

    Android中判断网络是否连接实例详解 在android中,如何监测网络的状态呢,这个有的时候也是十分重要的,方法如下: public class ConnectionDetector { priv ...

  10. Linux中history历史命令使用方法详解

    在/etc/profile里添加如下:#History export HISTTIMEFORMAT="[%F %T]" HISTDIR=/home/common/.hist if ...

最新文章

  1. android PhotoView的用法
  2. 【青少年编程】【三级】病毒传染模拟程序
  3. 50倍时空算力提升,阿里云RDS PostgreSQL GPU版本上线
  4. 杭电ACM刷题(1):1002,A + B Problem II
  5. Binder fuzz安全研究
  6. c语言switch comiti,国际经济学作业复习资料第三章.docx
  7. TensorFlow——Ubuntu系统上TensorFlow的安装教程
  8. 【java】随机数的阶乘
  9. Java和Go的GC差异
  10. C++编程问题--glibc detected *** ./a.out: munmap_chunk(): invalid pointer: xxxxxx
  11. 软件设计原则(四) 里氏替换原则
  12. perl 语言中的q,qw,qr,qx,qq符号用法总结
  13. 御剑php字典,Newyujian1.5 御剑源码 主要用于扫描网站目录 - 下载 - 搜珍网
  14. 数据库中的左连接和右连接的区别
  15. 搭建gos_如何将记录器注入gos http处理程序
  16. 原生js高仿浏览器ctrf+f
  17. 她力量系列七丨兰艳艳:理想温暖10年科研路,女性可以柔和,更要自信、专业 | 妇女节特辑
  18. 如何缓解百度网盘限速问题
  19. VBA—压缩文件夹成一个rar压缩包
  20. Python数据分析和挖掘之入门理论+实操

热门文章

  1. 勒索病毒WannaCry深度技术分析:详解传播、感染和危害细节
  2. 企业成本核算程序是怎样?一般采用什么方法
  3. 英特尔核显自定义分辨率_英特尔核芯显卡设置如何操作【图文】
  4. 15款最好用的新浪短链接(t.cn接口)在线生成工具
  5. 解决vs中没有为 VSFilter.dll 加载的符号文件问题
  6. ddos硬件防火墙(DDOS硬件防火墙)
  7. 帆软实现分页时第一行和最后两行冻结方式
  8. 吴伯凡-认知方法论-原始舒适区=0认知
  9. 3Ds Max动画课程设计
  10. ThreadPoolExecutor线程池终止