tf.keras CNN网络搭建笔记

这里写目录标题

  • tf.keras CNN网络搭建笔记
  • 基本流程,以LeNet为例
    • 创建Sequential模型
    • 配置模型的学习流程
    • 数据预处理
    • 模型训练与验证
  • 相关函数注释
    • Conv2D
    • MaxPooling2D
  • 其他操作
    • 自定义卷积层
    • BN层
    • Dropout

基本流程,以LeNet为例

创建Sequential模型

创建Sequential模型,并添加相应神经层

model = tf.keras.Sequential([# 卷积核数量为6,大小为3*3keras.layers.Conv2D(6, 3),# strides步长keras.layers.MaxPooling2D(pool_size=2, strides=2),keras.layers.ReLU(),keras.layers.Conv2D(16, 3),keras.layers.MaxPooling2D(pool_size=2, strides=2),keras.layers.ReLU(),# 矩阵数据拉平keras.layers.Flatten(),keras.layers.Dense(120, activation='relu'),keras.layers.Dense(84, activation='relu'),keras.layers.Dense(10, activation='softmax')
])model.build(input_shape=(batch, 28, 28, 1))

也可将

keras.layers.Dense(10, activation='softmax')

移出,在Sequential外用以下代替

model.add(keras.layers.Dense(10, activation='softmax'))

配置模型的学习流程

