1、文件说明

Model.py:构建模型
My_dataset.py:数据集处理
Predict.py:预测图片分类类别
Train.py:训练网络
Utils.py:

2、项目结构和函数设计

Model.py 的类

class DropPath(nn.Module)def forward(self, x)
class PatchEmbed(nn.Module)def forward(self, x)
class PatchMerging(nn.Module):def forward(self, x, H, W)
class Mlp(nn.Module):def forward(self, x):
class WindowAttention(nn.Module):def forward(self, x, mask: Optional[torch.Tensor] = None):
class SwinTransformerBlock(nn.Module):def forward(self, x, attn_mask):
class BasicLayer(nn.Module):def create_mask(self, x, H, W):def forward(self, x, H, W):
class SwinTransformer(nn.Module):def _init_weights(self, m):def forward(self, x)

Model.py 的函数

def drop_path_f(x, drop_prob: float = 0., training: bool = False)
def window_partition(x, window_size: int)
def window_reverse(windows, window_size: int, H: int, W: int)
def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):
def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs):
def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs):
def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs):
def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):

My_dataset.py只有类

class MyDataSet(Dataset):
---def __len__(self):
---def __getitem__(self, item):@staticmethod
---def collate_fn(batch):

Predict.py只有函数

def main():
if __name__ == '__main__':main()

Train.py只有函数

def main(args):
if __name__ == '__main__':。。。main(opt)

Utils.py只有函数

def read_split_data(root: str, val_rate: float = 0.2):
def plot_data_loader_image(data_loader):
def write_pickle(list_info: list, file_name: str):
def read_pickle(file_name: str) -> list:
def train_one_epoch(model, optimizer, data_loader, device, epoch):
@torch.no_grad()
def evaluate(model, data_loader, device, epoch):

Swin-Transformer 论文代码介绍

1 开发环境

 Python 3.6
 torch 1.7.1
 GPU

2 功能设计

实验数据集的说明:
数据来源
http://download.tensorflow.org/example_images/flower_photos.tgz
5类花的图片做分类:
3670 images were found in the dataset.
2939 images for training.
731 images for validation.

Daisy:菊花
Dandelion:蒲公英
Roses:玫瑰
Sunflowers:向日葵
Tulips:郁金香

3 、文件说明

Model.py:构建模型
My_dataset.py:数据集处理
Predict.py:预测图片分类类别
Train.py:训练网络
Utils.py:功能类函数
Model.py 的类
DropPath:设置各模块内的dropout率
PatchEmbed:对图片像素进行划分patch
PatchMerging:对图进行petch的拼接和线性映射
Mlp:SwinTransformerBlock后面一段的使用的
WindowAttention:window内部计算attention
SwinTransformerBlock:构建单个SwinTransformerBlock模型,该模型中含有W-MSA和SW-MSA两个模块
SwinTransformer:构建整个分类模型,这个类调用其他类,共同组成整个模型,从Patchpartion到LinearEmbedding(即类PatchEmbed),到四个SwinTransformerBlock,以及在SwinTransformerBlock中使用是否使用PatchMerging,经过四个阶段的SwinTransformerBlock之后输出展平的向量。
Model.py 的函数
window_partition:对特征图进行划分,划分成一个一个没有重叠的window
window_reverse:将window还原成特征图
定义各种模型,用于实例化模型
swin_tiny_patch4_window7_224
swin_small_patch4_window7_224
swin_base_patch4_window7_224
swin_base_patch4_window12_384
swin_base_patch4_window7_224_in22k
swin_base_patch4_window12_384_in22k
swin_large_patch4_window7_224_in22k
swin_large_patch4_window12_384_in22k
My_dataset.py只有类
MyDataSet(Dataset):构建获取数据集中元素和大小的方法
@staticmethod
collate_fn(batch):用于单独调用使用,将一个批次的图片转为向量并拼在一起
Predict.py只有函数
main(): 创建预测图片类别的函数,展示预测的图片以及被预测图片属于每个类别的概率
if name == ‘main’:
main()
开始预测
Train.py只有函数
main(args)
获取训练集和验证集,对图片进行处理,调整两个数据集中图片的大小,实例化模型,训练模型,保存模型。
自定义参数,解析参数,调用并执行main(args),训练分类模型
Utils.py只有函数
read_split_data:读取图片和图片的类别,划分训练集和验证集
train_one_epoch:
定义损失函数:torch.nn.CrossEntropyLoss()
进行一个epoch的训练,返回损失和精确率
Evaluate

4 流程

运行train.py训练模型,训练了个epoch,最高精确率可到96.6%

5 效果演示

运行predict.py对单独一张图片进行预测类别

