原始代码位置:

GitHub - HuCaoFighting/Swin-Unet: The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"https://github.com/HuCaoFighting/Swin-Unet

这个代码的架构和下面这个transunet一样的Github复现之TransUNet(Transformer用于语义分割)_如雾如电的博客-CSDN博客_transunet复现Transformer最近应该算是一个发文章的新扩展点了,下面给出了三个网络的结构分别是TransFuse,TransUNet,SETR。很明显,结构里那个Transformer层都是类似的,感觉只要用一下那个层,包装一下,发文章会比纯做卷积网络创新相对轻松些,目前我只用了TransUNet,也没有怎么训练,还没法给出实际效果的好坏评价,后续会多做实验,评估这些网路用于实际时究竟怎样,接下来就先完成一下TransUNet的复现。TransFuse论文链接:https://arxiv.org/abs/21https://blog.csdn.net/qq_20373723/article/details/115548900?spm=1001.2014.3001.5501

数据准备的方式和下面这个复现一模一样我就不废话了

Github复现之TransUnet更新_如雾如电的博客-CSDN博客_transunet复现上一篇关于TransUnet的GitHub复现,大家反映效果不好,调参也不好调,我把模型单独拿出来,放到另外一个框架,供大家参考学习(上一篇链接:https://blog.csdn.net/qq_20373723/article/details/115548900)我这里训练了20个epoch,下面先给出效果正常的情况:原图预测结果整体代码结构:1.数据准备,文件名字请务必保持一致,不过你也可以去代码里改一级目录,红线的三个,其它不用管二级目录三级目录就是图像和标签,二者名字保持一https://blog.csdn.net/qq_20373723/article/details/117225238?spm=1001.2014.3001.5501

注意里面有东西要装,贴一下我的环境(有删减,仅仅是参考)

Package                            Version               Location
---------------------------------- --------------------- ---------------------
cupy                               6.5.0+cuda101
cupy-cuda110                       9.6.0
cycler                             0.10.0
cymem                              2.0.6
Cython                             0.29.21
cytoolz                            0.9.0.1
easycython                         1.0.7
easydict                           1.9
efficientnet-pytorch               0.6.3
h5py                               2.10.0
ImageHash                          4.2.1
imageio                            2.5.0
imagesize                          1.1.0
json5                              0.9.6
keras                              2.8.0
Keras-Applications                 1.0.8
keras-bert                         0.86.0
keras-contrib                      2.0.8
keras-embed-sim                    0.8.0
keras-layer-normalization          0.14.0
keras-multi-head                   0.27.0
keras-nightly                      2.9.0.dev2022031807
keras-pos-embd                     0.11.0
keras-position-wise-feed-forward   0.6.0
Keras-Preprocessing                1.1.2
keras-self-attention               0.46.0
keras-transformer                  0.38.0
labelme                            3.16.5
labelme2coco                       0.1.2
langdetect                         1.0.9
lazy-object-proxy                  1.3.1
libarchive-c                       2.8
Markdown                           3.3.3
MarkupSafe                         2.0.1
matplotlib                         3.2.2
matplotlib-inline                  0.1.3
mayavi                             4.7.3
mccabe                             0.6.1
MedPy                              0.4.0
menuinst                           1.4.16
metview                            1.8.1
mistune                            0.8.4
mkl-fft                            1.0.10
mkl-random                         1.0.2
ml-collections                     0.1.0
mlbox                              0.8.5
mmcv                               1.3.12
mmdet                              2.16.0
mock                               2.0.0
more-itertools                     6.0.0
mpmath                             1.1.0
msgpack                            0.6.1
mtcnn                              0.1.0
multidict                          5.2.0
multipledispatch                   0.6.0
munch                              2.5.0
munkres                            1.1.4
murmurhash                         1.0.6
navigator-updater                  0.2.1
nbclassic                          0.3.1
nbconvert                          5.4.1
nbformat                           4.4.0
nest-asyncio                       1.5.1
networkx                           2.2
nibabel                            3.2.1
nltk                               3.4
nnunet                             1.6.6                 d:\csdn\nnunet-master
nose                               1.3.7
notebook                           5.7.8
numba                              0.55.1
numexpr                            2.6.9
numpy                              1.19.5
oauthlib                           3.1.0
odo                                0.5.1
olefile                            0.46
omegaconf                          2.0.0
open3d                             0.13.0
opencv-contrib-python              3.4.2.17
opencv-python                      4.5.2.52
opencv-python-headless             4.5.2.52
openpyxl                           2.6.1
opt-einsum                         3.3.0
ospybook                           1.0
packaging                          21.3
pandas                             0.25.3
pandocfilters                      1.4.2
parso                              0.3.4
partd                              0.3.10
path.py                            11.5.0
pathlib2                           2.3.3
patsy                              0.5.2
pbr                                5.5.1
PCV                                1.0
pep8                               1.7.1
pickleshare                        0.7.5
Pillow                             8.2.0
pinyin                             0.4.0
pip                                19.0.3
pixellib                           0.6.6
pkginfo                            1.5.0.1
plac                               1.1.3
pluggy                             0.9.0
ply                                3.11
pooch                              1.6.0
prefetch-generator                 1.0.1
preshed                            3.0.6
pretrainedmodels                   0.7.4
progressbar                        2.5
prometheus-client                  0.6.0
prompt-toolkit                     2.0.9
protobuf                           3.19.4
protobuf-py3                       2.5.1
psutil                             5.8.0
py                                 1.8.0
py3nvml                            0.2.6
pyaml                              21.10.1
pyarrow                            5.0.0
pyasn1                             0.4.8
pyasn1-modules                     0.2.8
pycocotools                        2.0.2
pycocotools-windows                2.0.0.2
pycodestyle                        2.5.0
pycosat                            0.6.3
pycparser                          2.19
pycrypto                           2.6.1
pycurl                             7.43.0.2
pydeck                             0.7.0
pydensecrf                         1.0rc2
pyDeprecate                        0.3.1
pydicom                            2.1.2
pyface                             7.3.0
pyflakes                           2.1.1
pygeos                             0.10
Pygments                           2.9.0
PyHamcrest                         2.0.2
pykdtree                           1.3.4
pylint                             2.3.1
pyodbc                             4.0.26
pyOpenSSL                          19.0.0
pyparsing                          2.3.1
pyproj                             3.0.0.post1
pyreadline                         2.1
pyresample                         1.21.1
pyrser                             0.2.0
pyrsistent                         0.14.11
PySocks                            1.6.8
pytest                             4.3.1
pytest-arraydiff                   0.3
pytest-astropy                     0.5.0
pytest-doctestplus                 0.3.0
pytest-openfiles                   0.3.2
pytest-remotedata                  0.3.1
python-dateutil                    2.8.0
python-editor                      1.0.4
pytorch-lightning                  1.0.8
pytorch-toolbelt                   0.3.0
pytz                               2020.1
PyWavelets                         1.1.1
pywin32                            225
pywinpty                           1.1.3
PyYAML                             5.3.1
pyzmq                              18.0.0
QtAwesome                          0.5.7
qtconsole                          4.4.3
QtPy                               1.7.0
rasterio                           1.2.0
rasterstats                        0.15.0
realesrgan                         0.2.4.0
regex                              2021.4.4
requests                           2.21.0
requests-oauthlib                  1.3.0
requests-unixsocket                0.2.0
resampy                            0.2.2
retry                              0.9.2
rope                               0.12.0
rsa                                4.6
Rtree                              0.9.7
ruamel-yaml                        0.15.46
sacremoses                         0.0.45
scikit-image                       0.18.1
scikit-learn                       0.22.1
scipy                              1.7.3
seaborn                            0.11.0
segmentation-models-pytorch        0.1.3
Send2Trash                         1.5.0
sentencepiece                      0.1.95
sentinelsat                        0.14
seqeval                            0.0.19
service-identity                   18.1.0
setuptools                         50.3.2
Shapely                            1.7.1
simplegeneric                      0.8.1
SimpleITK                          2.0.2
simplejson                         3.17.2
singledispatch                     3.4.0.3
six                                1.15.0
sklearn                            0.0
slidingwindow                      0.0.14
smart-open                         5.1.0
smmap                              4.0.0
sniffio                            1.2.0
snowballstemmer                    1.2.1
snuggs                             1.4.7
sortedcollections                  1.1.2
sortedcontainers                   2.1.0
SoundFile                          0.10.3.post1
soupsieve                          1.8
spacy                              2.3.7
Sphinx                             1.8.5
sphinxcontrib-websupport           1.1.0
spyder                             3.3.3
spyder-kernels                     0.4.2
SQLAlchemy                         1.4.13
srsly                              1.0.5
statsmodels                        0.13.1
streamlit                          0.89.0
sympy                              1.3
syntok                             1.3.1
tables                             3.5.2
tensorboard                        2.4.0
tensorboard-data-server            0.6.0
tensorboard-plugin-wit             1.8.1
tensorboardX                       2.5
test-tube                          0.7.5
testpath                           0.4.2
thinc                              7.4.5
thop                               0.0.31.post2005241907
threadpoolctl                      2.1.0
tifffile                           2021.4.8
tiffile                            2018.10.18
timm                               0.4.12
tokenizers                         0.10.3
toml                               0.10.2
tomlkit                            0.7.2
toolz                              0.9.0
torch                              1.7.0+cu110
torch2trt                          0.3.0
torchaudio                         0.7.0
torchfile                          0.1.0
torchgeometry                      0.1.2
torchmetrics                       0.5.1
torchnet                           0.0.4
torchsummary                       1.5.1
torchvision                        0.8.1+cu110
tornado                            6.1
tqdm                               4.48.2
traceback2                         1.4.0
traitlets                          4.3.2
traits                             6.2.0
traitsui                           7.2.1
transformers                       4.3.3
ttach                              0.0.3
Twisted                            19.2.0

开始

接下来我把我改的地方详细点放出来,有的地方还有些中文注释,大家仔细对比着源码看看改了哪里吧,我测试了训练建筑的情况,个人感觉效果不太理想,下面是预测结果,放大了细节感觉不大行,不知道是不是因为这个网络不太适应这种遥感数据。(更新:初步排查了下,问题出在loss函数,建议使用下面这个dice函数,再搭配nn.BCELoss应该就更好了segmentation_models.pytorch/dice.py at master · qubvel/segmentation_models.pytorch · GitHubSegmentation models with pretrained backbones. PyTorch. - segmentation_models.pytorch/dice.py at master · qubvel/segmentation_models.pytorchhttps://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/losses/dice.py

注意原始代码是多分类,我这里是改成二分类,下面是发生改动的所有代码(源码里的内容我这里其实没有删除,都是注释了以后加自己的)

1.改动部分

train.py,小改动,主要是参数部分,有的去掉了,需要注意的是图像的大小,最好是2的倍数,并且要能整除swin_tiny_patch4_window7_224_lite.yaml文件中的WINDOW_SIZE

# -*- coding: utf-8 -*-
import argparse
import logging
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from networks.vision_transformer import SwinUnet as ViT_seg
from trainer import trainer_synapse
from config import get_configparser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,default='./data/build512/', help='root dir for data')#改了代码以后这个参数用不着了,下面涉及到的地方都可以不用管
parser.add_argument('--dataset', type=str, default='Synapse', help='experiment_name')
#改了代码以后这个参数用不着了,下面涉及到的地方都可以不用管
parser.add_argument('--list_dir', type=str, default='./lists/lists_Synapse', help='list dir')parser.add_argument('--num_classes', type=int, default=1, help='output channel of network')
parser.add_argument('--output_dir', type=str, default='./weights/', help='output dir')
parser.add_argument('--max_iterations', type=int, default=30000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int, default=200, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int,  default=1, help='whether use deterministic training')
parser.add_argument('--base_lr', type=float,  default=1e-3, help='segmentation network learning rate')
parser.add_argument('--img_size', type=int, default=512, help='input patch size of network input')
parser.add_argument('--seed', type=int, default=1234, help='random seed')
parser.add_argument('--cfg', type=str, default='./configs/swin_tiny_patch4_window7_224_lite.yaml' , required=False, metavar="FILE", help='path to config file', )
parser.add_argument("--opts",help="Modify config options by adding 'KEY VALUE' pairs. ",default=None,nargs='+',)
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],help='no: no cache, ''full: cache all data, ''part: sharding the dataset into nonoverlapping pieces and only cache one piece')
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
parser.add_argument('--use-checkpoint', action='store_true',help="whether to use gradient checkpointing to save memory")
parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],help='mixed precision opt level, if O0, no amp is used')
parser.add_argument('--tag', help='tag of experiment')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--throughput', action='store_true', help='Test throughput only')args = parser.parse_args()
if args.dataset == "Synapse":# args.root_path = os.path.join(args.root_path, "train_npz")pass
config = get_config(args)if __name__ == "__main__":if not args.deterministic:cudnn.benchmark = Truecudnn.deterministic = Falseelse:cudnn.benchmark = Falsecudnn.deterministic = Truerandom.seed(args.seed)np.random.seed(args.seed)torch.manual_seed(args.seed)torch.cuda.manual_seed(args.seed)dataset_name = args.datasetdataset_config = {'Synapse': {'root_path': args.root_path,'list_dir': './lists/lists_Synapse','num_classes': args.num_classes,},}if args.batch_size != 24 and args.batch_size % 6 == 0:args.base_lr *= args.batch_size / 24args.num_classes = dataset_config[dataset_name]['num_classes']args.root_path = dataset_config[dataset_name]['root_path']args.list_dir = dataset_config[dataset_name]['list_dir']if not os.path.exists(args.output_dir):os.makedirs(args.output_dir)net = ViT_seg(config, img_size=args.img_size, num_classes=args.num_classes).cuda()net.load_from(config)trainer = {'Synapse': trainer_synapse,}trainer[dataset_name](args, net, args.output_dir)

