经典卷积网络--ResNet残差网络
经典卷积网络--ResNet残差网络
- 1、ResNet残差网络
- 2、tf.keras实现残差结构
- 3、tensorflow2.0实现ResNet18(使用CIFAR10数据集)
借鉴点:层间残差跳连,引入前方信息,减少梯度消失,使神经网络层数变身成为可能。
1、ResNet残差网络
ResNet 即深度残差网络,由何恺明及其团队提出,是深度学习领域又一具有开创性的工作,通过对残差结构的运用,ResNet 使得训练数百层的网络成为了可能,从而具有非常强大的表征能力,其网络结构如图所示。
ResNet
的核心是残差结构,如下图所示。在残差结构中,ResNet
不再让下一层直接拟合我们想得到的底层映射,而是令其对一种残差映射进行拟合。若期望得到的底层映射为 H(x),我们令堆叠的非线性层拟合另一个映射 F(x) := H(x) – x,则原有映射变为 F(x) + x。 对这种新的残差映射进行优化时,要比优化原有的非相关映射更为容易。不妨考虑极限情况, 如果一个恒等映射是最优的,那么将残差向零逼近显然会比利用大量非线性层直接进行拟合更容易。
值得一提的是,这里的相加与
InceptionNet
中的相加是有本质区别的,Inception
中的相 加是沿深度方向叠加,像“千层蛋糕”一样,对层数进行叠加;ResNet
中的相加则是特征图 对应元素的数值相加,类似于python
语法中基本的矩阵相加。
ResNet 引入残差结构最主要的目的是解决网络层数不断加深时导致的梯度消失问题, 从之前介绍的 4 种 CNN 经典网络结构我们也可以看出,网络层数的发展趋势是不断加深的。 这是由于深度网络本身集成了低层/中层/高层特征和分类器,以多层首尾相连的方式存在, 所以可以通过增加堆叠的层数(深度)来丰富特征的层次,以取得更好的效果。
但如果只是简单地堆叠更多层数,就会导致梯度消失(爆炸)问题,它从根源上导致了函数无法收敛。然而,通过标准初始化(normalized initialization)以及中间标准化层 (intermediate normalization layer),已经可以较好地解决这个问题了,这使得深度为数十层 的网络在反向传播过程中,可以通过随机梯度下降(SGD)的方式开始收敛。
但是,当深度更深的网络也可以开始收敛时,网络退化的问题就显露了出来:随着网络 深度的增加,准确率先是达到瓶颈(这是很常见的),然后便开始迅速下降。需要注意的是, 这种退化并不是由过拟合引起的。对于一个深度比较合适的网络来说,继续增加层数反而会导致训练错误率的提升, 下图就是一个例子。
ResNet
解决的正是这个问题,其核心思路为:对一个准确率达到饱和的浅层网络,在它后面加几个恒等映射层(即 y = x,输出等于输入),增加网络深度的同时不增加误差。 这使得神经网络的层数可以超越之前的约束,提高准确率。下图展示了 ResNet
中残差结构的具体用法。
上图中的实线和虚线均表示恒等映射,实线表示通道相同,计算方式为 H(x) = F(x) + x; 虚线表示通道不同,计算方式为 H(x) = F(x) + Wx,其中 W 为卷积操作,目的是调整 x 的维 度(通道数)。
2、tf.keras实现残差结构
我们同样可以借助 tf.keras 来实现这种残差结构,定义一个新的 ResnetBlock 类。
class ResnetBlock(Model):def __init__(self, filters, strides=1, residual_path=False):super(ResnetBlock, self).__init__()self.filters = filtersself.strides = stridesself.residual_path = residual_pathself.c1 = Conv2D(filters, (3, 3), strides=strides, padding='same', use_bias=False)self.b1 = BatchNormalization()self.a1 = Activation('relu')self.c2 = Conv2D(filters, (3, 3), strides=1, padding='same', use_bias=False)self.b2 = BatchNormalization()# residual_path为True时,对输入进行下采样,即用1x1的卷积核做卷积操作,保证x能和F(x)维度相同,顺利相加if residual_path:self.down_c1 = Conv2D(filters, (1, 1), strides=strides, padding='same', use_bias=False)self.down_b1 = BatchNormalization()self.a2 = Activation('relu')def call(self, inputs):residual = inputs # residual等于输入值本身,即residual=x# 将输入通过卷积、BN层、激活层,计算F(x)x = self.c1(inputs)x = self.b1(x)x = self.a1(x)x = self.c2(x)y = self.b2(x)if self.residual_path:residual = self.down_c1(inputs)residual = self.down_b1(residual)out = self.a2(y + residual) # 最后输出的是两部分的和,即F(x)+x或F(x)+Wx,再过激活函数return out
卷积操作仍然采用典型的 C、B、A 结构,激活采用 Relu
函数;为了保证 F(x)和 x 可以顺利相加,二者的维度必须相同,这里利用的是 1 * 1 卷积来实现
利用这种结构,就可以利用 tf.keras 来构建出 ResNet 模型,如下图所示。
对应代码:
class ResNet18(Model):def __init__(self, block_list, initial_filters=64): # block_list表示每个block有几个卷积层super(ResNet18, self).__init__()self.num_blocks = len(block_list) # 共有几个blockself.block_list = block_listself.out_filters = initial_filtersself.c1 = Conv2D(self.out_filters, (3, 3), strides=1, padding='same', use_bias=False)self.b1 = BatchNormalization()self.a1 = Activation('relu')self.blocks = tf.keras.models.Sequential()# 构建ResNet网络结构for block_id in range(len(block_list)): # 第几个resnet blockfor layer_id in range(block_list[block_id]): # 第几个卷积层if block_id != 0 and layer_id == 0: # 对除第一个block以外的每个block的输入进行下采样block = ResnetBlock(self.out_filters, strides=2, residual_path=True)else:block = ResnetBlock(self.out_filters, residual_path=False)self.blocks.add(block) # 将构建好的block加入resnetself.out_filters *= 2 # 下一个block的卷积核数是上一个block的2倍self.p1 = tf.keras.layers.GlobalAveragePooling2D()self.f1 = tf.keras.layers.Dense(10, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())def call(self, inputs):x = self.c1(inputs)x = self.b1(x)x = self.a1(x)x = self.blocks(x)x = self.p1(x)y = self.f1(x)return ymodel = ResNet18([2, 2, 2, 2])
参数 block_list
表示 ResNet
中 block
的数量;initial_filters
表示初始的卷积核数量。可以看到该模型同样使用了全局平均池化的方式来替代全连接层。
对于 ResNet
的残差单元来说,除了这里采用的两层结构外,还有一种三层结构,如下图所示。
两层残差单元多用于层数较少的网络,三层残差单元多用于层数较多的网络,以减少计算的参数量。
总体上看,ResNet
取得的成果还是相当巨大的,它将网络深度提升到了 152 层,于 2015 年将 ImageNet 图像识别 Top5 错误率降至 3.57 %。
3、tensorflow2.0实现ResNet18(使用CIFAR10数据集)
import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
from tensorflow.keras import Modelnp.set_printoptions(threshold=np.inf)cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0class ResnetBlock(Model):def __init__(self, filters, strides=1, residual_path=False):super(ResnetBlock, self).__init__()self.filters = filtersself.strides = stridesself.residual_path = residual_pathself.c1 = Conv2D(filters, (3, 3), strides=strides, padding='same', use_bias=False)self.b1 = BatchNormalization()self.a1 = Activation('relu')self.c2 = Conv2D(filters, (3, 3), strides=1, padding='same', use_bias=False)self.b2 = BatchNormalization()# residual_path为True时,对输入进行下采样,即用1x1的卷积核做卷积操作,保证x能和F(x)维度相同,顺利相加if residual_path:self.down_c1 = Conv2D(filters, (1, 1), strides=strides, padding='same', use_bias=False)self.down_b1 = BatchNormalization()self.a2 = Activation('relu')def call(self, inputs):residual = inputs # residual等于输入值本身,即residual=x# 将输入通过卷积、BN层、激活层,计算F(x)x = self.c1(inputs)x = self.b1(x)x = self.a1(x)x = self.c2(x)y = self.b2(x)if self.residual_path:residual = self.down_c1(inputs)residual = self.down_b1(residual)out = self.a2(y + residual) # 最后输出的是两部分的和,即F(x)+x或F(x)+Wx,再过激活函数return outclass ResNet18(Model):def __init__(self, block_list, initial_filters=64): # block_list表示每个block有几个卷积层super(ResNet18, self).__init__()self.num_blocks = len(block_list) # 共有几个blockself.block_list = block_listself.out_filters = initial_filtersself.c1 = Conv2D(self.out_filters, (3, 3), strides=1, padding='same', use_bias=False)self.b1 = BatchNormalization()self.a1 = Activation('relu')self.blocks = tf.keras.models.Sequential()# 构建ResNet网络结构for block_id in range(len(block_list)): # 第几个resnet blockfor layer_id in range(block_list[block_id]): # 第几个卷积层if block_id != 0 and layer_id == 0: # 对除第一个block以外的每个block的输入进行下采样block = ResnetBlock(self.out_filters, strides=2, residual_path=True)else:block = ResnetBlock(self.out_filters, residual_path=False)self.blocks.add(block) # 将构建好的block加入resnetself.out_filters *= 2 # 下一个block的卷积核数是上一个block的2倍self.p1 = tf.keras.layers.GlobalAveragePooling2D()self.f1 = tf.keras.layers.Dense(10, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())def call(self, inputs):x = self.c1(inputs)x = self.b1(x)x = self.a1(x)x = self.blocks(x)x = self.p1(x)y = self.f1(x)return ymodel = ResNet18([2, 2, 2, 2])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])checkpoint_save_path = "./checkpoint/ResNet18.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history = model.fit(x_train, y_train, batch_size=128, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])
# history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
# callbacks=[cp_callback])
model.summary()# print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:file.write(str(v.name) + '\n')file.write(str(v.shape) + '\n')file.write(str(v.numpy()) + '\n')
file.close()############################################### show ################################################ 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
acc和loss曲线:
模型摘要:
经典卷积网络--ResNet残差网络相关推荐
- CNN经典模型:深度残差网络(DRN)ResNet
一说起"深度学习",自然就联想到它非常显著的特点"深.深.深"(重要的事说三遍),通过很深层次的网络实现准确率非常高的图像识别.语音识别等能力.因此,我们自然很 ...
- 图像分类经典卷积神经网络—ResNet论文翻译(中英文对照版)—Deep Residual Learning for Image Recognition(深度残差学习的图像识别)
图像分类经典论文翻译汇总:[翻译汇总] 翻译pdf文件下载:[下载地址] 此版为中英文对照版,纯中文版请稳步:[ResNet纯中文版] Deep Residual Learning for Image ...
- Resnet 残差网络使用案例
Resnet 网络 深度残差网络(Deep residual network, ResNet)的提出是CNN图像史上的一件里程碑事件,在各类数据集上都有不凡的表现,Resnet是残差网络(Residu ...
- 图像分类经典卷积神经网络—ResNet论文翻译(纯中文版)—Deep Residual Learning for Image Recognition(深度残差学习的图像识别)
图像分类经典论文翻译汇总:[翻译汇总] 翻译pdf文件下载:[下载地址] 此版为纯中文版,中英文对照版请稳步:[ResNet中英文对照版] Deep Residual Learning for Ima ...
- ResNet残差网络及变体详解(符代码实现)
本文通过分析深度网络模型的缺点引出ResNet残差网络,并介绍了几种变体,最后用代码实现ResNet18. 文章目录 前言 模型退化 残差结构 ResNet网络结构 Pre Activation Re ...
- 目标检测学习笔记2——ResNet残差网络学习、ResNet论文解读
ResNet残差网络学习.ResNet论文解读 一.前言 为什么会提出ResNet? 什么是网络退化现象? 那网络退化现象是什么造成的呢? ResNet要如何解决退化问题? 二.残差模块 三.残差模块 ...
- 【五一创作】使用Resnet残差网络对图像进行分类(猫十二分类,模型定义、训练、保存、预测)(一)
使用Resnet残差网络对图像进行分类 (猫十二分类,模型定义.训练.保存.预测)(一) 目录 一.项目简介 二.环境说明 1.安装库 2.导入需要的库 三.分类过程 (1).解压数据集 (2).相关 ...
- (pytorch-深度学习系列)ResNet残差网络的理解-学习笔记
ResNet残差网络的理解 ResNet伴随文章 Deep Residual Learning for Image Recognition 诞生,该文章是MSRA何凯明团队在2015年ImageNet ...
- ResNet 残差网络、残差块
在深度学习中,为了增强模型的学习能力,网络的层数会不断的加深,于此同时,也伴随着一些比较棘手的问题,主要包括: ①模型复杂度上升,网络训练困难 ②出现梯度消失/梯度爆炸问题 ③网络退化,即增加层数并不 ...
- ResNet残差网络Pytorch实现——对花的种类进行训练
ResNet残差网络Pytorch实现--对花的种类进行训练 上一篇:[结合各个残差块] ✌✌✌✌ [目录] ✌✌✌✌ 下一篇:[对花的种类进行单数据预测] 大学生一枚,最近在学习神经网络,写这篇文章 ...
最新文章
- BaseTDI.sys 瑞星卡巴冲突,导致机器蓝屏
- UVa10000_Longest Paths(最短路SPFA)
- framebuffer
- .net MySql
- 沈阳药科大学计算机二级好考吗,沈阳药科大学考研难吗?一般要什么水平才可以进入?...
- 计算机设置重启时间表,电脑定时开关和重启方法
- kubernetes视频教程笔记 (15)-RC、RS和Deployment的关联
- 通过JDBC连接Oracle数据库中的十大技巧
- python怎么排名次_2019:python第3次获得TIOBE最佳年度语言排名
- SSM+Dubbox电商项目 - 品优购mall
- 数值积分方法的总结(从简单梯形积分到龙贝格积分、自适应积分、高斯积分等)
- 可以在一眨眼之间接触到庞大受众的方式
- 如何把PPT连背景一起复制?
- 2021年茶艺师(初级)考试题及茶艺师(初级)新版试题
- 3月下旬到5月中旬之前采制的茶叶
- 鸿蒙系统3.0多大内存,鸿蒙2.0终于上机实测 多大内存能跑?
- 在华为云 CCE 上部署 EMQX MQTT 服务器集群
- 电脑增加内存修改注册表,让你的电脑快到停不下来
- mongodb中的in和notin的查询
- JAVA毕设项目个性化推荐的扬州农业文化旅游管理平台(java+VUE+Mybatis+Maven+Mysql)
热门文章
- POS机全国产化电子元件推荐方案
- 高中数学知识点总结:函数零点经典例题解题技巧与方法总结
- 姓名生成---拼音简码(大小写)---拼音全码(大小写)
- 小熊派BearPi-IoT(GD)之IoT Studio开发环境搭建
- java毕业答辩_Java毕业设计答辩技巧
- how to use 1checker_vim command
- 【原创】JS文件替换神器--Chrome ReRes插件
- WIN10环境下VS2003的安装
- js实现双人对战五子棋
- 离线强化学习总结!(原理、数据集、算法、复杂性分析、超参数调优等)