代码请见:https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/xception.py

"""
Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch)@author: tstandley
Adapted by cadeneCreates an Xception Model as defined in:Francois Chollet
Xception: Deep Learning with Depthwise Separable Convolutions
https://arxiv.org/pdf/1610.02357.pdfThis weights ported from the Keras implementation. Achieves the following performance on the validation set:Loss:0.9173 Prec@1:78.892 Prec@5:94.292REMEMBER to set your image size to 3x299x299 for both test and validationnormalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
"""
from __future__ import print_function, division, absolute_import
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn import init__all__ = ['xception']pretrained_settings = {'xception': {'imagenet': {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth','input_space': 'RGB','input_size': [3, 299, 299],'input_range': [0, 1],'mean': [0.5, 0.5, 0.5],'std': [0.5, 0.5, 0.5],'num_classes': 1000,'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299}}
}class SeparableConv2d(nn.Module):def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):super(SeparableConv2d,self).__init__()self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)def forward(self,x):x = self.conv1(x)x = self.pointwise(x)return xclass Block(nn.Module):def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):super(Block, self).__init__()if out_filters != in_filters or strides!=1:self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)self.skipbn = nn.BatchNorm2d(out_filters)else:self.skip=Nonerep=[]filters=in_filtersif grow_first:rep.append(nn.ReLU(inplace=True))rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))rep.append(nn.BatchNorm2d(out_filters))filters = out_filtersfor i in range(reps-1):rep.append(nn.ReLU(inplace=True))rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))rep.append(nn.BatchNorm2d(filters))if not grow_first:rep.append(nn.ReLU(inplace=True))rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))rep.append(nn.BatchNorm2d(out_filters))if not start_with_relu:rep = rep[1:]else:rep[0] = nn.ReLU(inplace=False)if strides != 1:rep.append(nn.MaxPool2d(3,strides,1))self.rep = nn.Sequential(*rep)def forward(self,inp):x = self.rep(inp)if self.skip is not None:skip = self.skip(inp)skip = self.skipbn(skip)else:skip = inpx+=skipreturn xclass Xception(nn.Module):"""Xception optimized for the ImageNet dataset, as specified inhttps://arxiv.org/pdf/1610.02357.pdf"""def __init__(self, num_classes=1000):""" ConstructorArgs:num_classes: number of classes"""super(Xception, self).__init__()self.num_classes = num_classesself.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False)self.bn1 = nn.BatchNorm2d(32)self.relu1 = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(32,64,3,bias=False)self.bn2 = nn.BatchNorm2d(64)self.relu2 = nn.ReLU(inplace=True)#do relu hereself.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True)self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True)self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True)self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True)self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True)self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True)self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True)self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True)self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True)self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True)self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True)self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)self.conv3 = SeparableConv2d(1024,1536,3,1,1)self.bn3 = nn.BatchNorm2d(1536)self.relu3 = nn.ReLU(inplace=True)#do relu hereself.conv4 = SeparableConv2d(1536,2048,3,1,1)self.bn4 = nn.BatchNorm2d(2048)self.fc = nn.Linear(2048, num_classes)# #------- init weights --------# for m in self.modules():#     if isinstance(m, nn.Conv2d):#         n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels#         m.weight.data.normal_(0, math.sqrt(2. / n))#     elif isinstance(m, nn.BatchNorm2d):#         m.weight.data.fill_(1)#         m.bias.data.zero_()# #-----------------------------def features(self, input):x = self.conv1(input)x = self.bn1(x)x = self.relu1(x)x = self.conv2(x)x = self.bn2(x)x = self.relu2(x)x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = self.block6(x)x = self.block7(x)x = self.block8(x)x = self.block9(x)x = self.block10(x)x = self.block11(x)x = self.block12(x)x = self.conv3(x)x = self.bn3(x)x = self.relu3(x)x = self.conv4(x)x = self.bn4(x)return xdef logits(self, features):x = nn.ReLU(inplace=True)(features)x = F.adaptive_avg_pool2d(x, (1, 1))x = x.view(x.size(0), -1)x = self.last_linear(x)return xdef forward(self, input):x = self.features(input)x = self.logits(x)return xdef xception(num_classes=1000, pretrained='imagenet'):model = Xception(num_classes=num_classes)if pretrained:settings = pretrained_settings['xception'][pretrained]assert num_classes == settings['num_classes'], \"num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)model = Xception(num_classes=num_classes)model.load_state_dict(model_zoo.load_url(settings['url']))model.input_space = settings['input_space']model.input_size = settings['input_size']model.input_range = settings['input_range']model.mean = settings['mean']model.std = settings['std']# TODO: uglymodel.last_linear = model.fcdel model.fcreturn model

