前言

  ResNet(Residual Neural Network)由前微软研究院的 Kaiming He 等4名华人提出(有兴趣的可以点击这里,查看论文原文),通过使用 Residual Blocks 成功训练152层深的神经网络,在 ILSVRC 2015 比赛中获得了冠军,取得 3.57% 的 top-5 错误率,同时参数量却比 VGGNet 低,效果非常突出。ResNet 的结构可以极快地加速超深神经网络的训练,模型的准确率也有非常大的提升。上一篇博文讲解了 Inception,而 Inception V4 则是将 Inception Module 和 ResNet 相结合。可以看到 ResNet 是一个推广性非常好的网络结构,甚至可以直接应用到 Inception Net 中。


网络原理

  在ResNet提出之前,瑞士的Schmidhuber教授提出了Highway Network,原理与ResNet相似。Highway Network的目标就是为了解决极深的神经网络难以训练的问题,因为,通常看来神经网络越深其性能越好,但是网络越深其训练难度也就越大。Highway Network相当于修改了每一层的激活函数,此前的激活函数只是对输入做一个线性变换y=H(x,WH)y=H(x,WH)y=H(x,W_{H}),Highway Network则允许保留一定比例的原始输入x,即y=H(x,WH)∗T(x,WT)+x∗C(x,WC)y=H(x,WH)∗T(x,WT)+x∗C(x,WC)y=H(x,W_{H})*T(x,W_{T})+x*C(x,W_{C})。其中T为变换稀疏,C为保留系数。这样有一定比例的前一层信息不经过矩阵变换和非线性变换,直接传输到下一层,仿佛一层高速公路,因此命名为Highway Network。几百乃至上千层的Highway Network可以直接使用梯度下降算法训练,并且可以配合多种非线性激活函数,使得学习极深的神经网络也变得可行。Highway Network在理论上是可以训练任意深度的网络的,而传统的神经网络结构则对深度十分敏感。
  ResNet与Highway Network相似,也是允许前面的信息直接传递都后面的层中。而ResNet最初的灵感来自于这个问题:随着网络不断地加深,会出现一个Degradation的问题,即准确率会上升然后达到饱和,继续持续增加网络深度会导致准确率下降。如下图所示;

  试想,当一个网络已经达到了饱和的准确率,那么再在后面加上几个y=xy=xy=x的全等映射层起码不会增加错误。这里提到的全等映射的思想就是ResNet的灵感来源。假设某段网络的输入是xxx,期望输出是H(x)" role="presentation" style="position: relative;">H(x)H(x)H(x),如果直接把输入作为输出的初始结果,那么此时网络中训练的就是两者之差F(x)=H(x)−xF(x)=H(x)−xF(x)=H(x)-x。也就是说残差学习网络学习的不再是一个完整的输出,而是输出对输入做差值F(x)F(x)F(x)。下图是VGG-19、直连的34层网络、ResNet的34层结构的对比。

  下图为ResNet中两种残差学习单元的结构,其基础结构相近,两层的残差学习单元中包含了两个相同输出同单数的3x3卷积层;3层学习单元则使用了Inception Net中的1x1的卷积层,并且在中间3x3的卷积前后都使用了1x1卷积。需要注意的是,如果在卷积层中改变了Tensor的维度,需要对输入做一个线性的维度变换,保证能够衔接到后层。(在下方代码中,实现这一功能的函数为conv2d_same)

   下图是论文中提到的ResNet不同层数时的网络结构。在使用了ResNet的结构后,Degradation的现象被消除了,而且训练误差会随着层数的增大而减小,在测试集上的表现也会变好。在ResNet提出后不久,Google借鉴了ResNet的精髓,提出了Inception V4和Inception-ResNet-V2,通过融合这两个模型,在ILSVRC数据集上取得了惊人的3.08%的错误率。可见,ResNet 及其思想对卷积神经网络研究的贡献确实非常显著,具有很强的推广性。


为什么残差神经网络能有这么好的效果?

   上面提到,当网络参数快要达到稳定时,残差学习单元中的shortcut(即跳跃连接)部分的作用显现出来了。它将网络前方的输入直接传递至后方(相当于卷积层的部分对于H(x)H(x)H(x)的贡献相当小了,残差学习单元此时近似于一个恒等函数),保证了网络结构的稳定。如果,残差学习单元中的卷积层部分有学到一些有用的内容,那么残差学习单元的效果恒等函数更好,网络的性能得到了提升。
   在训练的时候,残差学习单元中的shortcut是恒等函数,它学习起来非常容易,并不会影响网络训练的效率。


