深度残差网络(ResNet)详解与实现(tensorflow2.x)
深度残差网络(ResNet)详解与实现(tensorflow2.x)
- ResNet原理
- ResNet实现
- 模型创建
- 数据加载
- 模型编译
- 模型训练
- 测试模型
- 训练过程
ResNet原理
深层网络在学习任务中取得了超越人眼的准确率,但是,经过实验表明,模型的性能和模型的深度并非成正比,是由于模型的表达能力过强,反而在测试数据集中性能下降。ResNet的核心是,为了防止梯度弥散或爆炸,让信息流经快捷连接到达浅层。
更正式的讲,输入xxx通过卷积层,得到特征变换后的输出F(x)F(x)F(x),与输入xxx进行对应元素的相加运算,得到最终输出H(x)H(x)H(x):
H(x)=x+F(x)H(x) = x + F(x)H(x)=x+F(x)
VGG模块和残差模块对比如下:
为了能够满足输入xxx与卷积层的输出F(x)F(x)F(x)能够相加运算,需要输入xxx的 shape 与F(x)F(x)F(x)的shape 完全一致。当出现 shape 不一致时,一般通过Conv2D进行变换,该Conv2D的核为1×1,步幅为2。
ResNet实现
使用tensorflow2.3实现ResNet
模型创建
import numpy as np
import tensorflow as tf
from tensorflow import keras
from matplotlib import pyplot as plt
import os
import math
"""
用于控制模型层数
"""
#残差块数
n = 3
depth = n * 9 + 1
def resnet_layer(inputs,num_filters=16,kernel_size=3,strides=1,activation='relu',batch_normalization=True,conv_first=True):"""2D Convolution-Batch Normalization-Activation stack builderArguments:inputs (tensor): 输入num_filters (int): 卷积核个数kernel_size (int): 卷积核大小activation (string): 激活层batch_normalization (bool): 是否使用批归一化conv_first (bool): conv-bn-active(True) or bn-active-conv (False)层堆叠次序Returns:x (tensor): 输出"""conv = keras.layers.Conv2D(num_filters,kernel_size=kernel_size,strides=strides,padding='same',kernel_initializer='he_normal',kernel_regularizer=keras.regularizers.l2(1e-4))x = inputsif conv_first:x = conv(x)if batch_normalization:x = keras.layers.BatchNormalization()(x)if activation is not None:x = keras.layers.Activation(activation)(x)else:if batch_normalization:x = keras.layers.BatchNormalization()(x)if activation is not None:x = keras.layers.Activation(activation)(x)x = conv(x)return xdef resnet(input_shape,depth,num_classes=10):"""ResNetArguments:input_shape (tensor): 输入尺寸depth (int): 网络层数num_classes (int): 预测类别数Return:model (Model): 模型"""if (depth - 2) % 6 != 0:raise ValueError('depth should be 6n+2')#超参数num_filters = 16num_res_blocks = int((depth - 2) / 6)inputs = keras.layers.Input(shape=input_shape)x = resnet_layer(inputs=inputs)for stack in range(3):for res_block in range(num_res_blocks):strides = 1if stack > 0 and res_block == 0:strides = 2y = resnet_layer(inputs=x,num_filters=num_filters,strides=strides)y = resnet_layer(inputs=y,num_filters=num_filters,activation=None)if stack > 0 and res_block == 0:x = resnet_layer(inputs=x,num_filters=num_filters,kernel_size=1,strides=strides,activation=None,batch_normalization=False)x = keras.layers.add([x,y])x = keras.layers.Activation('relu')(x)num_filters *= 2x = keras.layers.AveragePooling2D(pool_size=8)(x)x = keras.layers.Flatten()(x)outputs = keras.layers.Dense(num_classes,activation='softmax',kernel_initializer='he_normal')(x)model = keras.Model(inputs=inputs,outputs=outputs)return modelmodel = resnet_v1(input_shape=input_shape,depth=depth)
数据加载
#加载数据
(x_train,y_train),(x_test,y_test) = keras.datasets.cifar10.load_data()#计算类别数
num_labels = len(np.unique(y_train))#转化为one-hot编码
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)#预处理
input_shape = x_train.shape[1:]
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
模型编译
#超参数
batch_size = 64
epochs = 200
#编译模型
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['acc'])
model.summary()
模型训练
model.fit(x_train,y_train,batch_size=batch_size,epochs=epochs,validation_data=(x_test,y_test),shuffle=True)
测试模型
scores = model.evaluate(x_test,y_test,batch_size=batch_size,verbose=0)
print('Test loss: ',scores[0])
print('Test accuracy: ',scores[1])
训练过程
Epoch 104/200
782/782 [==============================] - ETA: 0s - loss: 0.2250 - acc: 0.9751
Epoch 00104: val_acc did not improve from 0.91140
782/782 [==============================] - 15s 19ms/step - loss: 0.2250 - acc: 0.9751 - val_loss: 0.4750 - val_acc: 0.9090
learning rate: 0.0001
Epoch 105/200
781/782 [============================>.] - ETA: 0s - loss: 0.2206 - acc: 0.9754
Epoch 00105: val_acc did not improve from 0.91140
782/782 [==============================] - 16s 20ms/step - loss: 0.2206 - acc: 0.9754 - val_loss: 0.4687 - val_acc: 0.9078
learning rate: 0.0001
Epoch 106/200
782/782 [==============================] - ETA: 0s - loss: 0.2160 - acc: 0.9769
Epoch 00106: val_acc did not improve from 0.91140
782/782 [==============================] - 15s 20ms/step - loss: 0.2160 - acc: 0.9769 - val_loss: 0.4886 - val_acc: 0.9053
深度残差网络(ResNet)详解与实现(tensorflow2.x)相关推荐
- 残差网络resnet详解
1 产生背景 网络的深度对于特征提取具有至关重要的作用,实验证得,如果简单的增加网络深度,会引起退化问题[Degradation问题],即准确率先上升然后达到饱和,再持续增加深度会导致准确率下降.该实 ...
- dlibdotnet 人脸相似度源代码_使用dlib中的深度残差网络(ResNet)实现实时人脸识别 - supersayajin - 博客园...
opencv中提供的基于haar特征级联进行人脸检测的方法效果非常不好,本文使用dlib中提供的人脸检测方法(使用HOG特征或卷积神经网方法),并使用提供的深度残差网络(ResNet)实现实时人脸识别 ...
- 何恺明编年史之深度残差网络ResNet
文章目录 前言 一.提出ResNet原因 二.深度残差模块 1.数学理论基础 2.深度网络结构 三.Pytorch代码实现 四.总结 前言 图像分类是计算机视觉任务的基石,在目标监测.图像分割等任务中 ...
- TF2.0深度学习实战(七):手撕深度残差网络ResNet
写在前面:大家好!我是[AI 菌],一枚爱弹吉他的程序员.我热爱AI.热爱分享.热爱开源! 这博客是我对学习的一点总结与记录.如果您也对 深度学习.机器视觉.算法.Python.C++ 感兴趣,可以关 ...
- 【深度学习之ResNet】——深度残差网络—ResNet总结
目录 论文名称:Deep Residual Learning for Image Recognition 摘要: 1.引言 2.为什么会提出ResNet残差网络呢? 3.深度残差网络结构学习(Deep ...
- 深度残差网络ResNet解析
ResNet在2015年被提出,在ImageNet比赛classification任务上获得第一名,因为它"简单与实用"并存,之后很多方法都建立在ResNet50或者ResNet1 ...
- 深度残差网络RESNET
一.残差神经网络--ResNet的综述 深度学习网络的深度对最后的分类和识别的效果有着很大的影响,所以正常想法就是能把网络设计的越深越好, 但是事实上却不是这样,常规的网络的堆叠(plain netw ...
- 【深度学习】深度残差网络ResNet
文章目录 1 残差网络ResNet 1.1要解决的问题 1.2 残差网络结构 1.3 捷径连接 1.4 总结 1 残差网络ResNet 1.1要解决的问题 在传统CNN架构中,如果我们简单堆叠CN ...
- 通过深度残差网络ResNet进行图像分类(pytorch网络多网络集成配置)
通过深度残差网络进行图像分类(pytorch网络多网络集成配置) 简介 本项目通过配置文件修改,实现pytorch的ResNet18, ResNet34, ResNet50, ResNet101, R ...
- ResNet(残差网络)详解
ResNet在<Deep Residual Learning for Image Recognition>论文中提出,是在CVPR 2016发表的一种影响深远的网络模型,由何凯明大神团队提 ...
最新文章
- matlab 自定义对象,面向对象:MATLAB的自定义类 [MATLAB]
- AdvancedEAST笔记
- 2015 百度之星 1004 KPI STL的妙用
- C语言判别输入的东东
- Python 科学计算库 Numpy (二) —— 索引及切片
- IT项目管理总结:第四章 项目综合管理
- PHP技巧:PATH_SEPARATOR是什么(Zend Framework引导文件中的路径用法)
- iPhone开发 No IB UITextField 设置圆角
- 计算机视觉:关于Graph cuts的简介及相关资源
- Atitti 图像处理 特征提取的科技树 attilax总结
- SQL Server 数据库的创建
- RouterOS利用(L2TP)实现异地组网
- Zigbee 协议栈网络管理
- 编程修养 from匠人的百宝箱
- web 基于jquery和canvas的打飞机小游戏
- 计算机作文范文,未来计算机作文范文.docx
- java对接中金支付接口
- 微软首席执行官鲍尔默简历
- CODEBLOCKS 17.12汉化
- 「业务架构」定义业务能力-备忘单