TensorFlow2 手把手教你实现自定义层

  • 概述
  • Sequential
  • Model & Layer
  • 案例
    • 数据集介绍
    • 完整代码

概述

通过自定义网络, 我们可以自己创建网络并和现有的网络串联起来, 从而实现各种各样的网络结构.

Sequential

Sequential 是 Keras 的一个网络容器. 可以帮助我们将多层网络封装在一起.

通过 Sequential 我们可以把现有的层已经我们自己的层实现结合, 一次前向传播就可以实现数据从第一层到最后一层的计算.

格式:

tf.keras.Sequential(layers=None, name=None
)

例子:

# 5层网络模型
model = tf.keras.Sequential([tf.keras.layers.Dense(256, activation=tf.nn.relu),tf.keras.layers.Dense(128, activation=tf.nn.relu),tf.keras.layers.Dense(64, activation=tf.nn.relu),tf.keras.layers.Dense(32, activation=tf.nn.relu),tf.keras.layers.Dense(10)
])

Model & Layer

通过 Model 和 Layer 的__init__call()我们可以自定义层和模型.

Model:

class My_Model(tf.keras.Model):  # 继承Modeldef __init__(self):"""初始化"""super(My_Model, self).__init__()self.fc1 = My_Dense(784, 256)  # 第一层self.fc2 = My_Dense(256, 128)  # 第二层self.fc3 = My_Dense(128, 64)  # 第三层self.fc4 = My_Dense(64, 32)  # 第四层self.fc5 = My_Dense(32, 10)  # 第五层def call(self, inputs, training=None):"""在Model被调用的时候执行:param inputs: 输入:param training: 默认为None:return: 返回输出"""x = self.fc1(inputs)x = tf.nn.relu(x)x = self.fc2(x)x = tf.nn.relu(x)x = self.fc3(x)x = tf.nn.relu(x)x = self.fc4(x)x = tf.nn.relu(x)x = self.fc5(x)return x

Layer:

class My_Dense(tf.keras.layers.Layer):  # 继承Layerdef __init__(self, input_dim, output_dim):"""初始化:param input_dim::param output_dim:"""super(My_Dense, self).__init__()# 添加变量self.kernel = self.add_variable("w", [input_dim, output_dim])  # 权重self.bias = self.add_variable("b", [output_dim])  # 偏置def call(self, inputs, training=None):"""在Layer被调用的时候执行, 计算结果:param inputs: 输入:param training: 默认为None:return: 返回计算结果"""# y = w * x + bout = inputs @ self.kernel + self.biasreturn out

案例

数据集介绍

CIFAR-10 是由 10 类不同的物品组成的 6 万张彩色图片的数据集. 其中 5 万张为训练集, 1 万张为测试集.

完整代码