Tensorflow实现网络

  根据《Tensorflow 实战》的内容,我实现了一下152层深度的ResNet。代码中有备注,可以供大家参考。

# -*- coding: utf-8 -*-
"""
Created on Tue Jul  3 09:46:34 2018@author: most_pan
"""import collections
import time
from datetime import datetime
import mathimport tensorflow as tf
slim =tf.contrib.slimclass Block(collections.namedtuple('Block',['scope','unit_fn','args'])):'A named tuple describing a ResNet block'
#unit_fn是残差学习元生成函数
#args是一个长度等于Block中单元数目的序列,序列中的每个元素
#包含第三层通道数,前两层通道数以及中间层步长(depth, depth_bottleneck, stride)三个量
#在定义一个Block类的对象时,需要提供三个信息,分别是scope,残差学习单元生成函数,以及参数列表#降采样函数
def subsample(inputs,factor,scope=None):if factor==1:return inputselse:return slim.max_pool2d(inputs,[1,1],stride=factor,scope=scope)##用于保证维度一致
def conv2d_same(inputs,num_outputs,kernel_size,stride,scope=None):if stride==1:return slim.conv2d(inputs,num_outputs,kernel_size,stride=1,padding='SAME',scope=scope)else:pad_total=kernel_size - 1pad_beg=pad_total // 2pad_end=pad_total-pad_beg
#使用tf.pad对图像进行填充inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride,padding='VALID', scope=scope)@slim.add_arg_scope
def stack_blocks_dense(net,blocks,outputs_collections=None):for block in blocks:with tf.variable_scope(block.scope,'block',[net])as sc:for i,unit in enumerate(block.args):with tf.variable_scope('unit_%d'%(i+1),values=[net]):unit_depth,unit_depth_bottleneck,unit_stride=unit    #获取每个Block中的参数,包括第三层通道数,前两层通道数以及中间层步长
#unit_fn是Block类的残差神经元生成函数,它按顺序创建残差学习元并进行连接net=block.unit_fn(net,depth=unit_depth,depth_bottleneck=unit_depth_bottleneck,stride=unit_stride)net=slim.utils.collect_named_outputs(outputs_collections,sc.name,net)return netdef resnet_arg_scope(is_training=True,weight_decay=0.0001,batch_norm_decay=0.997,batch_norm_epsilon=1e-5,batch_norm_scale=True):batch_norm_params = {'decay': batch_norm_decay,'epsilon': batch_norm_epsilon,'scale': batch_norm_scale,'updates_collections': tf.GraphKeys.UPDATE_OPS,'fused': None,  # Use fused batch norm if possible.}with slim.arg_scope([slim.conv2d],weights_regularizer=slim.l2_regularizer(weight_decay),weights_initializer=slim.variance_scaling_initializer(),activation_fn=tf.nn.relu,normalizer_fn=slim.batch_norm,normalizer_params=batch_norm_params):with slim.arg_scope([slim.batch_norm], **batch_norm_params):with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:return arg_sc#利用add_arg_scope使bottleneck函数能够直接使用slim.arg_scope设置默认参数
@slim.add_arg_scope
def bottleneck(inputs,depth,depth_bottleneck,stride,outputs_collections=None,scope=None):with tf.variable_scope(scope,'bottleneck_v2',[inputs])as sc:
#获取输入的通道数目depth_in=slim.utils.last_dimension(inputs.get_shape(),min_rank=4)
#先对输入进行batch_norm,再进行非线性激活preact=slim.batch_norm(inputs,activation_fn=tf.nn.relu,scope='preact')#如果残差神经元的输出通道数目和输入的通道数目相同,那么直接对图像进行降采样,以保证shortcut尺寸和经历三个卷积层后的输出的此存相同        if depth==depth_in:shortcut=subsample(inputs,stride,'shortcut')
#如果残差神经元的输出通道数目和输入的通道数目不同,利用尺寸为1x1的卷积核对输入进行卷积,使输入通道数相同;else:shortcut=slim.conv2d(preact,depth,[1,1],stride=stride,normalizer_fn=None,activation_fn=None,scope='shortcut')
#然后,定义三个卷积层           residual=slim.conv2d(preact,depth_bottleneck,[1,1],stride=1,scope='conv1')residual=conv2d_same(residual,depth_bottleneck,3,stride,scope='conv2')residual=slim.conv2d(residual,depth,[1,1],stride=1,normalizer_fn=None,activation_fn=None,scope='conv3')#将shortcut和residual相加,作为输出        output=shortcut+residualreturn slim.utils.collect_named_outputs(outputs_collections,sc.name,output)#input是输入,blocks包含残差学习元的参数,num_classes是输出分类数,global_pool是是否进行平均池化的标志位;
def resnet_v2(inputs,blocks,num_classes=None,global_pool=True,include_root_block=True,reuse=None,scope=None):with tf.variable_scope(scope,'resnet_v2',[inputs],reuse=reuse) as sc:end_points_collection=sc.original_name_scope+'_end_points'with slim.arg_scope([slim.conv2d,bottleneck,stack_blocks_dense],outputs_collections=end_points_collection):net=inputsif include_root_block:with slim.arg_scope([slim.conv2d],activation_fn=None,normalizer_fn=None):
#卷积核为7x7步长为2的卷积层net=conv2d_same(net,64,7,stride=2,scope='conv1')
#最大值池化net=slim.max_pool2d(net,[3,3],stride=2,scope='pool1')
#调用stack_blocks_dense堆叠50个残差学习元,每个有三个卷积层net=stack_blocks_dense(net,blocks)
#先做batch norm然后使用relu激活net=slim.batch_norm(net,activation_fn=tf.nn.relu,scope='postnorm')if global_pool:     #进行平均池化net=tf.reduce_mean(net,[1,2],name='pool5',keep_dims=True)
#一个输出为num_classes的卷积层,不进行激活也不归一正则化。if num_classes is not None:net=slim.conv2d(net,num_classes,[1,1],activation_fn=None,normalizer_fn=None,scope='logits')end_points=slim.utils.convert_collection_to_dict(end_points_collection)#使用softmax进行分类         if num_classes is not None:end_points['predictions']=slim.softmax(net,scope='predictions')return net,end_points#152层残差网络
def resnet_v2_152(inputs,num_classes=None,global_pool=True,reuse=None,scope='resnet_v2_152'):blocks=[Block('block1',bottleneck,[(256,64,1)]*2+[(256,64,2)]),Block('block2',bottleneck,[(512,128,1)]*7+[(512,128,2)]),Block('block3',bottleneck,[(1024,256,1)]*35+[(1024,256,2)]),Block('block4',bottleneck,[(2048,512,1)]*3)]return resnet_v2(inputs,blocks,num_classes,global_pool,include_root_block=True,reuse=reuse,scope=scope)#测试性能定义的函数
def time_tensorflow_run(session, target, info_string):num_steps_burn_in = 10total_duration = 0.0total_duration_squared = 0.0for i in range(num_batches + num_steps_burn_in):start_time = time.time()_ = session.run(target)duration = time.time() - start_timeif i >= num_steps_burn_in:if not i % 10:print('%s: step %d, duration = %.3f' %(datetime.now(), i - num_steps_burn_in, duration))total_duration += durationtotal_duration_squared += duration * durationmn = total_duration / num_batchesvr = total_duration_squared / num_batches - mn * mnsd = math.sqrt(vr)print('%s: %s across %d steps, %.3f +/- %.3f sec / batch' %(datetime.now(), info_string, num_batches, mn, sd))batch_size,height,width=32,224,224inputs=tf.random_uniform((batch_size,height,width,3))
with slim.arg_scope(resnet_arg_scope(is_training=False)):net,end_points=resnet_v2_152(inputs,1000)init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
num_batches=100
time_tensorflow_run(sess,net,"ForWard")

