1. ResNet理论

残差学习基本单元:

在ImageNet上的结果:

效果会随着模型层数的提升而下降,当更深的网络能够开始收敛时,就会出现降级问题:随着网络深度的增加,准确度变得饱和(这可能不足为奇),然后迅速降级。

ResNet模型:

2. pytorch实现

2.1 基础卷积

conv3$\times\(3 和conv1\)\times$1 基础模块

def conv3x3(in_channel, out_channel, stride=1, groups=1, dilation=1):

return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_channel, out_channel, stride=1):

return nn.Conv2d(in_channel, out_channel, kernel_size=1, bias=False)

参数解释:

in_channel: 输入的通道数目

out_channel:输出的通道数目

stride, padding: 步长和补0

dilation: 空洞卷积中的参数

groups: 从输入通道到输出通道的阻塞连接数

feature size 计算:

output = (intput - filter_size + 2 x padding) / stride + 1

空洞卷积实际卷积核大小:

K = K + (K-1)x(R-1)

K 是原始卷积核大小

R 是空洞卷积参数的空洞率(普通卷积为1)

2.2 模块

- resnet34

- _resnet

- ResNet

- _make_layer

- block

- Bottleneck

- BasicBlock

Bottlenect

class Bottleneck(nn.Module):

expansion = 4

__constants__ = ['downsample']

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,

base_width=64, dilation=1, norm_layer=None):

super(Bottleneck, self).__init__()

if norm_layer is None:

norm_layer = nn.BatchNorm2d

width = int(planes * (base_width / 64.)) * groups

# Both self.conv2 and self.downsample layers downsample the input when stride != 1

self.conv1 = conv1x1(inplanes, width)

self.bn1 = norm_layer(width)

self.conv2 = conv3x3(width, width, stride, groups, dilation)

self.bn2 = norm_layer(width)

self.conv3 = conv1x1(width, planes * self.expansion)

self.bn3 = norm_layer(planes * self.expansion)

self.relu = nn.ReLU(inplace=True)

self.downsample = downsample

self.stride = stride

def forward(self, x):

identity = x

out = self.conv1(x)

out = self.bn1(out)

out = self.relu(out)

out = self.conv2(out)

out = self.bn2(out)

out = self.relu(out)

out = self.conv3(out)

out = self.bn3(out)

if self.downsample is not None:

identity = self.downsample(x)

out += identity

out = self.relu(out)

return out

BasicBlock

class BasicBlock(nn.Module):

expansion = 1

__constants__ = ['downsample']

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,

base_width=64, dilation=1, norm_layer=None):

super(BasicBlock, self).__init__()

if norm_layer is None:

norm_layer = nn.BatchNorm2d

if groups != 1 or base_width != 64:

raise ValueError('BasicBlock only supports groups=1 and base_width=64')

if dilation > 1:

raise NotImplementedError("Dilation > 1 not supported in BasicBlock")

# Both self.conv1 and self.downsample layers downsample the input when stride != 1

self.conv1 = conv3x3(inplanes, planes, stride)

self.bn1 = norm_layer(planes)

self.relu = nn.ReLU(inplace=True)

self.conv2 = conv3x3(planes, planes)

self.bn2 = norm_layer(planes)

self.downsample = downsample

self.stride = stride

def forward(self, x):

identity = x

out = self.conv1(x)

out = self.bn1(out)

out = self.relu(out)

out = self.conv2(out)

out = self.bn2(out)

if self.downsample is not None:

identity = self.downsample(x)

out += identity

out = self.relu(out)

return out

2.3 使用ResNet模块进行迁移学习

import torchvision.models as models

import torch.nn as nn

class RES18(nn.Module):

def __init__(self):

super(RES18, self).__init__()

self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN

self.base = torchvision.models.resnet18(pretrained=False)

self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)

def forward(self, x):

out = self.base(x)

return out

class RES34(nn.Module):

def __init__(self):

super(RES34, self).__init__()

self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN

self.base = torchvision.models.resnet34(pretrained=False)

self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)

def forward(self, x):

out = self.base(x)

return out

class RES50(nn.Module):

def __init__(self):

super(RES50, self).__init__()

self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN

self.base = torchvision.models.resnet50(pretrained=False)

self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)

def forward(self, x):

out = self.base(x)

return out

class RES101(nn.Module):

def __init__(self):

super(RES101, self).__init__()

self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN

self.base = torchvision.models.resnet101(pretrained=False)

self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)

def forward(self, x):

out = self.base(x)

return out

class RES152(nn.Module):

def __init__(self):

super(RES152, self).__init__()

self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN

self.base = torchvision.models.resnet152(pretrained=False)

self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)

def forward(self, x):

out = self.base(x)

return out

使用模块直接生成一个类即可,比如训练的时候:

cnn = RES101()

cnn.train() # 改为训练模式

prediction = cnn(img) #进行预测

目前先写这么多,看过了源码以后感觉写的很好,不仅仅有论文中最基础的部分,还有一些额外的功能,模块的组织也很整齐。

平时使用一般都进行迁移学习,使用的话可以把上述几个类中pretrained=False参数改为True.

