【tensorflow】tf.nn.conv2d的使用
官方文档地址
接口如下
tf.nn.conv2d(
input, filters, strides, padding, data_format=‘NHWC’, dilations=None,
name=None
)
input数据
input为入参,其shape必须为4维的,其中每一维度含义如下
- N:Batch Number
- H:Height
- W:Width
- C:Num of Channles
但是顺序由data_format来指定,一般为"NHWC"或者"NCHW"
keras中的data_format一般为"NHWC"
filter数据
filter则为神经网络训练出来的kernel值,它为4维。
[filter_height, filter_width, in_channels, out_channels]
注:out_channels即为filter的个数
算法
压平filter为二维matrix,称为张量A
压平的目的shape为
[filter_height * filter_width * in_channels, output_channels]
将Input中元素转换为虚拟tensor,称为张量B
虚拟tensor的shape为
[batch, out_height, out_width, filter_height * filter_width * in_channels]
然后将上述两者做tenserdot运算,得到张量
C=BtensordotAC=B\ tensordot\ AC=B tensordot A
计算结果shape的公式
Batch个数和输入通道数是不变的,主要是输出的高与宽的值
out_width=(input_width−filter_width)/stride_width+1out_height=(input_height−filter_height)/stride_height+1out\_width = (input\_width-filter\_width)/stride\_width+1 \newline out\_height = (input\_height-filter\_height)/stride\_height+1 out_width=(input_width−filter_width)/stride_width+1out_height=(input_height−filter_height)/stride_height+1
所以输出shape为**(N, OUT_H, OUT_W, C)**
input shap到虚拟tensor shape举例
假设input data为 [1,2,3,4,5,6,7,8,9],shape为[1,3,3,1]
我们把数据下标标记好,那么所有数据如下
同样,假设filter为(1,2,2,1),使用该filter在input data上滑动截取[2,2],效果如下
根据上面介绍的shape计算公式,虚拟tensor的shape应该为(1,2,2,4),一共四个16个数据,我们穷举出所有的数值来看一下:
(0,0,x,x)
(0,1,x,x)
所以一共生成了16个数据,shape为(1,2,2,4)
代码样例
# Batch 1
# Height 3
# Width 3
# Channels 1
input_shape=(1,3,3,1)
x_in=np.linspace(1,9,num=9).reshape(input_shape)# Height 2
# Width 2
# in channels 1
# out channels 1
kernel_in = np.array([1, 2, 3, 4]).reshape((2,2,1,1))
print('kernel shape: ', kernel_in.shape)x = tf.constant(x_in, dtype=tf.float32)
kernel = tf.constant(kernel_in, dtype=tf.float32)
y = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')print("============================")
print(y)
【tensorflow】tf.nn.conv2d的使用相关推荐
- TensorFlow tf.nn.conv2d是怎样实现卷积的?
[TensorFlow]tf.nn.conv2d是怎样实现卷积的? 原文:http://blog.csdn.net/mao_xiao_feng/article/details/78004522 实验环 ...
- TensorFlow 从入门到精通(八):TensorFlow tf.nn.conv2d 一路追查
读者可能还记得本系列博客(二)和(六)中 tf.nn 模块,其中最关心的是 conv2d 这个函数. 首先将博客(二) MNIST 例程中 convolutional.py 关键源码列出: def m ...
- 【TensorFlow】理解tf.nn.conv2d方法 ( 附代码详解注释 )
最近在研究学习TensorFlow,在做识别手写数字的demo时,遇到了tf.nn.conv2d这个方法,查阅了官网的API 发现讲得比较简略,还是没理解.google了一下,参考了网上一些朋友写得博 ...
- 【TensorFlow】tf.nn.conv2d是怎样实现卷积的?
int height_col= (height + 2 * pad_h - kernel_h) / stride_h + 1; int width_col = (width + 2 * pad_w - ...
- tensorflow详解-tf.nn.conv2d(),tf.nn.max_pool()
tf.nn.conv2d() 函数来计算卷积,weights 作为滤波器,[1, 2, 2, 1] 作为 strides.TensorFlow 对每一个 input 维度使用一个单独的 stride ...
- TensorFlow学习笔记(十七)tf.nn.conv2d
在给定的4D input与filter下计算2D卷积输入shape为[batch, height, width, in_channels] TensorFlow的CNN代码中有 tf.nn.conv2 ...
- 【TensorFlow】TensorFlow函数精讲之tf.nn.conv2d()
博客之星评选,谢谢您的支持!微信.qq五连击投票(无需关注.无需登录) 人工智能博士(投票链接):http://m234140.nofollow.ax.mvote.cn/opage/4fddfa73- ...
- TensorFlow基础篇(七)——tf.nn.conv2d()
tf.nn.conv2d是TensorFlow里面实现卷积的函数,是搭建卷积神经网络比较核心的一个方法. 函数格式: tf.nn.conv2d(input, filter, strides, padd ...
- [TensorFlow 学习笔记-04]卷积函数之tf.nn.conv2d
[版权说明] TensorFlow 学习笔记参考: 李嘉璇 著 TensorFlow技术解析与实战 黄文坚 唐源 著 TensorFlow实战郑泽宇 顾思宇 著 TensorFlow实战Google ...
- TensorFlow学习——tf.nn.conv2d和tf.contrib.slim.conv2d的区别
在查看代码的时候,看到有代码用到卷积层是tf.nn.conv2d,也有的使用的卷积层是tf.contrib.slim.conv2d,这两个函数调用的卷积层是否一致,在查看了API的文档,以及slim. ...
最新文章
- python跟易语言那个写辅助_易语言写练练看辅助
- springMVC 返回类型选择 以及 SpringMVC中model,modelMap.request,session取值顺序
- python工程师月薪多少-Python全栈工程师为何这么火薪资这么高看了才知道
- Exchange 默认数据库删除问题
- python处理shp和栅格文件的相关库shapefile、gdal等
- C# 中使用log4net 日志记录
- eclipse安装cppcheck
- 剑指offer之对称的二叉树
- 下载配置Gradle
- 蓝牙AVRCP协议分析
- 《星科快报》第二期:元宇宙之道.
- LaTeX 常用符号命令大全
- laravel的Eloquent模型
- Tanzu系列:第8部分 - 创建tkg集群
- golang —— go语言科学记数法使用
- phpnow mysql升级_【php】升级phpnow1.5.6的Mysql
- 什么是rootkit
- 美赛论文Latex简易模板 | 快速上手(附注释)
- SpringMVC核心知识的梳理(现在都用SpringBoot了,但是SpringMVC还的学的扎实点,饮水思源)
- 数据智能公司袋鼠云完成 6000 万元 A 轮融资
热门文章
- Windows下有关NDK安装出现的问题的总结
- linux 线程库在哪里,linux线程库
- 类型数据合并去重 mysql_MySQL基础知识 数据类型和数据表管理
- qchart画完以后删除_Unity2019基础教程:TileMap搭建像素画场景关卡
- 在线作图丨高级的微生物分析——在线做Variance Partitioning Analysis(VPA分析)
- gg.gap:ggplot阶截断坐标轴的优秀完美解决方案
- Briefings in Bioinformatics:微生物基因组学和功能基因组学相关软件和数据库的研究进展
- Cytoscape: MCODE增强包的网络模块化分析
- seaborn可视化绘制双变量分组条形图(Customizing Annotation of Bars: Side-by-side)、添加数值标签进行标记、并自定义条形图数值标签的格式
- python使用np.linspace函数生成均匀的浮点数列表实战:生成浮点数列表、生成浮点数列表(指定是否包含末尾值)