TensorFlow2 千层神经网络, 始步于此 --ResNet 实现

  • 概述
  • 深度网络退化
  • 代码实现
    • 残差块
    • 超参数
    • ResNet 18 网络
    • 获取数据
  • 完整代码

概述

深度残差网络 ResNet (Deep residual network) 和 Alexnet 一样是深度学习的一个里程碑.

深度网络退化

当网络深度从 0 增加到 20 的时候, 结果会随着网络的深度而变好. 但当网络超过 20 层的时候, 结果会随着网络深度的增加而下降. 网络的层数越深, 梯度之间的相关性会越来越差, 模型也更难优化.

残差网络 (ResNet) 通过增加映射 (Identity) 来解决网络退化问题. H(x) = F(x) + x通过集合残差而不是恒等隐射, 保证了网络不会退化.

代码实现

残差块

class BasicBlock(tf.keras.layers.Layer):"""定义残差块"""def __init__(self, filter_num, stride=1):super(BasicBlock, self).__init__()self.conv1 = tf.keras.layers.Conv2D(filter_num, kernel_size=(3, 3), strides=stride, padding="same")self.bn1 = tf.keras.layers.BatchNormalization()  # 标准化self.relu = tf.keras.layers.Activation("relu")self.conv2 = tf.keras.layers.Conv2D(filter_num, kernel_size=(3, 3), strides=1, padding="same")self.bn2 = tf.keras.layers.BatchNormalization()  # 标准化# 如果步长不为1, 用1*1的卷积实现下采样if stride != 1:self.downsample = tf.keras.Sequential(tf.keras.layers.Conv2D(filter_num, kernel_size=(1, 1), strides=stride))else:self.downsample = lambda x: x  # 返回xdef call(self, inputs, training=None):# unit1out = self.conv1(inputs)out = self.bn1(out, training=training)out = self.relu(out)# unit2out = self.conv2(out)out = self.bn2(out, training=training)identity = self.downsample(inputs)  # 降采样output = tf.keras.layers.add([out, identity])  # 相加output = tf.nn.relu(output)return output

超参数

# 定义超参数
batch_size = 1024  # 一次训练的样本数目
learning_rate = 0.0001  # 学习率
iteration_num = 5  # 迭代次数
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # 优化器
loss = tf.losses.CategoricalCrossentropy(from_logits=True)  # 损失

ResNet 18 网络

ResNet_18 = tf.keras.Sequential([# 初始层tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1)),  # 卷积tf.keras.layers.BatchNormalization(),tf.keras.layers.Activation("relu"),tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="same"),  # 池化# 8个block(每个为两层)BasicBlock(filter_num=64, stride=1),BasicBlock(filter_num=64, stride=1),BasicBlock(filter_num=128, stride=2),BasicBlock(filter_num=128, stride=1),BasicBlock(filter_num=256, stride=2),BasicBlock(filter_num=256, stride=1),BasicBlock(filter_num=512, stride=2),BasicBlock(filter_num=512, stride=1),tf.keras.layers.GlobalAveragePooling2D(),  #池化# 全连接层tf.keras.layers.Dense(100)  # 100类
])# 调试输出summary
ResNet_18.build(input_shape=[None, 32, 32, 3])
print(ResNet_18.summary())

获取数据

def pre_process(x, y):"""数据预处理"""x = tf.cast(x, tf.float32) * 2 / 255 - 1 # 范围-1~1y = tf.one_hot(y, depth=100)return x, ydef get_data():"""获取数据"""# 读取数据(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar100.load_data()y_train, y_test = tf.squeeze(y_train, axis=1), tf.squeeze(y_test, axis=1)   # 压缩目标值# 分割数据集train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(10000).map(pre_process).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test)).map(pre_process).batch(batch_size)return train_db, test_db

完整代码

来, 让我们干了这杯酒, 来看完整代码.


完整代码:

import tensorflow as tfclass BasicBlock(tf.keras.layers.Layer):"""定义残差块"""def __init__(self, filter_num, stride=1):super(BasicBlock, self).__init__()self.conv1 = tf.keras.layers.Conv2D(filter_num, kernel_size=(3, 3), strides=stride, padding="same")self.bn1 = tf.keras.layers.BatchNormalization()  # 标准化self.relu = tf.keras.layers.Activation("relu")self.conv2 = tf.keras.layers.Conv2D(filter_num, kernel_size=(3, 3), strides=1, padding="same")self.bn2 = tf.keras.layers.BatchNormalization()  # 标准化# 如果步长不为1, 用1*1的卷积实现下采样if stride != 1:self.downsample = tf.keras.Sequential(tf.keras.layers.Conv2D(filter_num, kernel_size=(1, 1), strides=stride))else:self.downsample = lambda x: x  # 返回xdef call(self, inputs, training=None):# unit1out = self.conv1(inputs)out = self.bn1(out, training=training)out = self.relu(out)# unit2out = self.conv2(out)out = self.bn2(out, training=training)identity = self.downsample(inputs)  # 降采样output = tf.keras.layers.add([out, identity])  # 相加output = tf.nn.relu(output)return output# 定义超参数
batch_size = 1024  # 一次训练的样本数目
learning_rate = 0.0001  # 学习率
iteration_num = 5  # 迭代次数
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # 优化器
loss = tf.losses.CategoricalCrossentropy(from_logits=True)  # 损失ResNet_18 = tf.keras.Sequential([# 初始层tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1)),  # 卷积tf.keras.layers.BatchNormalization(),tf.keras.layers.Activation("relu"),tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="same"),  # 池化# 8个block(每个为两层)BasicBlock(filter_num=64, stride=1),BasicBlock(filter_num=64, stride=1),BasicBlock(filter_num=128, stride=2),BasicBlock(filter_num=128, stride=1),BasicBlock(filter_num=256, stride=2),BasicBlock(filter_num=256, stride=1),BasicBlock(filter_num=512, stride=2),BasicBlock(filter_num=512, stride=1),tf.keras.layers.GlobalAveragePooling2D(),  # 池化# 全连接层tf.keras.layers.Dense(100)  # 100类
])# 调试输出summary
ResNet_18.build(input_shape=[None, 32, 32, 3])
print(ResNet_18.summary())# 组合
ResNet_18.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])def pre_process(x, y):"""数据预处理"""x = tf.cast(x, tf.float32) * 2 / 255 - 1  # 范围-1~1y = tf.one_hot(y, depth=100)return x, ydef get_data():"""获取数据"""# 读取数据(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar100.load_data()y_train, y_test = tf.squeeze(y_train, axis=1), tf.squeeze(y_test, axis=1)  # 压缩目标值# 分割数据集train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(10000).map(pre_process).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test)).map(pre_process).batch(batch_size)return train_db, test_dbif __name__ == "__main__":# 获取分割的数据集train_db, test_db = get_data()# 拟合ResNet_18.fit(train_db, epochs=iteration_num, validation_data=test_db, validation_freq=1)

输出结果:

Model: "sequential_16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d_70 (Conv2D)           (None, 30, 30, 64)        1792
_________________________________________________________________
batch_normalization_17 (Batc (None, 30, 30, 64)        256
_________________________________________________________________
activation_9 (Activation)    (None, 30, 30, 64)        0
_________________________________________________________________
max_pooling2d_26 (MaxPooling (None, 30, 30, 64)        0
_________________________________________________________________
basic_block_8 (BasicBlock)   (None, 30, 30, 64)        74368
_________________________________________________________________
basic_block_9 (BasicBlock)   (None, 30, 30, 64)        74368
_________________________________________________________________
basic_block_10 (BasicBlock)  (None, 15, 15, 128)       230784
_________________________________________________________________
basic_block_11 (BasicBlock)  (None, 15, 15, 128)       296192
_________________________________________________________________
basic_block_12 (BasicBlock)  (None, 8, 8, 256)         920320
_________________________________________________________________
basic_block_13 (BasicBlock)  (None, 8, 8, 256)         1182208
_________________________________________________________________
basic_block_14 (BasicBlock)  (None, 4, 4, 512)         3675648
_________________________________________________________________
basic_block_15 (BasicBlock)  (None, 4, 4, 512)         4723712
_________________________________________________________________
global_average_pooling2d_1 ( (None, 512)               0
_________________________________________________________________
dense_16 (Dense)             (None, 100)               51300
=================================================================
Total params: 11,230,948
Trainable params: 11,223,140
Non-trainable params: 7,808
_________________________________________________________________
None
Epoch 1/20
49/49 [==============================] - 43s 848ms/step - loss: 3.9558 - accuracy: 0.1203 - val_loss: 4.6631 - val_accuracy: 0.0100
Epoch 2/20
49/49 [==============================] - 41s 834ms/step - loss: 3.0988 - accuracy: 0.2525 - val_loss: 4.9431 - val_accuracy: 0.0112
Epoch 3/20
49/49 [==============================] - 41s 828ms/step - loss: 2.5981 - accuracy: 0.3518 - val_loss: 5.2123 - val_accuracy: 0.0150
Epoch 4/20
49/49 [==============================] - 40s 824ms/step - loss: 2.1962 - accuracy: 0.4464 - val_loss: 5.4619 - val_accuracy: 0.0230
Epoch 5/20
49/49 [==============================] - 41s 828ms/step - loss: 1.8086 - accuracy: 0.5450 - val_loss: 5.4788 - val_accuracy: 0.0313
Epoch 6/20
49/49 [==============================] - 41s 827ms/step - loss: 1.4073 - accuracy: 0.6611 - val_loss: 5.6753 - val_accuracy: 0.0386
Epoch 7/20
49/49 [==============================] - 41s 827ms/step - loss: 1.0093 - accuracy: 0.7791 - val_loss: 5.3822 - val_accuracy: 0.0671
Epoch 8/20
49/49 [==============================] - 41s 829ms/step - loss: 0.6361 - accuracy: 0.8905 - val_loss: 5.0999 - val_accuracy: 0.0961
Epoch 9/20
49/49 [==============================] - 41s 829ms/step - loss: 0.3532 - accuracy: 0.9587 - val_loss: 4.7099 - val_accuracy: 0.1366
Epoch 10/20
49/49 [==============================] - 41s 828ms/step - loss: 0.1799 - accuracy: 0.9874 - val_loss: 4.1926 - val_accuracy: 0.1899
Epoch 11/20
49/49 [==============================] - 41s 828ms/step - loss: 0.0964 - accuracy: 0.9965 - val_loss: 3.6718 - val_accuracy: 0.2504
Epoch 12/20
49/49 [==============================] - 41s 827ms/step - loss: 0.0580 - accuracy: 0.9986 - val_loss: 3.3465 - val_accuracy: 0.2876
Epoch 13/20
49/49 [==============================] - 41s 828ms/step - loss: 0.0390 - accuracy: 0.9995 - val_loss: 3.1585 - val_accuracy: 0.3219
Epoch 14/20
49/49 [==============================] - 41s 828ms/step - loss: 0.0286 - accuracy: 0.9996 - val_loss: 3.1677 - val_accuracy: 0.3271
Epoch 15/20
49/49 [==============================] - 41s 830ms/step - loss: 0.0231 - accuracy: 0.9996 - val_loss: 3.1084 - val_accuracy: 0.3384
Epoch 16/20
49/49 [==============================] - 41s 829ms/step - loss: 0.0193 - accuracy: 0.9997 - val_loss: 3.1312 - val_accuracy: 0.3452
Epoch 17/20
49/49 [==============================] - 41s 829ms/step - loss: 0.0165 - accuracy: 0.9997 - val_loss: 3.1519 - val_accuracy: 0.3413
Epoch 18/20
49/49 [==============================] - 41s 828ms/step - loss: 0.0140 - accuracy: 0.9997 - val_loss: 3.1658 - val_accuracy: 0.3435
Epoch 19/20
49/49 [==============================] - 41s 827ms/step - loss: 0.0123 - accuracy: 0.9997 - val_loss: 3.1867 - val_accuracy: 0.3433
Epoch 20/20
49/49 [==============================] - 41s 828ms/step - loss: 0.0111 - accuracy: 0.9997 - val_loss: 3.2123 - val_accuracy: 0.3447

注: 我们可以看出 ResNet18 比 VGG13 的准确率高了一大截.


祝全天下的父亲节日快乐!

TensorFlow2 千层神经网络, 始步于此 --ResNet 实现相关推荐

  1. 《预训练周刊》第40期: 量子预训练、千层BERT与GPT

    No.40 智源社区 预训练组 预 训 练 研究 观点 资源 活动 周刊订阅 告诉大家一个好消息,<预训练周刊>已经开启"订阅功能",以后我们会向您自动推送最新版的&l ...

  2. DeepMind激起千层浪的这篇论文,并非无所不能

    皇甫琦 葛冬冬 撰稿 金磊 整理自 凹非寺 量子位 报道 | 公众号 QbitAI 本文对DeepMind近期的神经网络求解MIP(混合整数规划)的论文进行了一些初步解读.事实上,相较于此领域近期的类 ...

  3. DeepLearning.AI第一部分第三周、 浅层神经网络(Shallow neural networks)

    文章目录 3.1 一些简单的介绍 3.2神经网络的表示Neural Network Representation 3.3计算一个神经网络的输出Computing a Neural Network's ...

  4. 手写 单隐藏层神经网络_反向传播(Matlab实现)

    文章目录 要点 待优化 效果 代码 mian train_neural_net 待优化(1)已完成 要点 1.sigmoid函数做为激活函数,二分类交叉熵函数做损失函数 2.可以同时对整个训练集进行训 ...

  5. 千层套路 - Vue 3.0 初始化源码探秘

    关注若川视野, 回复"pdf" 领取资料,回复"1",可加群长期交流学习 刘崇桢,微医云服务团队前端工程师,左手抱娃.右手持家的非典型码农. 9 月初 Vue. ...

  6. 1.3)深度学习笔记------浅层神经网络

    目录 1)Neural Network Overview 2)Neural Network Representation 3)Computing a Neural Network's Output(重 ...

  7. 深度学习笔记(4) 浅层神经网络

    深度学习笔记(4) 浅层神经网络 1. 神经网络概述 2. 激活函数 3. 激活函数的导数 4. 神经网络的梯度下降 5. 随机初始化 1. 神经网络概述 神经网络看起来是如下: 有输入特征x1.x2 ...

  8. 神经网络与深度学习三:编写单隐层神经网络

    三:编写单隐层神经网络 1 神经网络概述 这篇文章你会学到如何实现一个神经网络,在我们深入学习技术细节之前,现在先大概快速的了解一下如何实现神经网络,如果你对某些内容不甚理解(后面的文章中会深入其中的 ...

  9. 实验二 单隐层神经网络

    一.实验目的 (1)学习并掌握常见的机器学习方法: (2)能够结合所学的python知识实现机器学习算法: (3)能够用所学的机器学习算法解决实际问题. 二.实验内容与要求 (1)掌握神经网络的基本原 ...

最新文章

  1. 黄聪:IE6下用控制图片最大显示尺寸
  2. 小波包分解 matlab_多尺度一维小波分解
  3. FPGA之道(10)布线资源与接口资源
  4. 勿谈大,且看Bloomberg的中数据处理平台
  5. wed6699整站程序下载【首发】
  6. arcgis10.1连接sqlserver数据库常见问题(转载)
  7. php mvc登陆注册,Asp.Net MVC 5使用Identity之简单的注册和登陆
  8. 数论 + 公式 - HDU 4335 What is N?
  9. Linux操作Oracle(17)——linux oracle启动时 :各种报错 解决方案(2020.07.30更新...)
  10. 任务02——安装 Intellj IDEA,编写一个简易四则运算小程序,并将代码提交到 GitHub...
  11. oracle查询语句中case when的使用
  12. u盘ios刻录_用UltraISO刻录U盘安装系统
  13. 怎么将几张pdf合并成一张_怎么把多个PDF合并成一个PDF?分享合并PDF文件最简单的方法...
  14. 传智教育|如何转行互联网高薪岗位之一的软件测试?(附软件测试学习路线图)
  15. 植树问题java,云南省优秀多媒体育软件大赛公示.doc
  16. visio绘制网络拓扑图要求_必备!可以电脑在线使用的3款网络拓扑图软件安利
  17. 作为一名基层管理者如何利用情商管理自己和团队(二)
  18. android 覆盖虚拟按键,解决Android 虚拟按键遮住了页面内容的问题
  19. 豆瓣首页话题输入框的实现
  20. spring boot check/token Principal 如何注入

热门文章

  1. 华为Matebook14 预装office 重新安装
  2. oracle数据库ORA-00918: 未明确定义列
  3. 医学信息学相关术语、缩语及专业名词
  4. web H5 首页引导图
  5. android项目 之 来电管家(3) ----- 添加与删除黑名单
  6. 宅霸我的世界服务器无响应,宅霸我的世界 v4.5.0 官方版
  7. php一句话上传木马,一句话木马上传常见的几种方法
  8. unity开发炉石传说系列玩家手中卡牌出入及移动排列代码
  9. 3年测试经验,用例设计竟然不知道状态迁移法?
  10. Photoshop cc 2015 导出web所用格式提示错误 无法完成该操作怎么办?用这个方法轻松解决!