dataset_synapse.py,大改动,主要是新增了加载自己数据的函数,里面有注释的

# -*- coding: utf-8 -*-
import os
import cv2
import random
import h5py
import numpy as np
import torch
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Datasetdef random_rot_flip(image, label):k = np.random.randint(0, 4)image = np.rot90(image, k)label = np.rot90(label, k)axis = np.random.randint(0, 2)image = np.flip(image, axis=axis).copy()label = np.flip(label, axis=axis).copy()return image, labeldef random_rotate(image, label):angle = np.random.randint(-20, 20)image = ndimage.rotate(image, angle, order=0, reshape=False)label = ndimage.rotate(label, angle, order=0, reshape=False)return image, labelclass RandomGenerator(object):def __init__(self, output_size):self.output_size = output_sizedef __call__(self, sample):image, label = sample['image'], sample['label']if random.random() > 0.5:image, label = random_rot_flip(image, label)elif random.random() > 0.5:image, label = random_rotate(image, label)x, y = image.shapeif x != self.output_size[0] or y != self.output_size[1]:image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3)  # why not 3?label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)label = torch.from_numpy(label.astype(np.float32))sample = {'image': image, 'label': label.long()}return sampleclass Synapse_dataset(Dataset):def __init__(self, base_dir, list_dir, split, transform=None):self.transform = transform  # using transform in torch!self.split = splitself.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines()self.data_dir = base_dirdef __len__(self):return len(self.sample_list)def __getitem__(self, idx):if self.split == "train":slice_name = self.sample_list[idx].strip('\n')data_path = os.path.join(self.data_dir, slice_name+'.npz')data = np.load(data_path)image, label = data['image'], data['label']else:vol_name = self.sample_list[idx].strip('\n')filepath = self.data_dir + "/{}.npy.h5".format(vol_name)data = h5py.File(filepath)image, label = data['image'][:], data['label'][:]sample = {'image': image, 'label': label}if self.transform:sample = self.transform(sample)sample['case_name'] = self.sample_list[idx].strip('\n')return sample# 这里开始是自己添加的用于加载自己数据的标准数据加载函数,可以参考用于别的复现!#***********************数据增强部分************************************
def randomHueSaturationValue(image, hue_shift_limit=(-180, 180),sat_shift_limit=(-255, 255),val_shift_limit=(-255, 255), u=0.5):if np.random.random() < u:image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)h, s, v = cv2.split(image)hue_shift = np.random.randint(hue_shift_limit[0], hue_shift_limit[1]+1)hue_shift = np.uint8(hue_shift)h += hue_shiftsat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1])s = cv2.add(s, sat_shift)val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1])v = cv2.add(v, val_shift)image = cv2.merge((h, s, v))#image = cv2.merge((s, v))image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)return imagedef randomShiftScaleRotate(image, mask,shift_limit=(-0.0, 0.0),scale_limit=(-0.0, 0.0),rotate_limit=(-0.0, 0.0), aspect_limit=(-0.0, 0.0),borderMode=cv2.BORDER_CONSTANT, u=0.5):if np.random.random() < u:height, width, channel = image.shapeangle = np.random.uniform(rotate_limit[0], rotate_limit[1])scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1])aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1])sx = scale * aspect / (aspect ** 0.5)sy = scale / (aspect ** 0.5)dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width)dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height)cc = np.math.cos(angle / 180 * np.math.pi) * sxss = np.math.sin(angle / 180 * np.math.pi) * syrotate_matrix = np.array([[cc, -ss], [ss, cc]])box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ])box1 = box0 - np.array([width / 2, height / 2])box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy])box0 = box0.astype(np.float32)box1 = box1.astype(np.float32)mat = cv2.getPerspectiveTransform(box0, box1)image = cv2.warpPerspective(image, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,borderValue=(0, 0,0,))mask = cv2.warpPerspective(mask, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,borderValue=(0, 0,0,))return image, maskdef randomHorizontalFlip(image, mask, u=0.5):if np.random.random() < u:image = cv2.flip(image, 1)mask = cv2.flip(mask, 1)return image, maskdef randomVerticleFlip(image, mask, u=0.5):if np.random.random() < u:image = cv2.flip(image, 0)mask = cv2.flip(mask, 0)return image, maskdef randomRotate90(image, mask, u=0.5):if np.random.random() < u:image=np.rot90(image)mask=np.rot90(mask)return image, mask#**********************加载自己数据相关的函数****************************
#遍历数据文件夹,这里注意路径是拼接的
def read_own_data(root_path, mode = 'train'):images = []masks = []image_root = os.path.join(root_path, mode + '/images')gt_root = os.path.join(root_path, mode + '/labels')for image_name in os.listdir(gt_root):image_path = os.path.join(image_root, image_name)label_path = os.path.join(gt_root, image_name)images.append(image_path)masks.append(label_path)return images, masks#训练数据读取
def own_data_loader(img_path, mask_path):img = cv2.imread(img_path)# img = cv2.resize(img, (512,512), interpolation = cv2.INTER_NEAREST)mask = cv2.imread(mask_path, 0)# mask = cv2.resize(mask, (512,512), interpolation = cv2.INTER_NEAREST)img = randomHueSaturationValue(img,hue_shift_limit=(-30, 30),sat_shift_limit=(-5, 5),val_shift_limit=(-15, 15))img, mask = randomShiftScaleRotate(img, mask,shift_limit=(-0.1, 0.1),scale_limit=(-0.1, 0.1),aspect_limit=(-0.1, 0.1),rotate_limit=(-0, 0))img, mask = randomHorizontalFlip(img, mask)img, mask = randomVerticleFlip(img, mask)img, mask = randomRotate90(img, mask)mask = np.expand_dims(mask, axis=2)img = np.array(img, np.float32) / 255.0 * 3.2 - 1.6# img = np.array(img, np.float32) / 255.0# mask = np.array(mask, np.float32)mask = np.array(mask, np.float32) / 255.0mask[mask >= 0.5] = 1mask[mask < 0.5] = 0img = np.array(img, np.float32).transpose(2, 0, 1)mask = np.array(mask, np.float32).transpose(2, 0, 1)return img, mask#验证数据读取
def own_data_test_loader(img_path, mask_path):img = cv2.imread(img_path)# img = cv2.resize(img, (512,512), interpolation = cv2.INTER_NEAREST)mask = cv2.imread(mask_path, 0)# mask = cv2.resize(mask, (512,512), interpolation = cv2.INTER_NEAREST)mask = np.expand_dims(mask, axis=2)img = np.array(img, np.float32) / 255.0 * 3.2 - 1.6# img = np.array(img, np.float32) / 255.0# mask = np.array(mask, np.float32)mask = np.array(mask, np.float32) / 255.0mask[mask >= 0.5] = 1mask[mask < 0.5] = 0# mask[mask > 0] = 1img = np.array(img, np.float32).transpose(2, 0, 1)mask = np.array(mask, np.float32).transpose(2, 0, 1)return img, maskclass ImageFolder(Dataset):def __init__(self, root_path, mode='train'):self.root = root_pathself.mode = modeself.images, self.labels = read_own_data(self.root, self.mode)def __getitem__(self, index):if self.mode == 'test':img, mask = own_data_test_loader(self.images[index], self.labels[index])else:img, mask = own_data_loader(self.images[index], self.labels[index])img = torch.Tensor(img)mask = torch.Tensor(mask)return img, maskdef __len__(self):# assert len(self.images) == len(self.labels), 'The number of images must be equal to labels'return len(self.images)