基于python实现resnet_【深度学习】基于Pytorch的ResNet实现相关推荐

  1. 基于Python机器学习、深度学习技术提升气象、海洋、水文领域实践应用

     Python是功能强大.免费.开源,实现面向对象的编程语言,能够在不同操作系统和平台使用,简洁的语法和解释性语言使其成为理想的脚本语言.除了标准库,还有丰富的第三方库,Python在数据处理.科学计 ...

  2. 基于Python机器学习、深度学习技术提升气象、海洋、水文领域实践应用能力

    目录 专题一.Python软件的安装及入门 专题二.气象常用科学计算库 专题三.气象海洋常用可视化库 专题四.爬虫和气象海洋数据 专题五.气象海洋常用插值方法 专题六.机器学习基础理论和实操 专题七. ...

  3. 基于python的tensorflow_Python深度学习:基于TensorFlow

    前言 第一部分 Python及应用数学基础 第1章 NumPy常用操作 2 1.1 生成ndarray的几种方式 3 1.2 存取元素 5 1.3 矩阵操作 6 1.4 数据合并与展平 7 1.5 通 ...

  4. 深度学习必备书籍——《Python深度学习 基于Pytorch》

    作为一名机器学习|深度学习的博主,想和大家分享几本深度学习的书籍,让大家更快的入手深度学习,成为AI达人!今天给大家介绍的是:<Python深度学习 基于Pytorch> 文章目录 一.背 ...

  5. ElasticDL:首个基于 TensorFlow 实现弹性深度学习的开源系统

    9 月 11 日,蚂蚁金服开源了 ElasticDL 项目,据悉这是业界首个基于 TensorFlow 实现弹性深度学习的开源系统. Google Brain 成员 Martin Wicke 此前在公 ...

  6. 基于Ubuntu18.04下深度学习服务器搭建

    基于Ubuntu18.04下深度学习服务器搭建 目录: 基于Ubuntu18.04下深度学习服务器搭建 主要模块组成 Anaconda安装 CUDA安装 pytorch安装 CuDNN安装 其他常用指 ...

  7. 基于NVIDIA GPUs的深度学习训练新优化

    基于NVIDIA GPUs的深度学习训练新优化 New Optimizations To Accelerate Deep Learning Training on NVIDIA GPUs 不同行业采用 ...

  8. 值得收藏!基于激光雷达数据的深度学习目标检测方法大合集(下)

    作者 | 黄浴 来源 | 转载自知乎专栏自动驾驶的挑战和发展 [导读]在近日发布的<值得收藏!基于激光雷达数据的深度学习目标检测方法大合集(上)>一文中,作者介绍了一部分各大公司和机构基于 ...

  9. 第三十五课.基于贝叶斯的深度学习

    目录 贝叶斯公式 基础问题 贝叶斯深度学习与深度学习的区别 贝叶斯神经网络与贝叶斯网络 贝叶斯神经网络的推理与学习 前向计算 学习 贝叶斯公式 首先回顾贝叶斯公式:p(z∣x)=p(x,z)p(x)= ...

  10. R基于H2O包构建深度学习模型实战

    R基于H2O包构建深度学习模型实战 目录 R基于H2O包构建深度学习模型实战 #案例分析

最新文章

  1. python代码案例详解-Python综合应用名片管理系统案例详解
  2. 机器学习中的不平衡分类方法(part2)--模型评估与选择
  3. 虚拟专题:知识图谱 | 流程工业控制系统的知识图谱构建
  4. php选择nginx还是apache,浅谈apache和nginx的rewrite的区别
  5. redis 获取不到_redis系列之——缓存穿透、缓存击穿、缓存雪崩
  6. python 微信公众号发文章_Python 微信公众号文章爬取
  7. IntelliJ IDEA安装与JDK 环境变量配置
  8. Arduino 连接JDY-08蓝牙模块
  9. 射雕英雄传ol显示服务器断开,射雕英雄传OL6月18日维护更新内容
  10. Ubuntu下codeblocks汉化
  11. 第二十五篇:稳定性之灰度发布
  12. 好队友--超好用的函数插件大全,再也不用为excel函数使用烦恼啦
  13. c语言设计评分程序,C语言程序设计课程设计---设计比赛评分系统
  14. 分布式任务调度框架设计与实现解读(1)
  15. 虚拟偶像成为二次元香饽饽,从直播切入有戏吗?
  16. 阿里巴巴国际站店铺装修悬浮菜单定位,快速导航链接到某个位置,跳转链接悬浮代码工具代码生成器制作锚点链接
  17. 产品 · B端生意的定义和分类
  18. 【智能优化算法-白鲸优化算法】基于白鲸优化算法求解单目标优化问题附matlab代码
  19. 如何快速编写纯CSS菜单?制作CSS精美菜单优化精简代码详细教程
  20. Cookie、Session、Token与JWT解析

热门文章

  1. python在材料方面的应用_python记录材料题带标准答案
  2. VS中监视窗口,即时窗口和输出窗口的使用
  3. 【树莓派】【网摘】树莓派与XBMC及Kodi、LibreELEC插件(三)
  4. ASP.NET MVC5使用Area区域
  5. 解决 Plugin with id 'com.github.dcendents.android-maven' not found.
  6. easyphp环境配置
  7. php 大流量网站访问
  8. SD卡启动盘制作软件
  9. 转:flex [Inspectable]标签详解
  10. Editplus For Python[转]