官方文档地址

接口如下

tf.nn.conv2d(
input, filters, strides, padding, data_format=‘NHWC’, dilations=None,
name=None
)

input数据

input为入参,其shape必须为4维的,其中每一维度含义如下

  • N:Batch Number
  • H:Height
  • W:Width
  • C:Num of Channles

但是顺序由data_format来指定,一般为"NHWC"或者"NCHW"

keras中的data_format一般为"NHWC"

filter数据

filter则为神经网络训练出来的kernel值,它为4维。
[filter_height, filter_width, in_channels, out_channels]

注:out_channels即为filter的个数

算法

  1. 压平filter为二维matrix,称为张量A
    压平的目的shape为
    [filter_height * filter_width * in_channels, output_channels]

  2. 将Input中元素转换为虚拟tensor,称为张量B
    虚拟tensor的shape为
    [batch, out_height, out_width, filter_height * filter_width * in_channels]

  3. 然后将上述两者做tenserdot运算,得到张量
    C=BtensordotAC=B\ tensordot\ AC=B tensordot A

计算结果shape的公式

Batch个数和输入通道数是不变的,主要是输出的高与宽的值
out_width=(input_width−filter_width)/stride_width+1out_height=(input_height−filter_height)/stride_height+1out\_width = (input\_width-filter\_width)/stride\_width+1 \newline out\_height = (input\_height-filter\_height)/stride\_height+1 out_width=(input_width−filter_width)/stride_width+1out_height=(input_height−filter_height)/stride_height+1
所以输出shape为**(N, OUT_H, OUT_W, C)**


input shap到虚拟tensor shape举例

假设input data为 [1,2,3,4,5,6,7,8,9],shape为[1,3,3,1]

我们把数据下标标记好,那么所有数据如下

同样,假设filter为(1,2,2,1),使用该filter在input data上滑动截取[2,2],效果如下

根据上面介绍的shape计算公式,虚拟tensor的shape应该为(1,2,2,4),一共四个16个数据,我们穷举出所有的数值来看一下:

(0,0,x,x)

(0,1,x,x)

所以一共生成了16个数据,shape为(1,2,2,4)


代码样例

# Batch 1
# Height 3
# Width 3
# Channels 1
input_shape=(1,3,3,1)
x_in=np.linspace(1,9,num=9).reshape(input_shape)# Height 2
# Width 2
# in channels 1
# out channels 1
kernel_in = np.array([1, 2, 3, 4]).reshape((2,2,1,1))
print('kernel shape: ', kernel_in.shape)x = tf.constant(x_in, dtype=tf.float32)
kernel = tf.constant(kernel_in, dtype=tf.float32)
y = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')print("============================")
print(y)