trainer.py,大改动,主要是调用自己的数据加载函数,以及损失函数,还加了些学习率下降策略

import argparse
import logging
import os
import random
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from torch.nn.modules.loss import CrossEntropyLoss
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import DiceLoss, BinaryDiceLoss
from torchvision import transforms
from utils import test_single_volume
from pytorch_toolbelt import losses as L
from datasets.dataset_synapse import ImageFolderdef trainer_synapse(args, model, snapshot_path):# from datasets.dataset_synapse import Synapse_dataset, RandomGeneratorlogging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))logging.info(str(args))base_lr = args.base_lrnum_classes = args.num_classesbatch_size = args.batch_size * args.n_gpumax_iterations = args.max_iterations# db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train",#                            transform=transforms.Compose(#                                [RandomGenerator(output_size=[args.img_size, args.img_size])]))#换成自己的db_train = ImageFolder(args.root_path, mode='train')print("The length of train set is: {}".format(len(db_train)))def worker_init_fn(worker_id):random.seed(args.seed + worker_id)# trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,#                          worker_init_fn=worker_init_fn)#换成自己的trainloader = DataLoader(db_train,batch_size=batch_size,shuffle=True,num_workers=0,pin_memory=True,worker_init_fn=worker_init_fn)if args.n_gpu > 1:model = nn.DataParallel(model)model.train()# ce_loss = CrossEntropyLoss()# bce_loss = nn.BCELoss()# dice_loss = DiceLoss(num_classes)bce_loss = nn.BCEWithLogitsLoss()dice_loss = BinaryDiceLoss()loss_fn = L.JointLoss(first=dice_loss, second=bce_loss, first_weight=0.5, second_weight=0.5).cuda()# optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)optimizer = torch.optim.AdamW(model.parameters(),lr=base_lr, weight_decay=1e-3)scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=2, # T_0就是初始restart的epoch数目T_mult=2, # T_mult就是重启之后因子,即每个restart后,T_0 = T_0 * T_multeta_min=1e-6 # 最低学习率) writer = SummaryWriter(snapshot_path + '/log')iter_num = 0max_epoch = args.max_epochsmax_iterations = args.max_epochs * len(trainloader)  # max_epoch = max_iterations // len(trainloader) + 1logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations))best_performance = 0.0iterator = tqdm(range(max_epoch), ncols=70)for epoch_num in iterator:# for i_batch, sampled_batch in enumerate(trainloader):for image_batch, label_batch in trainloader:# image_batch, label_batch = sampled_batch['image'], sampled_batch['label']image_batch, label_batch = image_batch.cuda(), label_batch.cuda()outputs = model(image_batch)# print(outputs) #torch.Size([6, 2, 224, 224])# print(label_batch.shape) #torch.Size([6, 1, 224, 224])#这里的ce_loss = CrossEntropyLoss()常用于多分类,换成BCELoss# loss_ce = ce_loss(outputs, label_batch[:].long())# loss_dice = dice_loss(outputs, label_batch, softmax=True)# loss = 0.4 * loss_ce + 0.6 * loss_diceoutputs= torch.squeeze(outputs)label_batch = torch.squeeze(label_batch)# loss_ce = bce_loss(outputs, label_batch)# loss_dice = dice_loss(outputs, label_batch)# loss = 0.4 * loss_ce + 0.6 * loss_diceloss = loss_fn(outputs, label_batch)optimizer.zero_grad()loss.backward()optimizer.step()lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9for param_group in optimizer.param_groups:param_group['lr'] = lr_iter_num = iter_num + 1writer.add_scalar('info/lr', lr_, iter_num)writer.add_scalar('info/total_loss', loss, iter_num)# writer.add_scalar('info/loss_ce', loss_ce, iter_num)# logging.info('iteration %d : loss : %f, loss_ce: %f' % (iter_num, loss.item(), loss_ce.item()))if iter_num % 20 == 0:image = image_batch[1, 0:1, :, :]image = (image - image.min()) / (image.max() - image.min())writer.add_image('train/Image', image, iter_num)# outputs = torch.argmax(torch.softmax(outputs, dim=1), dim=1, keepdim=True)# writer.add_image('train/Prediction', outputs[1, ...] * 50, iter_num)outputs = torch.sigmoid(outputs)outputs[outputs>=0.5] = 1outputs[outputs<0.5] = 0temp = torch.unsqueeze(outputs[0],0)writer.add_image('train/Prediction', temp * 50, iter_num)labs = label_batch[1, ...].unsqueeze(0) * 50writer.add_image('train/GroundTruth', labs, iter_num)logging.info('iteration %d : loss : %f' % (iter_num, loss.item()))save_interval = 10  # int(max_epoch/6)if epoch_num > int(max_epoch / 2) and (epoch_num + 1) % save_interval == 0:save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')torch.save(model.state_dict(), save_mode_path)logging.info("save model to {}".format(save_mode_path))if epoch_num >= max_epoch - 1:save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')torch.save(model.state_dict(), save_mode_path)logging.info("save model to {}".format(save_mode_path))iterator.close()breakscheduler.step()writer.close()return "Training Finished!"

