TensorFlow2 千层神经网络, 始步于此 --ResNet 实现
TensorFlow2 千层神经网络, 始步于此 --ResNet 实现
- 概述
- 深度网络退化
- 代码实现
- 残差块
- 超参数
- ResNet 18 网络
- 获取数据
- 完整代码
概述
深度残差网络 ResNet (Deep residual network) 和 Alexnet 一样是深度学习的一个里程碑.
深度网络退化
当网络深度从 0 增加到 20 的时候, 结果会随着网络的深度而变好. 但当网络超过 20 层的时候, 结果会随着网络深度的增加而下降. 网络的层数越深, 梯度之间的相关性会越来越差, 模型也更难优化.
残差网络 (ResNet) 通过增加映射 (Identity) 来解决网络退化问题. H(x) = F(x) + x
通过集合残差而不是恒等隐射, 保证了网络不会退化.
代码实现
残差块
class BasicBlock(tf.keras.layers.Layer):"""定义残差块"""def __init__(self, filter_num, stride=1):super(BasicBlock, self).__init__()self.conv1 = tf.keras.layers.Conv2D(filter_num, kernel_size=(3, 3), strides=stride, padding="same")self.bn1 = tf.keras.layers.BatchNormalization() # 标准化self.relu = tf.keras.layers.Activation("relu")self.conv2 = tf.keras.layers.Conv2D(filter_num, kernel_size=(3, 3), strides=1, padding="same")self.bn2 = tf.keras.layers.BatchNormalization() # 标准化# 如果步长不为1, 用1*1的卷积实现下采样if stride != 1:self.downsample = tf.keras.Sequential(tf.keras.layers.Conv2D(filter_num, kernel_size=(1, 1), strides=stride))else:self.downsample = lambda x: x # 返回xdef call(self, inputs, training=None):# unit1out = self.conv1(inputs)out = self.bn1(out, training=training)out = self.relu(out)# unit2out = self.conv2(out)out = self.bn2(out, training=training)identity = self.downsample(inputs) # 降采样output = tf.keras.layers.add([out, identity]) # 相加output = tf.nn.relu(output)return output
超参数
# 定义超参数
batch_size = 1024 # 一次训练的样本数目
learning_rate = 0.0001 # 学习率
iteration_num = 5 # 迭代次数
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) # 优化器
loss = tf.losses.CategoricalCrossentropy(from_logits=True) # 损失
ResNet 18 网络
ResNet_18 = tf.keras.Sequential([# 初始层tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1)), # 卷积tf.keras.layers.BatchNormalization(),tf.keras.layers.Activation("relu"),tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="same"), # 池化# 8个block(每个为两层)BasicBlock(filter_num=64, stride=1),BasicBlock(filter_num=64, stride=1),BasicBlock(filter_num=128, stride=2),BasicBlock(filter_num=128, stride=1),BasicBlock(filter_num=256, stride=2),BasicBlock(filter_num=256, stride=1),BasicBlock(filter_num=512, stride=2),BasicBlock(filter_num=512, stride=1),tf.keras.layers.GlobalAveragePooling2D(), #池化# 全连接层tf.keras.layers.Dense(100) # 100类
])# 调试输出summary
ResNet_18.build(input_shape=[None, 32, 32, 3])
print(ResNet_18.summary())
获取数据
def pre_process(x, y):"""数据预处理"""x = tf.cast(x, tf.float32) * 2 / 255 - 1 # 范围-1~1y = tf.one_hot(y, depth=100)return x, ydef get_data():"""获取数据"""# 读取数据(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar100.load_data()y_train, y_test = tf.squeeze(y_train, axis=1), tf.squeeze(y_test, axis=1) # 压缩目标值# 分割数据集train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(10000).map(pre_process).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test)).map(pre_process).batch(batch_size)return train_db, test_db
完整代码
来, 让我们干了这杯酒, 来看完整代码.
完整代码:
import tensorflow as tfclass BasicBlock(tf.keras.layers.Layer):"""定义残差块"""def __init__(self, filter_num, stride=1):super(BasicBlock, self).__init__()self.conv1 = tf.keras.layers.Conv2D(filter_num, kernel_size=(3, 3), strides=stride, padding="same")self.bn1 = tf.keras.layers.BatchNormalization() # 标准化self.relu = tf.keras.layers.Activation("relu")self.conv2 = tf.keras.layers.Conv2D(filter_num, kernel_size=(3, 3), strides=1, padding="same")self.bn2 = tf.keras.layers.BatchNormalization() # 标准化# 如果步长不为1, 用1*1的卷积实现下采样if stride != 1:self.downsample = tf.keras.Sequential(tf.keras.layers.Conv2D(filter_num, kernel_size=(1, 1), strides=stride))else:self.downsample = lambda x: x # 返回xdef call(self, inputs, training=None):# unit1out = self.conv1(inputs)out = self.bn1(out, training=training)out = self.relu(out)# unit2out = self.conv2(out)out = self.bn2(out, training=training)identity = self.downsample(inputs) # 降采样output = tf.keras.layers.add([out, identity]) # 相加output = tf.nn.relu(output)return output# 定义超参数
batch_size = 1024 # 一次训练的样本数目
learning_rate = 0.0001 # 学习率
iteration_num = 5 # 迭代次数
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) # 优化器
loss = tf.losses.CategoricalCrossentropy(from_logits=True) # 损失ResNet_18 = tf.keras.Sequential([# 初始层tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1)), # 卷积tf.keras.layers.BatchNormalization(),tf.keras.layers.Activation("relu"),tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="same"), # 池化# 8个block(每个为两层)BasicBlock(filter_num=64, stride=1),BasicBlock(filter_num=64, stride=1),BasicBlock(filter_num=128, stride=2),BasicBlock(filter_num=128, stride=1),BasicBlock(filter_num=256, stride=2),BasicBlock(filter_num=256, stride=1),BasicBlock(filter_num=512, stride=2),BasicBlock(filter_num=512, stride=1),tf.keras.layers.GlobalAveragePooling2D(), # 池化# 全连接层tf.keras.layers.Dense(100) # 100类
])# 调试输出summary
ResNet_18.build(input_shape=[None, 32, 32, 3])
print(ResNet_18.summary())# 组合
ResNet_18.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])def pre_process(x, y):"""数据预处理"""x = tf.cast(x, tf.float32) * 2 / 255 - 1 # 范围-1~1y = tf.one_hot(y, depth=100)return x, ydef get_data():"""获取数据"""# 读取数据(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar100.load_data()y_train, y_test = tf.squeeze(y_train, axis=1), tf.squeeze(y_test, axis=1) # 压缩目标值# 分割数据集train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(10000).map(pre_process).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test)).map(pre_process).batch(batch_size)return train_db, test_dbif __name__ == "__main__":# 获取分割的数据集train_db, test_db = get_data()# 拟合ResNet_18.fit(train_db, epochs=iteration_num, validation_data=test_db, validation_freq=1)
输出结果:
Model: "sequential_16"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_70 (Conv2D) (None, 30, 30, 64) 1792
_________________________________________________________________
batch_normalization_17 (Batc (None, 30, 30, 64) 256
_________________________________________________________________
activation_9 (Activation) (None, 30, 30, 64) 0
_________________________________________________________________
max_pooling2d_26 (MaxPooling (None, 30, 30, 64) 0
_________________________________________________________________
basic_block_8 (BasicBlock) (None, 30, 30, 64) 74368
_________________________________________________________________
basic_block_9 (BasicBlock) (None, 30, 30, 64) 74368
_________________________________________________________________
basic_block_10 (BasicBlock) (None, 15, 15, 128) 230784
_________________________________________________________________
basic_block_11 (BasicBlock) (None, 15, 15, 128) 296192
_________________________________________________________________
basic_block_12 (BasicBlock) (None, 8, 8, 256) 920320
_________________________________________________________________
basic_block_13 (BasicBlock) (None, 8, 8, 256) 1182208
_________________________________________________________________
basic_block_14 (BasicBlock) (None, 4, 4, 512) 3675648
_________________________________________________________________
basic_block_15 (BasicBlock) (None, 4, 4, 512) 4723712
_________________________________________________________________
global_average_pooling2d_1 ( (None, 512) 0
_________________________________________________________________
dense_16 (Dense) (None, 100) 51300
=================================================================
Total params: 11,230,948
Trainable params: 11,223,140
Non-trainable params: 7,808
_________________________________________________________________
None
Epoch 1/20
49/49 [==============================] - 43s 848ms/step - loss: 3.9558 - accuracy: 0.1203 - val_loss: 4.6631 - val_accuracy: 0.0100
Epoch 2/20
49/49 [==============================] - 41s 834ms/step - loss: 3.0988 - accuracy: 0.2525 - val_loss: 4.9431 - val_accuracy: 0.0112
Epoch 3/20
49/49 [==============================] - 41s 828ms/step - loss: 2.5981 - accuracy: 0.3518 - val_loss: 5.2123 - val_accuracy: 0.0150
Epoch 4/20
49/49 [==============================] - 40s 824ms/step - loss: 2.1962 - accuracy: 0.4464 - val_loss: 5.4619 - val_accuracy: 0.0230
Epoch 5/20
49/49 [==============================] - 41s 828ms/step - loss: 1.8086 - accuracy: 0.5450 - val_loss: 5.4788 - val_accuracy: 0.0313
Epoch 6/20
49/49 [==============================] - 41s 827ms/step - loss: 1.4073 - accuracy: 0.6611 - val_loss: 5.6753 - val_accuracy: 0.0386
Epoch 7/20
49/49 [==============================] - 41s 827ms/step - loss: 1.0093 - accuracy: 0.7791 - val_loss: 5.3822 - val_accuracy: 0.0671
Epoch 8/20
49/49 [==============================] - 41s 829ms/step - loss: 0.6361 - accuracy: 0.8905 - val_loss: 5.0999 - val_accuracy: 0.0961
Epoch 9/20
49/49 [==============================] - 41s 829ms/step - loss: 0.3532 - accuracy: 0.9587 - val_loss: 4.7099 - val_accuracy: 0.1366
Epoch 10/20
49/49 [==============================] - 41s 828ms/step - loss: 0.1799 - accuracy: 0.9874 - val_loss: 4.1926 - val_accuracy: 0.1899
Epoch 11/20
49/49 [==============================] - 41s 828ms/step - loss: 0.0964 - accuracy: 0.9965 - val_loss: 3.6718 - val_accuracy: 0.2504
Epoch 12/20
49/49 [==============================] - 41s 827ms/step - loss: 0.0580 - accuracy: 0.9986 - val_loss: 3.3465 - val_accuracy: 0.2876
Epoch 13/20
49/49 [==============================] - 41s 828ms/step - loss: 0.0390 - accuracy: 0.9995 - val_loss: 3.1585 - val_accuracy: 0.3219
Epoch 14/20
49/49 [==============================] - 41s 828ms/step - loss: 0.0286 - accuracy: 0.9996 - val_loss: 3.1677 - val_accuracy: 0.3271
Epoch 15/20
49/49 [==============================] - 41s 830ms/step - loss: 0.0231 - accuracy: 0.9996 - val_loss: 3.1084 - val_accuracy: 0.3384
Epoch 16/20
49/49 [==============================] - 41s 829ms/step - loss: 0.0193 - accuracy: 0.9997 - val_loss: 3.1312 - val_accuracy: 0.3452
Epoch 17/20
49/49 [==============================] - 41s 829ms/step - loss: 0.0165 - accuracy: 0.9997 - val_loss: 3.1519 - val_accuracy: 0.3413
Epoch 18/20
49/49 [==============================] - 41s 828ms/step - loss: 0.0140 - accuracy: 0.9997 - val_loss: 3.1658 - val_accuracy: 0.3435
Epoch 19/20
49/49 [==============================] - 41s 827ms/step - loss: 0.0123 - accuracy: 0.9997 - val_loss: 3.1867 - val_accuracy: 0.3433
Epoch 20/20
49/49 [==============================] - 41s 828ms/step - loss: 0.0111 - accuracy: 0.9997 - val_loss: 3.2123 - val_accuracy: 0.3447
注: 我们可以看出 ResNet18 比 VGG13 的准确率高了一大截.
祝全天下的父亲节日快乐!
TensorFlow2 千层神经网络, 始步于此 --ResNet 实现相关推荐
- 《预训练周刊》第40期: 量子预训练、千层BERT与GPT
No.40 智源社区 预训练组 预 训 练 研究 观点 资源 活动 周刊订阅 告诉大家一个好消息,<预训练周刊>已经开启"订阅功能",以后我们会向您自动推送最新版的&l ...
- DeepMind激起千层浪的这篇论文,并非无所不能
皇甫琦 葛冬冬 撰稿 金磊 整理自 凹非寺 量子位 报道 | 公众号 QbitAI 本文对DeepMind近期的神经网络求解MIP(混合整数规划)的论文进行了一些初步解读.事实上,相较于此领域近期的类 ...
- DeepLearning.AI第一部分第三周、 浅层神经网络(Shallow neural networks)
文章目录 3.1 一些简单的介绍 3.2神经网络的表示Neural Network Representation 3.3计算一个神经网络的输出Computing a Neural Network's ...
- 手写 单隐藏层神经网络_反向传播(Matlab实现)
文章目录 要点 待优化 效果 代码 mian train_neural_net 待优化(1)已完成 要点 1.sigmoid函数做为激活函数,二分类交叉熵函数做损失函数 2.可以同时对整个训练集进行训 ...
- 千层套路 - Vue 3.0 初始化源码探秘
关注若川视野, 回复"pdf" 领取资料,回复"1",可加群长期交流学习 刘崇桢,微医云服务团队前端工程师,左手抱娃.右手持家的非典型码农. 9 月初 Vue. ...
- 1.3)深度学习笔记------浅层神经网络
目录 1)Neural Network Overview 2)Neural Network Representation 3)Computing a Neural Network's Output(重 ...
- 深度学习笔记(4) 浅层神经网络
深度学习笔记(4) 浅层神经网络 1. 神经网络概述 2. 激活函数 3. 激活函数的导数 4. 神经网络的梯度下降 5. 随机初始化 1. 神经网络概述 神经网络看起来是如下: 有输入特征x1.x2 ...
- 神经网络与深度学习三:编写单隐层神经网络
三:编写单隐层神经网络 1 神经网络概述 这篇文章你会学到如何实现一个神经网络,在我们深入学习技术细节之前,现在先大概快速的了解一下如何实现神经网络,如果你对某些内容不甚理解(后面的文章中会深入其中的 ...
- 实验二 单隐层神经网络
一.实验目的 (1)学习并掌握常见的机器学习方法: (2)能够结合所学的python知识实现机器学习算法: (3)能够用所学的机器学习算法解决实际问题. 二.实验内容与要求 (1)掌握神经网络的基本原 ...
最新文章
- 黄聪:IE6下用控制图片最大显示尺寸
- 小波包分解 matlab_多尺度一维小波分解
- FPGA之道(10)布线资源与接口资源
- 勿谈大,且看Bloomberg的中数据处理平台
- wed6699整站程序下载【首发】
- arcgis10.1连接sqlserver数据库常见问题(转载)
- php mvc登陆注册,Asp.Net MVC 5使用Identity之简单的注册和登陆
- 数论 + 公式 - HDU 4335 What is N?
- Linux操作Oracle(17)——linux oracle启动时 :各种报错 解决方案(2020.07.30更新...)
- 任务02——安装 Intellj IDEA,编写一个简易四则运算小程序,并将代码提交到 GitHub...
- oracle查询语句中case when的使用
- u盘ios刻录_用UltraISO刻录U盘安装系统
- 怎么将几张pdf合并成一张_怎么把多个PDF合并成一个PDF?分享合并PDF文件最简单的方法...
- 传智教育|如何转行互联网高薪岗位之一的软件测试?(附软件测试学习路线图)
- 植树问题java,云南省优秀多媒体育软件大赛公示.doc
- visio绘制网络拓扑图要求_必备!可以电脑在线使用的3款网络拓扑图软件安利
- 作为一名基层管理者如何利用情商管理自己和团队(二)
- android 覆盖虚拟按键,解决Android 虚拟按键遮住了页面内容的问题
- 豆瓣首页话题输入框的实现
- spring boot check/token Principal 如何注入
热门文章
- 华为Matebook14 预装office 重新安装
- oracle数据库ORA-00918: 未明确定义列
- 医学信息学相关术语、缩语及专业名词
- web H5 首页引导图
- android项目 之 来电管家(3) ----- 添加与删除黑名单
- 宅霸我的世界服务器无响应,宅霸我的世界 v4.5.0 官方版
- php一句话上传木马,一句话木马上传常见的几种方法
- unity开发炉石传说系列玩家手中卡牌出入及移动排列代码
- 3年测试经验,用例设计竟然不知道状态迁移法?
- Photoshop cc 2015 导出web所用格式提示错误 无法完成该操作怎么办?用这个方法轻松解决!