Xception: Deep Learning with Depthwise Separable Convolutions相关推荐

  1. 《Deep Learning With Python second edition》英文版读书笔记:第十一章DL for text: NLP、Transformer、Seq2Seq

    文章目录 第十一章:Deep learning for text 11.1 Natural language processing: The bird's eye view 11.2 Preparin ...

  2. 深度可分离卷积Depthwise Separable Convolution

    从卷积神经网络登上历史舞台开始,经过不断的改进和优化,卷积早已不是当年的卷积,诞生了分组卷积(Group convolution).空洞卷积(Dilated convolution 或 À trous ...

  3. 《Deep Learning for Computer Vision withPython》阅读笔记-StarterBundle(第18 - 23章)

    18.检查点模型 截止到P265页 //2022.1.18日22:14开始学习 在第13章中,我们讨论了如何在培训完成后将模型保存和序列化到磁盘上.在上一章中,我们学习了如何在发生欠拟合和过拟合时发现 ...

  4. 【文献学习】Complex-Valued Convolutions for Modulation Recognition using Deep Learning

    目录 1 简介和创新点 1.1 DL中复数的处理综述 1.2 DL对于调制分类的综述 2 系统模型 2.1 二维实数卷积 2.2 整合到现有的DL架构中 3 模型参数 4 实验分析 5 思考和收获哦 ...

  5. A Novel Two-stage Separable Deep Learning Framework for Practical Blind Watermarking论文阅读

    A Novel Two-stage Separable Deep Learning Framework for Practical Blind Watermarking Abstract 数字水印是一 ...

  6. 【IJCV 2022】RIConv++: Effective Rotation Invariant Convolutions for 3D Point Clouds Deep Learning

    文章目录 研究旋转不变就从这里开始吧. [3DV 2019]Rotation Invariant Convolutions for 3D Point Clouds Deep Learning [IJC ...

  7. Image Segmentation Using Deep Learning: A Survery

    图像分割综述–论文翻译    论文地址:https://arxiv.org/pdf/2001.05566.pdf 图像分割综述论文 图像分割综述--论文翻译 摘要 介绍 深度神经网络的介绍 Convo ...

  8. 详述Deep Learning中的各种卷积(二)

    作者:Redflashing 本文梳理举例总结深度学习中所遇到的各种卷积,帮助大家更为深刻理解和构建卷积神经网络. 本文将详细介绍以下卷积概念: 2D卷积(2D Convolution) 3D卷积(3 ...

  9. 全文翻译(全文合集):TVM: An Automated End-to-End Optimizing Compiler for Deep Learning

    全文翻译(全文合集):TVM: An Automated End-to-End Optimizing Compiler for Deep Learning 摘要 人们越来越需要将机器学习应用到各种各样 ...

最新文章

  1. 成功解决ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or
  2. 幅度和幅值有区别吗_童溢金:白银期货和现货白银的区别在哪,你知道吗?
  3. 【转】TechEd第一课:新一代关系管理系统XRM**
  4. Mac下安装Flink的local模式(flink-1.2.0)
  5. python批量提取哔哩哔哩bilibili视频
  6. 推荐一个自动写paper的软件,让IEEE吐血泪奔
  7. svm (opencv)几个主要函数
  8. 关于java小游戏的暂停,退出和从新开始
  9. 云服务器外网访问MySql全程实录
  10. linux 单网卡绑定两个ip
  11. 千元机PK苹果iphone
  12. Ubuntu删除多余内核
  13. 【启明云端】启明云端带你揭开WT32-S3-WROVER神秘面纱
  14. Java设计模式-代理模式笔记
  15. CobaltStrike上线Linux主机(CrossC2)
  16. root精灵有mac版的吗,苹果有root
  17. 立体图像和平面图像质量评价常用数据库
  18. webpack中利用eslint对js进行代码格式检校
  19. 常见的主流浏览器内核
  20. php判断几维数组的方法,PHP判断数组是一维二维或几维实例

热门文章

  1. python 窗口键 键位码_滚轮键按一下 这些功能超方便
  2. Python+OpenCV:Optical Flow(光流)
  3. SQL Server 2014 导入Excel
  4. QMutexLocker作用范围
  5. OpenGL秒安装及显示
  6. 【mongodb系统学习之十】mongodb查询(二)
  7. [转]查看linux服务器硬盘IO读写负载
  8. C++之String的find方法,查找一个字符串在另一个字符串的什么位置;leveldb字符串转数字解析办法...
  9. Cisco网络管理的35个常见问题及解答
  10. 用Asp.Net创建基于Ajax的聊天室程序