utils.py,小改动,主要是加了二分类的diceloss函数

import numpy as np
import torch
from medpy import metric
from scipy.ndimage import zoom
import torch.nn as nn
import SimpleITK as sitkclass BinaryDiceLoss(nn.Module):"""Dice loss of binary classArgs:smooth: A float number to smooth loss, and avoid NaN error, default: 1p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2predict: A tensor of shape [N, *]target: A tensor of shape same with predictreduction: Reduction method to apply, return mean over batch if 'mean',return sum if 'sum', return a tensor of shape [N,] if 'none'Returns:Loss tensor according to arg reductionRaise:Exception if unexpected reduction"""def __init__(self, smooth=1, p=2, reduction='mean'):super(BinaryDiceLoss, self).__init__()self.smooth = smoothself.p = pself.reduction = reductiondef forward(self, predict, target):assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"predict = predict.contiguous().view(predict.shape[0], -1)target = target.contiguous().view(target.shape[0], -1)num = torch.sum(torch.mul(predict, target), dim=1) + self.smoothden = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smoothloss = 1 - num / denif self.reduction == 'mean':return loss.mean()elif self.reduction == 'sum':return loss.sum()elif self.reduction == 'none':return losselse:raise Exception('Unexpected reduction {}'.format(self.reduction))class DiceLoss(nn.Module):def __init__(self, n_classes):super(DiceLoss, self).__init__()self.n_classes = n_classesdef _one_hot_encoder(self, input_tensor):tensor_list = []for i in range(self.n_classes):temp_prob = input_tensor == i  # * torch.ones_like(input_tensor)tensor_list.append(temp_prob.unsqueeze(1))output_tensor = torch.cat(tensor_list, dim=1)return output_tensor.float()def _dice_loss(self, score, target):target = target.float()smooth = 1e-5intersect = torch.sum(score * target)y_sum = torch.sum(target * target)z_sum = torch.sum(score * score)loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)loss = 1 - lossreturn lossdef forward(self, inputs, target, weight=None, softmax=False):if softmax:inputs = torch.softmax(inputs, dim=1)target = self._one_hot_encoder(target)#这里多了一个维度,去掉if not softmax:inputs = torch.squeeze(inputs)target = torch.squeeze(target)if weight is None:weight = [1] * self.n_classesassert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())class_wise_dice = []loss = 0.0for i in range(0, self.n_classes):dice = self._dice_loss(inputs[:, i], target[:, i])class_wise_dice.append(1.0 - dice.item())loss += dice * weight[i]return loss / self.n_classesdef calculate_metric_percase(pred, gt):pred[pred > 0] = 1gt[gt > 0] = 1if pred.sum() > 0 and gt.sum()>0:dice = metric.binary.dc(pred, gt)hd95 = metric.binary.hd95(pred, gt)return dice, hd95elif pred.sum() > 0 and gt.sum()==0:return 1, 0else:return 0, 0def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()if len(image.shape) == 3:prediction = np.zeros_like(label)for ind in range(image.shape[0]):slice = image[ind, :, :]x, y = slice.shape[0], slice.shape[1]if x != patch_size[0] or y != patch_size[1]:slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3)  # previous using 0input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()net.eval()with torch.no_grad():outputs = net(input)out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)out = out.cpu().detach().numpy()if x != patch_size[0] or y != patch_size[1]:pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)else:pred = outprediction[ind] = predelse:input = torch.from_numpy(image).unsqueeze(0).unsqueeze(0).float().cuda()net.eval()with torch.no_grad():out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)prediction = out.cpu().detach().numpy()metric_list = []for i in range(1, classes):metric_list.append(calculate_metric_percase(prediction == i, label == i))if test_save_path is not None:img_itk = sitk.GetImageFromArray(image.astype(np.float32))prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32))lab_itk = sitk.GetImageFromArray(label.astype(np.float32))img_itk.SetSpacing((1, 1, z_spacing))prd_itk.SetSpacing((1, 1, z_spacing))lab_itk.SetSpacing((1, 1, z_spacing))sitk.WriteImage(prd_itk, test_save_path + '/'+case + "_pred.nii.gz")sitk.WriteImage(img_itk, test_save_path + '/'+ case + "_img.nii.gz")sitk.WriteImage(lab_itk, test_save_path + '/'+ case + "_gt.nii.gz")return metric_list

