深度残差网络(ResNet)详解与实现(tensorflow2.x)

  • ResNet原理
  • ResNet实现
    • 模型创建
    • 数据加载
    • 模型编译
    • 模型训练
    • 测试模型
    • 训练过程

ResNet原理

深层网络在学习任务中取得了超越人眼的准确率,但是,经过实验表明,模型的性能和模型的深度并非成正比,是由于模型的表达能力过强,反而在测试数据集中性能下降。ResNet的核心是,为了防止梯度弥散或爆炸,让信息流经快捷连接到达浅层。
更正式的讲,输入xxx通过卷积层,得到特征变换后的输出F(x)F(x)F(x),与输入xxx进行对应元素的相加运算,得到最终输出H(x)H(x)H(x):
H(x)=x+F(x)H(x) = x + F(x)H(x)=x+F(x)
VGG模块和残差模块对比如下:


为了能够满足输入xxx与卷积层的输出F(x)F(x)F(x)能够相加运算,需要输入xxx的 shape 与F(x)F(x)F(x)的shape 完全一致。当出现 shape 不一致时,一般通过Conv2D进行变换,该Conv2D的核为1×1,步幅为2。

ResNet实现

使用tensorflow2.3实现ResNet

模型创建

import numpy as np
import tensorflow as tf
from tensorflow import keras
from matplotlib import pyplot as plt
import os
import math
"""
用于控制模型层数
"""
#残差块数
n = 3
depth = n * 9 + 1
def resnet_layer(inputs,num_filters=16,kernel_size=3,strides=1,activation='relu',batch_normalization=True,conv_first=True):"""2D Convolution-Batch Normalization-Activation stack builderArguments:inputs (tensor): 输入num_filters (int): 卷积核个数kernel_size (int): 卷积核大小activation (string): 激活层batch_normalization (bool): 是否使用批归一化conv_first (bool): conv-bn-active(True) or bn-active-conv (False)层堆叠次序Returns:x (tensor): 输出"""conv = keras.layers.Conv2D(num_filters,kernel_size=kernel_size,strides=strides,padding='same',kernel_initializer='he_normal',kernel_regularizer=keras.regularizers.l2(1e-4))x = inputsif conv_first:x = conv(x)if batch_normalization:x = keras.layers.BatchNormalization()(x)if activation is not None:x = keras.layers.Activation(activation)(x)else:if batch_normalization:x = keras.layers.BatchNormalization()(x)if activation is not None:x = keras.layers.Activation(activation)(x)x = conv(x)return xdef resnet(input_shape,depth,num_classes=10):"""ResNetArguments:input_shape (tensor): 输入尺寸depth (int): 网络层数num_classes (int): 预测类别数Return:model (Model): 模型"""if (depth - 2) % 6 != 0:raise ValueError('depth should be 6n+2')#超参数num_filters = 16num_res_blocks = int((depth - 2) / 6)inputs = keras.layers.Input(shape=input_shape)x = resnet_layer(inputs=inputs)for stack in range(3):for res_block in range(num_res_blocks):strides = 1if stack > 0 and res_block == 0:strides = 2y = resnet_layer(inputs=x,num_filters=num_filters,strides=strides)y = resnet_layer(inputs=y,num_filters=num_filters,activation=None)if stack > 0 and res_block == 0:x = resnet_layer(inputs=x,num_filters=num_filters,kernel_size=1,strides=strides,activation=None,batch_normalization=False)x = keras.layers.add([x,y])x = keras.layers.Activation('relu')(x)num_filters *= 2x = keras.layers.AveragePooling2D(pool_size=8)(x)x = keras.layers.Flatten()(x)outputs = keras.layers.Dense(num_classes,activation='softmax',kernel_initializer='he_normal')(x)model = keras.Model(inputs=inputs,outputs=outputs)return modelmodel = resnet_v1(input_shape=input_shape,depth=depth)

数据加载

#加载数据
(x_train,y_train),(x_test,y_test) = keras.datasets.cifar10.load_data()#计算类别数
num_labels = len(np.unique(y_train))#转化为one-hot编码
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)#预处理
input_shape = x_train.shape[1:]
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

模型编译

#超参数
batch_size = 64
epochs = 200
#编译模型
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['acc'])
model.summary()

模型训练

model.fit(x_train,y_train,batch_size=batch_size,epochs=epochs,validation_data=(x_test,y_test),shuffle=True)

测试模型

scores = model.evaluate(x_test,y_test,batch_size=batch_size,verbose=0)
print('Test loss: ',scores[0])
print('Test accuracy: ',scores[1])

训练过程

Epoch 104/200
782/782 [==============================] - ETA: 0s - loss: 0.2250 - acc: 0.9751
Epoch 00104: val_acc did not improve from 0.91140
782/782 [==============================] - 15s 19ms/step - loss: 0.2250 - acc: 0.9751 - val_loss: 0.4750 - val_acc: 0.9090
learning rate:  0.0001
Epoch 105/200
781/782 [============================>.] - ETA: 0s - loss: 0.2206 - acc: 0.9754
Epoch 00105: val_acc did not improve from 0.91140
782/782 [==============================] - 16s 20ms/step - loss: 0.2206 - acc: 0.9754 - val_loss: 0.4687 - val_acc: 0.9078
learning rate:  0.0001
Epoch 106/200
782/782 [==============================] - ETA: 0s - loss: 0.2160 - acc: 0.9769
Epoch 00106: val_acc did not improve from 0.91140
782/782 [==============================] - 15s 20ms/step - loss: 0.2160 - acc: 0.9769 - val_loss: 0.4886 - val_acc: 0.9053