【tensorflow】tf.nn.conv2d的使用相关推荐

  1. TensorFlow tf.nn.conv2d是怎样实现卷积的?

    [TensorFlow]tf.nn.conv2d是怎样实现卷积的? 原文:http://blog.csdn.net/mao_xiao_feng/article/details/78004522 实验环 ...

  2. TensorFlow 从入门到精通(八):TensorFlow tf.nn.conv2d 一路追查

    读者可能还记得本系列博客(二)和(六)中 tf.nn 模块,其中最关心的是 conv2d 这个函数. 首先将博客(二) MNIST 例程中 convolutional.py 关键源码列出: def m ...

  3. 【TensorFlow】理解tf.nn.conv2d方法 ( 附代码详解注释 )

    最近在研究学习TensorFlow,在做识别手写数字的demo时,遇到了tf.nn.conv2d这个方法,查阅了官网的API 发现讲得比较简略,还是没理解.google了一下,参考了网上一些朋友写得博 ...

  4. 【TensorFlow】tf.nn.conv2d是怎样实现卷积的?

    int height_col= (height + 2 * pad_h - kernel_h) / stride_h + 1; int width_col = (width + 2 * pad_w - ...

  5. tensorflow详解-tf.nn.conv2d(),tf.nn.max_pool()

    tf.nn.conv2d() 函数来计算卷积,weights 作为滤波器,[1, 2, 2, 1] 作为 strides.TensorFlow 对每一个 input 维度使用一个单独的 stride ...

  6. TensorFlow学习笔记(十七)tf.nn.conv2d

    在给定的4D input与filter下计算2D卷积输入shape为[batch, height, width, in_channels] TensorFlow的CNN代码中有 tf.nn.conv2 ...

  7. 【TensorFlow】TensorFlow函数精讲之tf.nn.conv2d()

    博客之星评选,谢谢您的支持!微信.qq五连击投票(无需关注.无需登录) 人工智能博士(投票链接):http://m234140.nofollow.ax.mvote.cn/opage/4fddfa73- ...

  8. TensorFlow基础篇(七)——tf.nn.conv2d()

    tf.nn.conv2d是TensorFlow里面实现卷积的函数,是搭建卷积神经网络比较核心的一个方法. 函数格式: tf.nn.conv2d(input, filter, strides, padd ...

  9. [TensorFlow 学习笔记-04]卷积函数之tf.nn.conv2d

    [版权说明] TensorFlow 学习笔记参考: 李嘉璇 著 TensorFlow技术解析与实战 黄文坚 唐源 著 TensorFlow实战郑泽宇  顾思宇 著 TensorFlow实战Google ...

  10. TensorFlow学习——tf.nn.conv2d和tf.contrib.slim.conv2d的区别

    在查看代码的时候,看到有代码用到卷积层是tf.nn.conv2d,也有的使用的卷积层是tf.contrib.slim.conv2d,这两个函数调用的卷积层是否一致,在查看了API的文档,以及slim. ...

最新文章

  1. python跟易语言那个写辅助_易语言写练练看辅助
  2. springMVC 返回类型选择 以及 SpringMVC中model,modelMap.request,session取值顺序
  3. python工程师月薪多少-Python全栈工程师为何这么火薪资这么高看了才知道
  4. Exchange 默认数据库删除问题
  5. python处理shp和栅格文件的相关库shapefile、gdal等
  6. C# 中使用log4net 日志记录
  7. eclipse安装cppcheck
  8. 剑指offer之对称的二叉树
  9. 下载配置Gradle
  10. 蓝牙AVRCP协议分析
  11. 《星科快报》第二期:元宇宙之道.
  12. LaTeX 常用符号命令大全
  13. laravel的Eloquent模型
  14. Tanzu系列:第8部分 - 创建tkg集群
  15. golang —— go语言科学记数法使用
  16. phpnow mysql升级_【php】升级phpnow1.5.6的Mysql
  17. 什么是rootkit
  18. 美赛论文Latex简易模板 | 快速上手(附注释)
  19. SpringMVC核心知识的梳理(现在都用SpringBoot了,但是SpringMVC还的学的扎实点,饮水思源)
  20. 数据智能公司袋鼠云完成 6000 万元 A 轮融资

热门文章

  1. Windows下有关NDK安装出现的问题的总结
  2. linux 线程库在哪里,linux线程库
  3. 类型数据合并去重 mysql_MySQL基础知识 数据类型和数据表管理
  4. qchart画完以后删除_Unity2019基础教程:TileMap搭建像素画场景关卡
  5. 在线作图丨高级的微生物分析——在线做Variance Partitioning Analysis(VPA分析)
  6. gg.gap:ggplot阶截断坐标轴的优秀完美解决方案
  7. Briefings in Bioinformatics:微生物基因组学和功能基因组学相关软件和数据库的研究进展
  8. Cytoscape: MCODE增强包的网络模块化分析
  9. seaborn可视化绘制双变量分组条形图(Customizing Annotation of Bars: Side-by-side)、添加数值标签进行标记、并自定义条形图数值标签的格式
  10. python使用np.linspace函数生成均匀的浮点数列表实战:生成浮点数列表、生成浮点数列表(指定是否包含末尾值)