test.py,大改动,原始的测试是要加载标签做评价的,这里我直接注释了然后加了自己的数据加载直接做预测看效果,不评价

import argparse
import logging
import os
import random
import sys
import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from datasets.dataset_synapse import Synapse_dataset
from utils import test_single_volume
from networks.vision_transformer import SwinUnet as ViT_seg
from trainer import trainer_synapse
from config import get_config
from datasets.dataset_synapse import ImageFolderparser = argparse.ArgumentParser()
parser.add_argument('--volume_path', type=str,default='../data/Synapse/test_vol_h5', help='root dir for validation volume data')  # for acdc volume_path=root_dir
parser.add_argument('--dataset', type=str,default='Synapse', help='experiment_name')
parser.add_argument('--num_classes', type=int,default=1, help='output channel of network')
parser.add_argument('--list_dir', type=str,default='./lists/lists_Synapse', help='list dir')
parser.add_argument('--output_dir', type=str, default='./predictions/', help='output dir')
parser.add_argument('--max_iterations', type=int,default=30000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int, default=150, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int, default=6,help='batch_size per gpu')
parser.add_argument('--img_size', type=int, default=512, help='input patch size of network input')
parser.add_argument('--is_savenii', action="store_true", help='whether to save results during inference')
parser.add_argument('--test_save_dir', type=str, default='../predictions', help='saving prediction as nii!')
parser.add_argument('--deterministic', type=int,  default=1, help='whether use deterministic training')
parser.add_argument('--base_lr', type=float,  default=0.01, help='segmentation network learning rate')
parser.add_argument('--seed', type=int, default=1234, help='random seed')
# parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
parser.add_argument('--cfg', type=str, default='./configs/swin_tiny_patch4_window7_224_lite.yaml' , required=False, metavar="FILE", help='path to config file', )
parser.add_argument("--opts",help="Modify config options by adding 'KEY VALUE' pairs. ",default=None,nargs='+',)
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],help='no: no cache, ''full: cache all data, ''part: sharding the dataset into nonoverlapping pieces and only cache one piece')
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
parser.add_argument('--use-checkpoint', action='store_true',help="whether to use gradient checkpointing to save memory")
parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],help='mixed precision opt level, if O0, no amp is used')
parser.add_argument('--tag', help='tag of experiment')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--throughput', action='store_true', help='Test throughput only')args = parser.parse_args()
if args.dataset == "Synapse":args.volume_path = os.path.join(args.volume_path, "test_vol_h5")
config = get_config(args)def inference(args, model, test_save_path=None):db_test = args.Dataset(base_dir=args.volume_path, split="test_vol", list_dir=args.list_dir)testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)logging.info("{} test iterations per epoch".format(len(testloader)))model.eval()metric_list = 0.0for i_batch, sampled_batch in tqdm(enumerate(testloader)):h, w = sampled_batch["image"].size()[2:]image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0]metric_i = test_single_volume(image, label, model, classes=args.num_classes, patch_size=[args.img_size, args.img_size],test_save_path=test_save_path, case=case_name, z_spacing=args.z_spacing)metric_list += np.array(metric_i)logging.info('idx %d case %s mean_dice %f mean_hd95 %f' % (i_batch, case_name, np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1]))metric_list = metric_list / len(db_test)for i in range(1, args.num_classes):logging.info('Mean class %d mean_dice %f mean_hd95 %f' % (i, metric_list[i-1][0], metric_list[i-1][1]))performance = np.mean(metric_list, axis=0)[0]mean_hd95 = np.mean(metric_list, axis=0)[1]logging.info('Testing performance in best val model: mean_dice : %f mean_hd95 : %f' % (performance, mean_hd95))return "Testing Finished!"# def inference(model, test_root, test_save_path):
#     db_test = ImageFolder(test_root,mode='test')
#     testloader = DataLoader(
#         db_test,
#         batch_size=1,
#         shuffle=True,
#         num_workers=0)#     for image_batch, label_batch in testloader:
#         image_batch, label_batch = image_batch.cuda(), label_batch.cuda()
#         outputs = model(image_batch)
#         print(outputs.shape)DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
def inference_single(model, model_path, test_path, save_path):model.to(DEVICE)model.load_state_dict(torch.load(model_path))model.eval()im_names = os.listdir(test_path)for name in im_names:full_path = os.path.join(test_path, name)img = cv2.imread(full_path)# img = cv2.resize(img, (512,512), interpolation = cv2.INTER_NEAREST)# image = np.array(img, np.float32) / 255.0image = np.array(img, np.float32) / 255.0 * 3.2 - 1.6image = np.array(image, np.float32).transpose(2, 0, 1)image = np.expand_dims(image, axis=0)image = torch.Tensor(image)image = image.cuda()output = model(image).cpu().data.numpy()output[output < 0.5] = 0output[output >= 0.5] = 1output = np.squeeze(output)save_full = os.path.join(save_path, name)cv2.imwrite(save_full, output*255)if __name__ == "__main__":# if not args.deterministic:#     cudnn.benchmark = True#     cudnn.deterministic = False# else:#     cudnn.benchmark = False#     cudnn.deterministic = True# random.seed(args.seed)# np.random.seed(args.seed)# torch.manual_seed(args.seed)# torch.cuda.manual_seed(args.seed)# dataset_config = {#     'Synapse': {#         'Dataset': Synapse_dataset,#         'volume_path': args.volume_path,#         'list_dir': './lists/lists_Synapse',#         'num_classes': 9,#         'z_spacing': 1,#     },# }# dataset_name = args.dataset# args.num_classes = dataset_config[dataset_name]['num_classes']# args.volume_path = dataset_config[dataset_name]['volume_path']# args.Dataset = dataset_config[dataset_name]['Dataset']# args.list_dir = dataset_config[dataset_name]['list_dir']# args.z_spacing = dataset_config[dataset_name]['z_spacing']# args.is_pretrain = True# net = ViT_seg(config, img_size=args.img_size, num_classes=args.num_classes).cuda()# snapshot = os.path.join(args.output_dir, 'best_model.pth')# if not os.path.exists(snapshot): snapshot = snapshot.replace('best_model', 'epoch_'+str(args.max_epochs-1))# msg = net.load_state_dict(torch.load(snapshot))# print("self trained swin unet",msg)# snapshot_name = snapshot.split('/')[-1]# log_folder = './test_log/test_log_'# os.makedirs(log_folder, exist_ok=True)# logging.basicConfig(filename=log_folder + '/'+snapshot_name+".txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')# logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))# logging.info(str(args))# logging.info(snapshot_name)# if args.is_savenii:#     args.test_save_dir = os.path.join(args.output_dir, "predictions")#     test_save_path = args.test_save_dir #     os.makedirs(test_save_path, exist_ok=True)# else:#     test_save_path = None# inference(args, net, test_save_path)args = parser.parse_args()config = get_config(args)net = ViT_seg(config, img_size=args.img_size, num_classes=args.num_classes).cuda()test_root = 'D:/csdn/Swin-Unet/data/build512/val/images/'test_save_path = './predictions/'model_path = './weights/epoch_179.pth'inference_single(net, model_path, test_root, test_save_path)