代码分析

   上面代码中,resnet_v2_152函数用于设置网络中的参数,它主要是生成了一个变量blocks,并传递给resnet_v2函数。blocks是一个数组,每个元素都是一个Block类对象,每个对象有三个属性,分别是scope,残差学习单元生成函数bottleneck,以及残差学习单元的参数列表。bottleneck根据列表中的内容生成残差学习单元。其它函数的说明在代码中都有标注,大家可以自行理解。


运行结果

已完。。


参考书籍

《Tensorflow 实战》黄文坚等著;

【Tensorflow】深度学习实战06——Tensorflow实现ResNet相关推荐

  1. 深度学习实战—基于TensorFlow 2.0的人工智能开发应用

    作者:辛大奇 著 出版社:中国水利水电出版社 品牌:智博尚书 出版时间:2020-10-01 深度学习实战-基于TensorFlow 2.0的人工智能开发应用

  2. 跨年之际,中文版畅销书《TensorFlow深度学习实战大全》分享,直接送!

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 跨年之际,给大家一份福利,赠书抽奖,一共4本!感兴趣的同学可以参与一下,奖品是新书&l ...

  3. TensorFlow 深度学习实战指南中文版

    TensorFlow 深度学习实战指南中文版 第 1 章入门 安装 TensorFlow 简单的计算 逻辑回归模型构建 逻辑回归训练 第 2 章深度神经网络 基本神经网络 单隐藏层模型 单隐藏层的说明 ...

  4. Tensorflow深度学习实战之(七)--MP神经元与BP神经网络模型

    本文是在GPU版本的Tensorflow = 2.6.2 , 英伟达显卡驱动CUDA版本 =11.6,Python版本 = 3.6, 显卡为3060的环境下进行验证实验的!!! 文章目录 一.M-P神 ...

  5. 转:tensorflow深度学习实战笔记(二):把训练好的模型进行固化

    原文地址:https://blog.csdn.net/chenyuping333/article/details/82106863 目录 一.导出前向传播图 二.对模型进行固化 三.pb文件转tfli ...

  6. 【Tensorflow】深度学习实战03——Tensorflow实现AlexNet

    [fishing-pan:https://blog.csdn.net/u013921430转载请注明出处] 前言 前两篇博文中分别利用卷积神经网络识别手写数字和对CIFAR-10数据集分类,在这两次的 ...

  7. 【Tensorflow】深度学习实战01——Tensorflow实现简单的卷积网络(MNIST)

    [fishing-pan:https://blog.csdn.net/u013921430转载请注明出处] 前言 现在深度学习可以说很是热门,自己也非常感兴趣,之前有看过吴恩达老师的课程,也看过一些书 ...

  8. 【Tensorflow】深度学习实战05——Tensorflow实现Inception V3

    [fishing-pan:https://blog.csdn.net/u013921430转载请注明出处] 前言 前些日子在忙其他的事情,一直没有更新自己学习神经网络的博客,就在端午这天更吧!也祝大家 ...

  9. 【Tensorflow】深度学习实战04——Tensorflow实现VGGNet

    [fishing-pan:https://blog.csdn.net/u013921430转载请注明出处] 前言 现在已经到了Tensorflow实现卷积神经网络的第四讲了,既然是学习.实践,我一直坚 ...

最新文章

  1. Android源码开发笔记 -- Android数据库,屏幕休眠时间
  2. Sql Server系列:触发器
  3. Spring 详解(三):AOP 面向切面的编程
  4. 专 linux命令之set x详解
  5. 第四代:大规模集成电路计算机
  6. Ubuntu NFS服务器的配置
  7. Linux下搭建SVN服务器及自动更新项目文件到web发布目录(www)
  8. springboot 禁用tomcat_Spring Boot 面试的十个问题
  9. java_version干什么的_java类中serialVersionUID的作用及其使用
  10. 深入研究 Iptables 和 Netfilter 架构
  11. 17.PHPDoc 规范,PHPDocumenter 生成
  12. 火星坐标转WGS84
  13. mac虚拟机改显存_虚拟机mac怎么增大显存
  14. Android 类似360 系统启动时间提示
  15. UNITY个人版设置深色主题
  16. 看不见的竞争 文件和数据压缩
  17. 英雄联盟龙的传人皮肤爬虫
  18. 【爬虫实战】国家企业公示网-项目分析
  19. Unity API常用方法和类
  20. Vue报错:Error in v-on handler: “ReferenceError: regeneratorRuntime is not defined“

热门文章

  1. eclipse中JS文件乱码
  2. 【原创】使用yahoo雅虎js库(YUI)建立无刷新的N级树(可添加删除节点)
  3. Windows XP减肥法
  4. GitHub 标星 1.6w+,我发现了一个宝藏项目,推荐大家学习
  5. Node Version Manager--NodeJS的多版本管理工具--轻松实现多个版本的NodeJS的管理开发
  6. 磁盘分区20191017
  7. Js Vue 对象数组的创建方式
  8. servlet url-pattern配置中 / 和 /* 的区别 记录
  9. 【Python】Python库之数据分析
  10. 如何通过统计值z看置信水平_中恨他! 看看他如何通过这一简单技巧来改善统计信息页面...