model.compile(optimizer = 优化器, loss = 损失函数, metrics = ["准确率”])

model.compile(optimizer=keras.optimizers.Adam(),loss = keras.losses.CategoricalCrossentropy(),metrics = ['accuracy']
)

数据预处理

tf.data.Dataset.from_tensor_slices() 函数对数据集切片
shuffle() 打乱数据,参数为样本数
batch() 函数设置 batch size 值
map() 函数进行预处理

def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255x = tf.reshape(x, [-1, 28, 28, 1])y = tf.one_hot(y, depth=10)return x, ytrain_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(10000)
train_db = train_db.batch(128)
train_db = train_db.map(preprocess)test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(10000)
test_db = test_db.batch(128)
test_db = test_db.map(preprocess)

模型训练与验证

# 训练
model.fit(train_db, epochs=5)
# 验证
model.evaluate(test_db)

相关函数注释

Conv2D

tf.keras.layers.Conv2D(filters, kernel_size, strides=(1, 1), padding='valid',data_format=None, dilation_rate=(1, 1), groups=1, activation=None,use_bias=True, kernel_initializer='glorot_uniform',bias_initializer='zeros', kernel_regularizer=None,bias_regularizer=None, activity_regularizer=None, kernel_constraint=None,bias_constraint=None, **kwargs
)

filters:卷积核通道数

kernel_size:卷积核大小,用2个整数的元组或列表表示,比如(3,3),[5,5]

strides:步长

padding可选“valid”和“same”。卷积核的尺寸为S,输入的尺寸为P,padding = ‘valid“时,卷积后的height,width结果为:height =width = (P-S)/strides +1

data_format:输入数据格式,表示通道数的位置,默认为“channels_last”。“channels_first”应将数据组织为(batch_size, channels, height, width),而“channels_last”应将数据组织为(batch_size,height,width,channels)。

activation=None:相当于经过卷积输出后,在经过一次激活函数,常见的激活函数有relu,softmax,selu等

use_bias =0 、1,偏置项,0表示没有增加bias,1表示有

MaxPooling2D

tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=None, padding='valid', data_format=None,**kwargs
)

其他操作

自定义卷积层

super().init()为类继承
call()第一次调用的时候会调用 build() ,然后设置self.built = True,之后每次调用时不再调用build()

class C2(tf.keras.layers.Layer):def __init__(self):super().__init__()def build(self, input_shape):self.w = tf.random.normal([5, 5, input_shape[-1], 256])def call(self, inputs):return tf.nn.conv2d(inputs,filters=self.w,strides=1,padding=[[0, 0], [2, 2],[2, 2], [0, 0]])

BN层

和激活函数层、卷积层、全连接层、池化层一样,BN(Batch Normalization)也属于网络的一层
主要目的加快收敛速度,减少学习率、参数初始化、权重衰减系数、Drop out比例等参数调整

keras.layers.BatchNormalization()

简单逻辑如下,其中gamma和beta通过学习改变

m = K.mean(X, axis=-1, keepdims=True)        #计算均值
std = K.std(X, axis=-1, keepdims=True)           #计算标准差
X_normed = (X - m) / (std + self.epsilon)         #归一化
out = self.gamma * X_normed + self.beta           #重构变换

Dropout

缓解过拟合

tf.keras.layers.Dropout(rate, noise_shape=None, seed=None, **kwargs
)

tf.keras CNN网络搭建笔记相关推荐

  1. 使用tf.keras搭建mnist手写数字识别网络

    使用tf.keras搭建mnist手写数字识别网络 目录 使用tf.keras搭建mnist手写数字识别网络 1.使用tf.keras.Sequential搭建序列模型 1.1 tf.keras.Se ...

  2. TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络

    TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...

  3. yolov3从头实现(四)-- darknet53网络tf.keras搭建

    darknet53网络tf.keras搭建 一.定义darknet块类 1 .darknet块网络结构 2.darknet块实现 # 定义darknet块类 class _ResidualBlock( ...

  4. 机器学习(七)——tf.keras搭建神经网络固定模式

    一.总纲(一般按照下面6步搭建) import--导入相关模块 train,test--指定训练集与测试集 model = tf.keras.models.Sequential--在Sequentia ...

  5. CNN网络的搭建(Lenet5与ResNet18)

    CNN介绍 这里给出维基百科中对于卷积神经网络简介 卷积神经网络(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单 ...

  6. 基于FPGA的一维卷积神经网络CNN的实现(三)训练网络搭建及参数导出(附代码)

    训练网络搭建 环境:Pytorch,Pycham,Matlab. 说明:该网络反向传播是通过软件方式生成,FPGA内部不进行反向传播计算. 该节通过Python获取训练数据集,并通过Pytorch框架 ...

  7. 简要笔记-CNN网络

    以下是CNN网络的简要介绍. 1 CNN的发展简述 CNN可以有效降低传统神经网络(全连接)的复杂性,常见的网络结构有LeNet.AlexNet.ZFNet.VGGNet.GoogleNet.ResN ...

  8. 抽取CNN网络任意层的特征,VGG模型fine-tuning实践

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 作为迁移学习的一种,finetune能够将general的特征转变为special的特征,从而 ...

  9. 第十二章_网络搭建及训练

    文章目录 第十二章 网络搭建及训练 CNN训练注意事项 第十二章 TensorFlow.pytorch和caffe介绍 12.1 TensorFlow 12.1.1 TensorFlow是什么? 12 ...

最新文章

  1. 【rnnoise源码分析】biquad滤波器
  2. -static 静态链接库的某些问题
  3. Bitmap 索引 vs. B-tree 索引:如何选择以及何时使用?——4-5
  4. [企业内部https证书配置]tomcat 7配置https的完整历程
  5. 迅雷极速与旋风的对比
  6. GoAhead2.5源代码分析之19-web层(webs.c)
  7. 队列总结(六)DelayQueue
  8. 后端开发面试自我介绍_java工程师面试自我介绍范文
  9. oracle全局索引改成本地索引,解析一个通过添加本地分区索引提高SQL性能的案例...
  10. 最近看到一篇文章拿来跟午饭们分享--养生之道补肾气
  11. 淘宝API,api各种接口
  12. 工业线上赛(2022省赛)
  13. ## 投标人出具哪种检测机构的报告才具有法律效力?
  14. 增高助长~~~~~~~~~~~~
  15. Redis是什么、特点、优势
  16. 网络爬虫逆向(全国建筑市场监管公共服务平台)
  17. Excel操作记录如何生成日志
  18. node-js上手安装。
  19. 游标卡尺原理及读数方法
  20. oracle的trim函数使用时不生效问题

热门文章

  1. linux 用户、群组及权限操作
  2. 靠加班?靠团建?靠个人?请停止无效努力!
  3. 算法我也不知道有没有下一个---一个题目的开端(索引堆与图)
  4. ORA-01919: role 'OLAPI_TRACE_USER' does not exist
  5. BZOJ3170: [Tjoi2013]松鼠聚会(切比雪夫距离转曼哈顿距离)
  6. 深度deepin安装apache tomcat
  7. FNV哈希算法【转】
  8. 使用yum安装CDH Hadoop集群
  9. iOS中常用的正则表达式
  10. JSP/Servlet中的几个编码的作用