2.训练

做好以上改动后,把下面标号的地方改成自己的路径和想设置的参数就可以直接在命令行

python train.py 开始训练了(注意img_size改了以后记得在config.py里对应的也改,不然报错

3.预测

同上

上面已经是改动代码加注释了,应该可以跟着改肯定可以跑通的,下面的付费,建议不要管,实在不行的可以考虑

Swin-Unet-Transformer网络-用于语义分割-二分类-深度学习文档类资源-CSDN下载1.增加了数据加载部分,二分类loss2.必要的中文注释3.附带了自己的数据集4.有问题随时联更多下载资源、学习资料请访问CSDN下载频道.https://download.csdn.net/download/qq_20373723/85012614

题外话:有什么新的比较好的网络可以评论推荐给我,我来复现贴出来大家一起用一用 

Swin-Unet跑自己的数据集(Transformer用于语义分割)相关推荐

  1. 【深度学习】Transformer在语义分割上的应用探索

    [深度学习]Transformer在语义分割上的应用探索 文章目录 1 Segmenter 2 Swin-Unet:Unet形状的纯Transformer的医学图像分割 3 复旦大学提出SETR:基于 ...

  2. 【深度学习】SETR:基于视觉 Transformer 的语义分割模型

    Visual Transformer Author:louwill Machine Learning Lab 自从Transformer在视觉领域大火之后,一系列下游视觉任务应用研究也随之多了起来.基 ...

  3. Dynamic Routing-中科院西交旷视(孙剑团队)提出用于语义分割的动态路由网络,精确感知多尺度目标,代码已开源!...

    关注公众号,发现CV技术之美 ▊ 写在前面 近年来,大量手工设计和基于搜索的网络被用于语义分割.然而,以前的工作(如FCN.U-Net和DeepLab系列)希望在预定义的静态网络结构中处理不同规模的输 ...

  4. 目前缺少用于语义分割的 3D LiDAR 数据吗?关于三维点云数据集和方法的调查

    目前缺少用于语义分割的 3D LiDAR 数据吗?关于三维点云数据集和方法的调查 原文 Are We Hungry for 3D LiDAR Data for Semantic Segmentatio ...

  5. CVPR2020 | 即插即用!将双边超分辨率用于语义分割网络,提升图像分辨率的有效策略...

    点击上方"AI算法修炼营",选择"星标"公众号 精选作品,第一时间送达 本文是收录于CVPR2020的工作,文章利用低分辨率的输入图像通过超分辨率网络生成高分辨 ...

  6. *基于类平衡自我训练的无监督域自适应用于语义分割

    基于类平衡自我训练的无监督域自适应用于语义分割 摘要:最近的深度网络实现了最先进的性能在各种语义分割任务中.尽管有这样的进步,但是这些模型在现实世界中面临挑战,它们存在很大的差别在已标签训练/源文件和 ...

  7. CVF2020邻域自适应/语义分割:FDA: Fourier Domain Adaptation for Semantic SegmentationFDA:用于语义分割的傅立叶域自适应算法

    邻域自适应/语义分割:FDA: Fourier Domain Adaptation for Semantic Segmentation FDA:用于语义分割的傅立叶域自适应算法 0.摘要 1.概述 1 ...

  8. LiteSeg: 一种用于语义分割的轻量级ConvNet

    作者:Tom Hardy Date:2020-02-09 来源:LiteSeg: 一种用于语义分割的轻量级ConvNet

  9. 用于语义分割的解码器 diffusion 预训练方法

    目录 前言 DPSS 方法概述 DeP 和 DDeP 基础网络结构 损失函数 diffusion 的扩展 实验 总结 参考 本文首发于 GiantPandaCV,未经允许不得转载!! 前言 当前语义分 ...

最新文章

  1. mysql 8.0.21 安装配置方法图文教程
  2. 为什么程序员不擅长估算时间
  3. MyBatis的入门知识
  4. 1Python全栈之路系列Web框架介绍
  5. Web API 速率限制(二)- 令牌桶算法简介
  6. uni开发中可以用table标签么_「uni-app 组件」t-table 表格
  7. android p dp5,谷歌释出Android P第5个开发者预览版更新!
  8. 基于linux的电子邮件服务(sendmail)
  9. koa2 mysql_koa2+vue+mysql 全栈开发记录
  10. python画一条曲线有不同的形状_Python+pandas+matplotlib控制不同曲线的属性 !
  11. 电商场景中的精排服务实践
  12. 草根学Python(十四) 一步一步了解正则表达式
  13. hibernate java内存一次能取多少条_Hibernate性能测试(load10000条记录的简单测试 仅供参考)(转)...
  14. Android应用开发
  15. 百度SEO站群超简约实惠的个人简历模板源码
  16. CoAP协议 libcoap API
  17. C语言中access/_access函数的使用
  18. int 和 枚举类型enum的转换 in c++
  19. mbp网速很慢_mac网速慢的解决办法_mac上网速度极其慢如何处理-win7之家
  20. 九大PHP开源Wiki(维基百科)程序评测

热门文章

  1. 用c#开发Android应用(二)——运行Hello World!
  2. Excel添加固定文本到开头的2种操作方法
  3. Python3.6实现图片转文字
  4. android 动画 最顶层_【尼康影像学院】使用SnapBridge连接照相机与智能手机(Android安卓系统)...
  5. 对于短信验证码登录流程详细步骤
  6. 计算机键盘灯光怎么关闭,笔记本怎么关键盘的灯_笔记本电脑关闭键盘背光的步骤-win7之家...
  7. html 常用标签、特殊符号
  8. 8321平台修改WIFI_ONLY总结
  9. python自学网站-杭州python自学网站
  10. recycleview横向展开_android 横向recyclerView 数据居中,从中间往两边展开显示