import tensorflow as tfdef pre_process(x, y):# 转换xx = 2 * tf.cast(x, dtype=tf.float32) / 255 - 1  # 转换为-1~1的形式x = tf.reshape(x, [-1, 32 * 32 * 3])  # 把x铺平# 转换yy = tf.convert_to_tensor(y)  # 转换为0~1的形式y = tf.one_hot(y, depth=10)  # 转成one_hot编码# 返回x, yreturn x, ydef get_data():"""获取数据:return:"""# 获取数据(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()# 调试输出维度print(X_train.shape)  # (50000, 32, 32, 3)print(y_train.shape)  # (50000, 1)# squeezey_train = tf.squeeze(y_train)  # (50000, 1) => (50000,)y_test = tf.squeeze(y_test)  # (10000, 1) => (10000,)# 分割训练集train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(10000, seed=0)train_db = train_db.batch(batch_size).map(pre_process).repeat(iteration_num)  # 迭代20次# 分割测试集test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test)).shuffle(10000, seed=0)test_db = test_db.batch(batch_size).map(pre_process)return train_db, test_dbclass My_Dense(tf.keras.layers.Layer):  # 继承Layerdef __init__(self, input_dim, output_dim):"""初始化:param input_dim::param output_dim:"""super(My_Dense, self).__init__()# 添加变量self.kernel = self.add_weight("w", [input_dim, output_dim])  # 权重self.bias = self.add_weight("b", [output_dim])  # 偏置def call(self, inputs, training=None):"""在Layer被调用的时候执行, 计算结果:param inputs: 输入:param training: 默认为None:return: 返回计算结果"""# y = w * x + bout = inputs @ self.kernel + self.biasreturn outclass My_Model(tf.keras.Model):  # 继承Modeldef __init__(self):"""初始化"""super(My_Model, self).__init__()self.fc1 = My_Dense(32 * 32 * 3, 256)  # 第一层self.fc2 = My_Dense(256, 128)  # 第二层self.fc3 = My_Dense(128, 64)  # 第三层self.fc4 = My_Dense(64, 32)  # 第四层self.fc5 = My_Dense(32, 10)  # 第五层def call(self, inputs, training=None):"""在Model被调用的时候执行:param inputs: 输入:param training: 默认为None:return: 返回输出"""x = self.fc1(inputs)x = tf.nn.relu(x)x = self.fc2(x)x = tf.nn.relu(x)x = self.fc3(x)x = tf.nn.relu(x)x = self.fc4(x)x = tf.nn.relu(x)x = self.fc5(x)return x# 定义超参数
batch_size = 256  # 一次训练的样本数目
learning_rate = 0.001  # 学习率
iteration_num = 20  # 迭代次数
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # 优化器
loss = tf.losses.CategoricalCrossentropy(from_logits=True)  # 损失
network = My_Model()  # 实例化网络# 调试输出summary
network.build(input_shape=[None, 32 * 32 * 3])
print(network.summary())# 组合
network.compile(optimizer=optimizer,loss=loss,metrics=["accuracy"])if __name__ == "__main__":# 获取分割的数据集train_db, test_db = get_data()# 拟合network.fit(train_db, epochs=5, validation_data=test_db, validation_freq=1)

输出结果:

Model: "my__model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
my__dense (My_Dense)         multiple                  786688
_________________________________________________________________
my__dense_1 (My_Dense)       multiple                  32896
_________________________________________________________________
my__dense_2 (My_Dense)       multiple                  8256
_________________________________________________________________
my__dense_3 (My_Dense)       multiple                  2080
_________________________________________________________________
my__dense_4 (My_Dense)       multiple                  330
=================================================================
Total params: 830,250
Trainable params: 830,250
Non-trainable params: 0
_________________________________________________________________
None
(50000, 32, 32, 3)
(50000, 1)
2021-06-15 14:35:26.600766: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/5
3920/3920 [==============================] - 39s 10ms/step - loss: 0.9676 - accuracy: 0.6595 - val_loss: 1.8961 - val_accuracy: 0.5220
Epoch 2/5
3920/3920 [==============================] - 41s 10ms/step - loss: 0.3338 - accuracy: 0.8831 - val_loss: 3.3207 - val_accuracy: 0.5141
Epoch 3/5
3920/3920 [==============================] - 41s 10ms/step - loss: 0.1713 - accuracy: 0.9410 - val_loss: 4.2247 - val_accuracy: 0.5122
Epoch 4/5
3920/3920 [==============================] - 41s 10ms/step - loss: 0.1237 - accuracy: 0.9581 - val_loss: 4.9458 - val_accuracy: 0.5050
Epoch 5/5
3920/3920 [==============================] - 42s 11ms/step - loss: 0.1003 - accuracy: 0.9666 - val_loss: 5.2425 - val_accuracy: 0.5097