深度残差网络(ResNet)详解与实现(tensorflow2.x)相关推荐

  1. 残差网络resnet详解

    1 产生背景 网络的深度对于特征提取具有至关重要的作用,实验证得,如果简单的增加网络深度,会引起退化问题[Degradation问题],即准确率先上升然后达到饱和,再持续增加深度会导致准确率下降.该实 ...

  2. dlibdotnet 人脸相似度源代码_使用dlib中的深度残差网络(ResNet)实现实时人脸识别 - supersayajin - 博客园...

    opencv中提供的基于haar特征级联进行人脸检测的方法效果非常不好,本文使用dlib中提供的人脸检测方法(使用HOG特征或卷积神经网方法),并使用提供的深度残差网络(ResNet)实现实时人脸识别 ...

  3. 何恺明编年史之深度残差网络ResNet

    文章目录 前言 一.提出ResNet原因 二.深度残差模块 1.数学理论基础 2.深度网络结构 三.Pytorch代码实现 四.总结 前言 图像分类是计算机视觉任务的基石,在目标监测.图像分割等任务中 ...

  4. TF2.0深度学习实战(七):手撕深度残差网络ResNet

    写在前面:大家好!我是[AI 菌],一枚爱弹吉他的程序员.我热爱AI.热爱分享.热爱开源! 这博客是我对学习的一点总结与记录.如果您也对 深度学习.机器视觉.算法.Python.C++ 感兴趣,可以关 ...

  5. 【深度学习之ResNet】——深度残差网络—ResNet总结

    目录 论文名称:Deep Residual Learning for Image Recognition 摘要: 1.引言 2.为什么会提出ResNet残差网络呢? 3.深度残差网络结构学习(Deep ...

  6. 深度残差网络ResNet解析

    ResNet在2015年被提出,在ImageNet比赛classification任务上获得第一名,因为它"简单与实用"并存,之后很多方法都建立在ResNet50或者ResNet1 ...

  7. 深度残差网络RESNET

    一.残差神经网络--ResNet的综述 深度学习网络的深度对最后的分类和识别的效果有着很大的影响,所以正常想法就是能把网络设计的越深越好, 但是事实上却不是这样,常规的网络的堆叠(plain netw ...

  8. 【深度学习】深度残差网络ResNet

    文章目录 1 残差网络ResNet 1.1要解决的问题 1.2 残差网络结构 1.3 捷径连接 1.4 总结 1 残差网络ResNet 1.1要解决的问题   在传统CNN架构中,如果我们简单堆叠CN ...

  9. 通过深度残差网络ResNet进行图像分类(pytorch网络多网络集成配置)

    通过深度残差网络进行图像分类(pytorch网络多网络集成配置) 简介 本项目通过配置文件修改,实现pytorch的ResNet18, ResNet34, ResNet50, ResNet101, R ...

  10. ResNet(残差网络)详解

    ResNet在<Deep Residual Learning for Image Recognition>论文中提出,是在CVPR 2016发表的一种影响深远的网络模型,由何凯明大神团队提 ...

最新文章

  1. matlab 自定义对象,面向对象:MATLAB的自定义类 [MATLAB]
  2. AdvancedEAST笔记
  3. 2015 百度之星 1004 KPI STL的妙用
  4. C语言判别输入的东东
  5. Python 科学计算库 Numpy (二) —— 索引及切片
  6. IT项目管理总结:第四章 项目综合管理
  7. PHP技巧:PATH_SEPARATOR是什么(Zend Framework引导文件中的路径用法)
  8. iPhone开发 No IB UITextField 设置圆角
  9. 计算机视觉:关于Graph cuts的简介及相关资源
  10. Atitti 图像处理 特征提取的科技树 attilax总结
  11. SQL Server 数据库的创建
  12. RouterOS利用(L2TP)实现异地组网
  13. Zigbee 协议栈网络管理
  14. 编程修养 from匠人的百宝箱
  15. web 基于jquery和canvas的打飞机小游戏
  16. 计算机作文范文,未来计算机作文范文.docx
  17. java对接中金支付接口
  18. 微软首席执行官鲍尔默简历
  19. CODEBLOCKS 17.12汉化
  20. 「业务架构」定义业务能力-备忘单

热门文章

  1. 【软工项目组】第十八次会议
  2. 注意地方hadoop中的pi值计算
  3. [转载] python复数类型-Python 复数属性和方法操作实例
  4. ALTERA 命名规则
  5. RIA and volta
  6. MYSQL:RELPACE用法
  7. sql server中除数为零的处理技巧
  8. 小故事:在缺陷中发现长处
  9. 如何在网页中弹出的模式窗口,就像C/S中的SHOWMODAL类型窗口
  10. (二)设置hexo支持mermaid