Swintransformer详细设计文档相关推荐

  1. 开发详细设计文档_郑州APP开发:开发前,进行详细设计有没有必要?

    郑州燚轩软件科技有限公司● 点击蓝字关注我们 ● 一般进行软件开发 的人都知道,在进行郑州APP开发时,项目流程主要包括需求分析.概要设计.详细设计.编码和测试 ,那么在其中有了概要设计的情况下,为什 ...

  2. rd如何撰写总体设计文档和详细设计文档

    转自:http://www.habadog.com/2012/10/18/rd-how-to-write-document/ rd需要撰写的设计文档主要分为:总体设计文档 + 详细设计文档,后简称为& ...

  3. python手机销售系统详细设计_数据库详细设计文档 .doc

    [原创]定制代写r/python/spss/matlab/WEKA/sas/sql/C++/stata/eviews数据挖掘和统计分析可视化调研报告等服务(附代码数据), 咨询邮箱: 30253934 ...

  4. 【ZT】详细设计文档规范

    1.引言 1.1编写目的 [阐明编写详细设计说明书的目的,指明读者对象.] 1.2项目背景 [应包括项目的来源和主管部门等.] 1.3定义 [列出文档中所用到的专门术语的定义和缩写词的原文.] 1.4 ...

  5. python飞机大战概要设计_飞机大战详细设计文档 第三次修改

    飞机大战详细设计说明书 1. 引言部分 本部分主要说明项目背景和术语定义等. 1.1编写目的 本部分阐明编写详细设计说明书的目的,指明读者对象. 该文档的目的是描述设计飞机大战的每个模块的细节,包括模 ...

  6. oracle客户关系系统,Java swing Oracle实现的客户关系管理系统项目源码附带详细设计文档...

    <p style="font-family:"font-size:16px;text-indent:2em;color:#666666;background-color:#F ...

  7. Java Swing Sqlserver实现的酒店管理系统附带详细设计文档免费下载

    今天给大家分享一款由Java swing sqlserver实现的酒店管理系统,整个系统功能非常完善,结构层次设计的也很合理,数据库采用的是sqlserver,此外附带有系统详细的需求文档,设计文档, ...

  8. Java 食堂管理系统-MySQL数据库,窗体程序 有详细设计文档

    今天为大家分享一个java 编写的食堂管理系统,目前系统功能已经完善,后续会进一步完善.整个系统界面漂亮,有完整得源码,希望大家可以喜欢.喜欢的帮忙点赞和关注.一起编程.一起进步. 开发环境 开发语言 ...

  9. 前端详细设计文档怎么写_UI设计师简历应该怎么写?

    像这种分享,常规开篇都应该说说当前的就业趋势啦,分析分析行业形势啦这类的 但我不想按流程写 行业不论什么时候分析,它都没好过,什么红利期什么风口,那更是从来没赶上过.但凡我能跟点风,我也不能到现在还没 ...

最新文章

  1. SAP MIGO收货界面'批次'分类选项卡里不出现'分类'按钮之对策
  2. 邮件系列(二)-发送邮件
  3. 爱酷pro充电测试软件,iQOO 5 Pro续航、充电测试简报
  4. android 约束布局的坑,android - 使用android约束布局2.0.0 Flow将项目放置一行 - 堆栈内存溢出...
  5. Mysql ==》 单表查询
  6. Object类解析(简)
  7. web自动化知识点-02
  8. Tableau安装与破解
  9. python遗传算法
  10. 计算机微课课件评比活动总结,教学大赛总结.doc
  11. 360Lib整体介绍
  12. php strict,PHP 5.4中的E_STRICT和E_ALL有什么区别?
  13. 历史二—— 浮点运算与数组下标寻址
  14. gitee注册新用户收不到验证码, 不管是手机还是邮箱都收不到验证码解决方案
  15. 引用变量和对象--作为初学者的混淆
  16. 大端和小端的区别和判断
  17. Atom快速跳转到函数定义处
  18. RPC服务器不可用解决方法汇集
  19. 傻白入门芯片设计,盘点CPU业界的顶尖人才(十四)
  20. 量化投资之王:他连续27年回报率打败巴菲特

热门文章

  1. 使用java 自带的webservice
  2. 设置npm的registry
  3. C#实现异步消息队列
  4. Javascript 学习笔记 2: 标识语句
  5. flash 多个文件上传
  6. 爱吃苹果的与喜欢篮球的没必要非得达成一致~
  7. 冒泡算法的三种JavaScript表示
  8. c语言 自动测试,C语言测试。自己实现scandir 函数
  9. php 整数转换为32 位,PHP哈希函数返回一个整数(32位int)(PHP hashing function that returns an integer (32bit int))...
  10. centos dovecot mysql_Centos6.4 配置postfix+dovecot+mysql