TensorFlow2 手把手教你实现自定义层相关推荐

  1. TensorFlow2 手把手教你避开梯度消失和梯度爆炸

    TensorFlow2 手把手教你避开梯度消失和梯度爆炸 梯度消失 & 梯度爆炸 梯度消失 梯度爆炸 张量限幅 tf.clip_by_value tf.clip_by_norm mnist 展 ...

  2. 手把手教你 Tableau 自定义地理编码(十九)

    手把手教你 Tableau 自定义地理编码 Tableau 绘制地图时,支持的地理位置数据有限.当我们需要 Tableau 识别我们自定义的地理位置数据时,我们可以使用 Tableau 的自定义地理编 ...

  3. 手把手教你 springboot 自定义注解 (含代码)

    来源:moon聊技术 https://mp.weixin.qq.com/s/s-B8i1wyJESqvrKp10ytQg 女朋友 : 我想要我自己的注解,你教我! moon : 诶?你怎么突然想要自己 ...

  4. Spring Boot - 手把手教小师妹自定义Spring Boot Starter

    文章目录 Pre 自定义starter的套路 步骤 命名规范 官方命名空间 自定义命名空间 实战 创建一个父maven项目:springboot_custome_starter 创建 两个Module ...

  5. 干货 | 手把手教你iOS自定义视频压缩

    作者简介 孙龙波,携程内容信息研发部 Native 开发 leader.目前主要负责携程攻略,行程,视频直播等项目的前端开发和团队管理. 一.前言 随着抖音,快手等APP的迅猛发展,短视频在移动端的地 ...

  6. 手把手教你绘制自定义地图

    1. 内容概述 自定义地图组件支持使用用户自己绘制的地图绑定和呈现数据. 我们可以直接在设计器中绘制自定义地图,只需导入底图图片,进行描边和调整标记点,即可使用. 如下图所示,我们根据一张商场的平面图 ...

  7. 【朝花夕拾】Android自定义View之(一)手把手教你看懂View绘制流程——向源码要答案

    前言 原文:Android自定义View之(一)手把手教你看懂View绘制流程--向源码要答案 View作为整个app的颜值担当,在Android体系中占有重要的地位.深入理解Android View ...

  8. C# SuperSocket 手把手教你入门 傻瓜教程---5(探索自定义AppServer、AppSession,Conmmand,用配置文件App.comfig启动服务器)

    C# SuperSocket 手把手教你入门 傻瓜教程系列教程 C# SuperSocket 手把手教你入门 傻瓜教程---1(服务器单向接收客户端发送数据) C# SuperSocket 手把手教你 ...

  9. 解放前端工程师——手把手教你开发自己的自定义列表和自定义表单系列之二接口

    前言 阅读前请按照顺序参看系列文章,效果更佳! Vue中路由到一个公共组件,然后根据路径中是否存在文件动态加载组件 解放前端工程师--手把手教你开发自己的自定义列表和自定义表单系列之一缘起 据说系列文 ...

最新文章

  1. C#.NET常见问题(FAQ)-如何修改Form不能修改窗体大小
  2. SAP Spartacus angular.json 中定义的 serve-ssr
  3. SAP Spartacus运行时错误 - The pipe cxUrl could not be found!
  4. 使用ST05 研究product extension field deletion
  5. Leetcode 219. 存在重复元素 II
  6. android屏幕基础知识
  7. 深度学习三十年创新路
  8. 有序充电matlab仿真,电动汽车有序充电策略研究
  9. HDU - 4780费用流
  10. linux dns服务无效,Linux下搭建DNS服务器及踩坑
  11. 《SQL注入攻击与防御(第2版)》百度网盘链接
  12. 易语言 html替换,易语言教程文本替换和子文本替换
  13. python的unicode编码表_Python-编码
  14. 图像去雾去雨去模糊去噪
  15. chrome-调试按钮详解
  16. 前端三件套系例之CSS——响应式布局
  17. (一)买基金的基础知识
  18. 如果阿里裁员30%是真的,你拿什么和阿里背景的程序员竞争?
  19. 《回忆之前,忘记之后---写给我记忆中的汪峰》
  20. Open3D键盘切换上下帧显示点云

热门文章

  1. 微信趣味测评小程序独立版源码
  2. 基于动态时间规整(DTW)的孤立字语音识别
  3. 值得收藏的6个OCR文字识别软件,帮你提升10倍工作效率
  4. LeetCode刷题(43)~汉明距离【异或+布赖恩·克尼根算法】
  5. 印刷常用专业术语解释
  6. 消息队列-beanstalkd
  7. beats耳机红白交替闪烁三次_beats耳机红白灯交替闪如何解决
  8. linux pdf翻译
  9. 考虑题4所示的日志记录_福建省厦门双十中学2016届高三上学期中考试地理【解析】...
  10. htc+one+m8+联通+android+5,HTC One M8e 双卡版 刷机包 Android5.0.2+Sense6 完美ROOT 下拉农历 精简稳定版...