从ResNet101到ResNet50
一直用VGG训练,几天前想看下ResNet的效果如何,因为SSD源码中有python实现的ResNet网络结构实现代码,包含ResNet101和ResNet152,直接拿ResNet101来训练,GTX1060配置,batchsize竟然只降到2才跑的起来,果然一直收敛不了。看了下model_libs.py里面的实现代码:
def ResNet101Body(net, from_layer, use_pool5=True, use_dilation_conv5=False, **bn_param):conv_prefix = ''conv_postfix = ''bn_prefix = 'bn_'bn_postfix = ''scale_prefix = 'scale_'scale_postfix = ''ConvBNLayer(net, from_layer, 'conv1', use_bn=True, use_relu=True,num_output=64, kernel_size=7, pad=3, stride=2,conv_prefix=conv_prefix, conv_postfix=conv_postfix,bn_prefix=bn_prefix, bn_postfix=bn_postfix,scale_prefix=scale_prefix, scale_postfix=scale_postfix, **bn_param)net.pool1 = L.Pooling(net.conv1, pool=P.Pooling.MAX, kernel_size=3, stride=2)ResBody(net, 'pool1', '2a', out2a=64, out2b=64, out2c=256, stride=1, use_branch1=True, **bn_param)ResBody(net, 'res2a', '2b', out2a=64, out2b=64, out2c=256, stride=1, use_branch1=False, **bn_param)ResBody(net, 'res2b', '2c', out2a=64, out2b=64, out2c=256, stride=1, use_branch1=False, **bn_param)ResBody(net, 'res2c', '3a', out2a=128, out2b=128, out2c=512, stride=2, use_branch1=True, **bn_param)from_layer = 'res3a'for i in xrange(1, 4):block_name = '3b{}'.format(i)ResBody(net, from_layer, block_name, out2a=128, out2b=128, out2c=512, stride=1, use_branch1=False, **bn_param)from_layer = 'res{}'.format(block_name)ResBody(net, from_layer, '4a', out2a=256, out2b=256, out2c=1024, stride=2, use_branch1=True, **bn_param)from_layer = 'res4a'for i in xrange(1, 23):block_name = '4b{}'.format(i)ResBody(net, from_layer, block_name, out2a=256, out2b=256, out2c=1024, stride=1, use_branch1=False, **bn_param)from_layer = 'res{}'.format(block_name)stride = 2dilation = 1if use_dilation_conv5:stride = 1dilation = 2ResBody(net, from_layer, '5a', out2a=512, out2b=512, out2c=2048, stride=stride, use_branch1=True, dilation=dilation, **bn_param)ResBody(net, 'res5a', '5b', out2a=512, out2b=512, out2c=2048, stride=1, use_branch1=False, dilation=dilation, **bn_param)ResBody(net, 'res5b', '5c', out2a=512, out2b=512, out2c=2048, stride=1, use_branch1=False, dilation=dilation, **bn_param)if use_pool5:net.pool5 = L.Pooling(net.res5c, pool=P.Pooling.AVE, global_pooling=True)return net
RenNet152Body为:
def ResNet152Body(net, from_layer, use_pool5=True, use_dilation_conv5=False, **bn_param):conv_prefix = ''conv_postfix = ''bn_prefix = 'bn_'bn_postfix = ''scale_prefix = 'scale_'scale_postfix = ''ConvBNLayer(net, from_layer, 'conv1', use_bn=True, use_relu=True,num_output=64, kernel_size=7, pad=3, stride=2,conv_prefix=conv_prefix, conv_postfix=conv_postfix,bn_prefix=bn_prefix, bn_postfix=bn_postfix,scale_prefix=scale_prefix, scale_postfix=scale_postfix, **bn_param)net.pool1 = L.Pooling(net.conv1, pool=P.Pooling.MAX, kernel_size=3, stride=2)ResBody(net, 'pool1', '2a', out2a=64, out2b=64, out2c=256, stride=1, use_branch1=True, **bn_param)ResBody(net, 'res2a', '2b', out2a=64, out2b=64, out2c=256, stride=1, use_branch1=False, **bn_param)ResBody(net, 'res2b', '2c', out2a=64, out2b=64, out2c=256, stride=1, use_branch1=False, **bn_param)ResBody(net, 'res2c', '3a', out2a=128, out2b=128, out2c=512, stride=2, use_branch1=True, **bn_param)from_layer = 'res3a'for i in xrange(1, 8):block_name = '3b{}'.format(i)ResBody(net, from_layer, block_name, out2a=128, out2b=128, out2c=512, stride=1, use_branch1=False, **bn_param)from_layer = 'res{}'.format(block_name)ResBody(net, from_layer, '4a', out2a=256, out2b=256, out2c=1024, stride=2, use_branch1=True, **bn_param)from_layer = 'res4a'for i in xrange(1, 36):block_name = '4b{}'.format(i)ResBody(net, from_layer, block_name, out2a=256, out2b=256, out2c=1024, stride=1, use_branch1=False, **bn_param)from_layer = 'res{}'.format(block_name)stride = 2dilation = 1if use_dilation_conv5:stride = 1dilation = 2ResBody(net, from_layer, '5a', out2a=512, out2b=512, out2c=2048, stride=stride, use_branch1=True, dilation=dilation, **bn_param)ResBody(net, 'res5a', '5b', out2a=512, out2b=512, out2c=2048, stride=1, use_branch1=False, dilation=dilation, **bn_param)ResBody(net, 'res5b', '5c', out2a=512, out2b=512, out2c=2048, stride=1, use_branch1=False, dilation=dilation, **bn_param)if use_pool5:net.pool5 = L.Pooling(net.res5c, pool=P.Pooling.AVE, global_pooling=True)return net
其中每次调用ResBody,当use_brabch1 = True,会创建4个卷积层,当use_brabch1 = False时,创建3个卷积层。ResNet101Body和ResNet152Body的区别在于两个for循环的次数不一样,101层和152层差的51层就是这里体现的,所以现在要创建ResNet50Body就容易多了。根据网上下载的模型对应的ResNet_50_train_val.prototxt,对上面代码进行修改即可。50层,batchsize=4,训练马上收敛。当然训练方式多种,可用直接利用已有ResNet_50_train_val.prototxt进行训练。
从ResNet101到ResNet50相关推荐
- X射线图像中的目标检测
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 1 动机和背景 每天有数百万人乘坐地铁.民航飞机等公共交通工具,因 ...
- mask rcnn训练自己的数据集
原文首发于微信公众号「3D视觉工坊」--mask rcnn训练自己的数据集 前言 最近迷上了mask rcnn,也是由于自己工作需要吧,特意研究了其源代码,并基于自己的数据进行训练~ 本博客参考:ht ...
- 模型压缩 | 无需精雕细琢,随机剪枝足矣!(ICLR 2022)
关注公众号,发现CV技术之美 本文是由京东探索研究院联合荷兰埃因霍温理工大学和德州大学奥斯汀分校完成,探索了随机化剪枝在稀疏训练中不可思议的表现.我们注意到,每当有新的剪枝方法提出来的时候,随机剪枝就 ...
- ICCV 2019 | 港大提出视频显著物体检测算法MGA,大幅提升分割精度
点击我爱计算机视觉标星,更快获取CVML新技术 本文解读了香港大学联合中山大学和深睿医疗人工智能实验室 ICCV2019 论文<Motion Guided Attention for Video ...
- nnUNet原创团队全新力作!MedNeXt:医学图像分割新SOTA
Title:MedNeXt: Transformer-driven Scaling of ConvNets for Medical Image Segmentation MedNeXt:用于医学图像分 ...
- PyTorch学习记录——PyTorch生态
Pytorch的强大并不仅局限于自身的易用性,更在于开源社区围绕PyTorch所产生的一系列工具包(一般是Python package)和程序,这些优秀的工具包极大地方便了PyTorch在特定领域的使 ...
- 【STARK论文翻译阅读】
STARK:Learning Spatio-Temporal Transformer for Visual Tracking论文翻译阅读 简介 1.引言 2.相关工作 2.1 transformer在 ...
- 深度学习(9):FastFCN论文翻译与学习
FastFCN: Rethinking Dilated Convolution in the Backbone for Semantic FastFCN:重新思考膨胀卷积在语义分割主干网络中的作用 注 ...
- PyTorch:生态简介
PyTorch生态简介 PyTorch的强大并不仅局限于自身的易用性,更在于开源社区围绕PyTorch所产生的一系列工具包(一般是Python package)和程序,这些优秀的工具包极大地方便了Py ...
- 【论文阅读】Squeeze-and-Attention Networks for Semantic Segmentation(CVPR2020)
论文题目:Squeeze-and-Attention Networks for Semantic Segmentation(用于语义分割的挤压-注意网络) 下载链接:https://arxiv.org ...
最新文章
- 创建Maven版Java工程
- python调用qq互联_Django项目中实现使用qq第三方登录功能
- Py之pandas:利用pandas工具输出每行的索引值、及其对应的行数据
- java 批量处理 示例_Java异常处理教程(包含示例和最佳实践)
- 使用ZeroTier搭建大局域网利用VNC远程桌面
- [转]Spring中的ContextLoaderListener使用
- C语言的标准内存分配函数
- dcp9030cdn定影_兄弟DCP-9030CDN打印驱动下载|兄弟Brother DCP-9030CDN一体打印机驱动官方下载 - 维维软件园...
- fftshift使用
- 域名ip查询步骤与域名如何解析到ip
- python爬虫区划代码表
- php过滤微信表情符号
- php session fixation,Session Fixation 原理与防御
- 使用opencv识别同心圆
- MVP模式网络请求购物车
- 总体标准差-样本标准差
- 10个超棒的界面设计工具
- 基于单片机交通灯控制的c语言程序设计,基于单片机控制的交通灯毕业设计
- 三光(可见光、红外光、激光)云台产品调研
- 什么是Ceph?听听Ceph创始人怎么说
热门文章
- MYSQL误删数据恢复
- python图书库存管理系统_基于Odoo的物流库存管理系统的设计(Python)
- C语言 用矩形法计算定积分∫(0—1)sinxdx、∫(-1—1)cosxdx、∫(0—2)e^xdx
- SWMM源码编译LNK2001 无法解析的外部符号 _swmm_close@0
- 微信内置浏览器cookie设置问题
- html左侧树形图,Qunee for HTML5 - 中文 : 树形布局
- 如何修改ssh端口号
- electron编写我们第一个hello world程序和文件引入
- 测试 minpy gpu加速 numpy 矩阵相乘 matmul matrix multiplication
- FTP文件传输神器:8uftp