数据集下载地址

数据集来源天池比赛:零基础入门语义分割-地表建筑物识别-天池大赛-阿里云天池 (aliyun.com)

| test_a.zip              | 314.49MB | http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531872/%E5%9C%B0%E8%A1%A8%E5%BB%BA%E7%AD%91%E7%89%A9%E8%AF%86%E5%88%AB/test_a.zip |
| test_a_samplesubmit.csv | 46.39KB  | http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531872/%E5%9C%B0%E8%A1%A8%E5%BB%BA%E7%AD%91%E7%89%A9%E8%AF%86%E5%88%AB/test_a_samplesubmit.csv |
| train.zip               | 3.68GB   | http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531872/%E5%9C%B0%E8%A1%A8%E5%BB%BA%E7%AD%91%E7%89%A9%E8%AF%86%E5%88%AB/train.zip |
| train_mask.csv.zip      | 97.52MB  | http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531872/%E5%9C%B0%E8%A1%A8%E5%BB%BA%E7%AD%91%E7%89%A9%E8%AF%86%E5%88%AB/train_mask.csv.zip |

完整代码

训练代码

#!/usr/bin/env python
# coding: utf-8import numpy as np
import pandas as pd
import pathlib, sys, os, random, time
import numba, cv2, gc
# from tqdm import tqdm_notebook
from tqdm import tqdm
import matplotlib.pyplot as plt
# get_ipython().run_line_magic('matplotlib', 'inline')import warningswarnings.filterwarnings('ignore')
from sklearn.model_selection import KFold
import albumentations as A
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
import torch.utils.data as D
from torchvision import transforms as TEPOCHES = 120
BATCH_SIZE = 4
IMAGE_SIZE = 512
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'import logginglogging.basicConfig(filename='log_unet_sh_fold_4_s.log',format='%(asctime)s - %(name)s - %(levelname)s -%(module)s:  %(message)s',datefmt='%Y-%m-%d %H:%M:%S ',level=logging.INFO)def set_seeds(seed=42):random.seed(seed)os.environ['PYTHONHASHSEED'] = str(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falseset_seeds()def rle_encode(im):'''im: numpy array, 1 - mask, 0 - backgroundReturns run length as string formated'''pixels = im.flatten(order='F')pixels = np.concatenate([[0], pixels, [0]])runs = np.where(pixels[1:] != pixels[:-1])[0] + 1runs[1::2] -= runs[::2]return ' '.join(str(x) for x in runs)def rle_decode(mask_rle, shape=(512, 512)):'''mask_rle: run-length as string formated (start length)shape: (height,width) of array to returnReturns numpy array, 1 - mask, 0 - background'''s = mask_rle.split()starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]starts -= 1ends = starts + lengthsimg = np.zeros(shape[0] * shape[1], dtype=np.uint8)for lo, hi in zip(starts, ends):img[lo:hi] = 1return img.reshape(shape, order='F')train_trfm = A.Compose([A.Resize(IMAGE_SIZE, IMAGE_SIZE),A.HorizontalFlip(p=0.5),A.VerticalFlip(p=0.5),A.RandomRotate90(),A.OneOf([A.RandomContrast(),A.RandomGamma(),A.RandomBrightness(),A.ColorJitter(brightness=0.07, contrast=0.07,saturation=0.1, hue=0.1, always_apply=False, p=0.3),], p=0.3),])val_trfm = A.Compose([A.Resize(IMAGE_SIZE, IMAGE_SIZE),A.HorizontalFlip(p=0.5),A.VerticalFlip(p=0.5),A.RandomRotate90()
])class TianChiDataset(D.Dataset):def __init__(self, paths, rles, transform, test_mode=False):self.paths = pathsself.rles = rlesself.transform = transformself.test_mode = test_modeself.len = len(paths)self.as_tensor = T.Compose([T.ToPILImage(),T.Resize(IMAGE_SIZE),T.ToTensor(),T.Normalize([0.625, 0.448, 0.688],[0.131, 0.177, 0.101]),])# get data operationdef __getitem__(self, index):img = cv2.imread(self.paths[index])if not self.test_mode:mask = rle_decode(self.rles[index])augments = self.transform(image=img, mask=mask)return self.as_tensor(augments['image']), augments['mask'][None]else:return self.as_tensor(img), ''def __len__(self):"""Total number of samples in the dataset"""return self.lentrain_mask = pd.read_csv('./data/train_mask.csv', sep='\t', names=['name', 'mask'])
train_mask['name'] = train_mask['name'].apply(lambda x: './data/train/' + x)img = cv2.imread(train_mask['name'].iloc[0])
mask = rle_decode(train_mask['mask'].iloc[0])dataset = TianChiDataset(train_mask['name'].values,train_mask['mask'].fillna('').values,train_trfm, False
)skf = KFold(n_splits=5)
idx = np.array(range(len(dataset)))@torch.no_grad()
def validation(model, loader, loss_fn):losses = []model.eval()for image, target in loader:image, target = image.to(DEVICE), target.float().to(DEVICE)output = model(image)loss = loss_fn(output, target)losses.append(loss.item())return np.array(losses).mean()class SoftDiceLoss(nn.Module):def __init__(self, smooth=1., dims=(-2, -1)):super(SoftDiceLoss, self).__init__()self.smooth = smoothself.dims = dimsdef forward(self, x, y):tp = (x * y).sum(self.dims)fp = (x * (1 - y)).sum(self.dims)fn = ((1 - x) * y).sum(self.dims)dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)dc = dc.mean()return 1 - dcbce_fn = nn.BCEWithLogitsLoss()  # nn.NLLLoss()
dice_fn = SoftDiceLoss()def loss_fn(y_pred, y_true, ratio=0.8, hard=False):bce = bce_fn(y_pred, y_true)if hard:dice = dice_fn((y_pred.sigmoid()).float() > 0.5, y_true)else:dice = dice_fn(y_pred.sigmoid(), y_true)return ratio * bce + (1 - ratio) * diceheader = r'''Train | Valid
Epoch |  Loss |  Loss | Time, m
'''
#          Epoch         metrics            time
raw_line = '{:6d}' + '\u2502{:7.4f}' * 2 + '\u2502{:6.2f}'
print(header)for fold_idx, (train_idx, valid_idx) in enumerate(skf.split(idx, idx)):# select folderif fold_idx != 4:continuetrain_ds = D.Subset(dataset, train_idx)valid_ds = D.Subset(dataset, valid_idx)# define training and validation data loadersloader = D.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)vloader = D.DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)fold_model_path = 'fold4_unet_model_new4_s.pth'model = smp.Unet(encoder_name="efficientnet-b4",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7encoder_weights='imagenet',  # use `imagenet` pretreined weights for encoder initializationin_channels=3,  # model input channels (1 for grayscale images, 3 for RGB, etc.)classes=1,  # model output channels (number of classes in your dataset))model.load_state_dict(torch.load(fold_model_path))optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1, eta_min=1e-6,                                                                last_epoch=-1)model.to(DEVICE)best_loss = 10for epoch in range(1, EPOCHES + 1):losses = []start_time = time.time()model.train()for image, target in tqdm(loader):image, target = image.to(DEVICE), target.float().to(DEVICE)optimizer.zero_grad()output = model(image)loss = loss_fn(output, target)loss.backward()optimizer.step()losses.append(loss.item())vloss = validation(model, vloader, loss_fn)scheduler.step(vloss)logging.info(raw_line.format(epoch, np.array(losses).mean(), vloss,(time.time() - start_time) / 60 ** 1))losses = []if vloss < best_loss:best_loss = vlosstorch.save(model.state_dict(), 'fold{}_unet_model_new4_s.pth'.format(fold_idx))print("best loss is{}".format(best_loss))

测试代码

#!/usr/bin/env python
# coding: utf-8import numpy as np
import pandas as pd
import pathlib, sys, os, random, time
import numba, cv2, gc
from tqdm import tqdm_notebook
from tqdm import tqdm
import matplotlib.pyplot as plt
# get_ipython().run_line_magic('matplotlib', 'inline')import warningswarnings.filterwarnings('ignore')
from sklearn.model_selection import KFold
import albumentations as A
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
import torch.utils.data as D
from torchvision import transforms as T
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IMAGE_SIZE = 512
trfm = T.Compose([T.ToPILImage(),T.Resize(IMAGE_SIZE),T.ToTensor(),T.Normalize([0.625, 0.448, 0.688],[0.131, 0.177, 0.101]),
])subm = []
model = smp.Unet(encoder_name="efficientnet-b4",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7encoder_weights='imagenet',  # use `imagenet` pretreined weights for encoder initializationin_channels=3,  # model input channels (1 for grayscale images, 3 for RGB, etc.)classes=1,  # model output channels (number of classes in your dataset)
)
model.load_state_dict(torch.load("./fold4_unet_model_new4_s.pth"))
model.eval()
model = model.to(DEVICE)
test_mask = pd.read_csv('./data/test_a_samplesubmit.csv', sep='\t', names=['name', 'mask'])
test_mask['name'] = test_mask['name'].apply(lambda x: './data/test_a/' + x)
def rle_encode(im):'''im: numpy array, 1 - mask, 0 - backgroundReturns run length as string formated'''pixels = im.flatten(order='F')pixels = np.concatenate([[0], pixels, [0]])runs = np.where(pixels[1:] != pixels[:-1])[0] + 1runs[1::2] -= runs[::2]return ' '.join(str(x) for x in runs)
def rle_decode(mask_rle, shape=(512, 512)):'''mask_rle: run-length as string formated (start length)shape: (height,width) of array to returnReturns numpy array, 1 - mask, 0 - background'''s = mask_rle.split()starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]starts -= 1ends = starts + lengthsimg = np.zeros(shape[0] * shape[1], dtype=np.uint8)for lo, hi in zip(starts, ends):img[lo:hi] = 1return img.reshape(shape, order='F')
for idx, name in enumerate(tqdm_notebook(test_mask['name'].iloc[:])):image = cv2.imread(name)image = trfm(image)with torch.no_grad():image = image.to(DEVICE)[None]out=model(image)score = model(image)[0][0]score_sigmoid = score.sigmoid().cpu().numpy()score_sigmoid = (score_sigmoid > 0.5).astype(np.uint8)score_sigmoid = cv2.resize(score_sigmoid, (512, 512))# breaksubm.append([name.split('/')[-1], rle_encode(score_sigmoid)])
subm = pd.DataFrame(subm)
subm.to_csv('./tmp.csv', index=None, header=None, sep='\t')
plt.figure(figsize=(16,8))
plt.subplot(121)
plt.imshow(rle_decode(subm[1].fillna('').iloc[0]), cmap='gray')
plt.subplot(122)
plt.imshow(cv2.imread('data/test_a/' + subm[0].iloc[0]))
plt.show()

图像分割库segmentation_models.pytorch和Albumentations 实现图像分割相关推荐

  1. 图像分割库segmentation_models.pytorch

    segmentation_models_pytorch是一个基于PyTorch的图像分割神经网络 这个新集合由俄罗斯的程序员小哥Pavel Yakubovskiy一手打造. github地址:http ...

  2. 深度学习实战(十):使用 PyTorch 进行 3D 医学图像分割

    深度学习实战(十):使用 PyTorch 进行 3D 医学图像分割 1. 项目简介 2. 3D医学图像分割的需求 3. 医学图像和MRI 4. 三维医学图像表示 5. 3D-Unet模型 5.1损失函 ...

  3. 【图像分割】走进基于深度学习的图像分割

    深度学习中的图像分割 图像分割就是把图像分成若干个特定的.具有独特性质的区域并提出感兴趣目标的技术和过程.它是由图像处理到图像分析的关键步骤.现有的图像分割方法主要分以下几类:基于阈值的分割方法.基于 ...

  4. 【图像分割】基于matlab HSV彩色空间图像分割【含Matlab源码 1474期】

    ⛄一.获取代码方式 获取代码方式1: 完整代码已上传我的资源:[图像分割]基于matlab HSV彩色空间图像分割[含Matlab源码 1474期] (https://download.csdn.ne ...

  5. 医学图像分割方法及卷积神经网络在医学图像分割上的应用

    (最开始接触医学图像分割时写的综述,写的比较幼稚,传上来的时候格式可能有些乱.需要原文的小伙伴可以加我qq:604395564联系,也欢迎做医学图像处理的小伙伴一起交流学习.自己写的,欢迎转载,但请注 ...

  6. 【图像分割】基于 C-V模型水平集图像分割Matlab代码

    1 简介 图像分割是计算机视觉中的关键步骤之一.传统的分割方法由于方法自身的局部性,难以满足复杂分割的要求,基于水平集方法的图像分割研究正是这种需求下出现的.C-V模型对灰度图像的变化处理非常自然,解 ...

  7. matlab 图像分割库,图像分割Matlab代码

    图像分割Matlab代码 图像分割 Matlab 代码(一)图像边缘检测不同方法比较将 Roberts.Sobel.Prewitt.LOG.Canny 算子等经典图像分割算法对灰度图像分割的结果进行比 ...

  8. 使用PyTorch和Albumentations进行数据增强与损失函数

    数据扩增 Part 1 数据读取与数据扩增 图像读取 数据扩增 基于图像处理的数据扩增 几何变换 灰度和彩色空间变换 添加噪声和滤波 图像混合(Mixing images) 随机搽除(Random e ...

  9. Matlab图像分割---使用dice相似系数方法进行图像分割精度验证

    实例1:计算二值分割图像的dice相似系数 实例2:计算多区域分割图像的dice相似系数 本例程的配套完整源码和图片素材下载 描述: similarity = dice(BW1,BW2) 计算二进制图 ...

最新文章

  1. Java 用正则表达式 提取目录
  2. CSS中的四种样式及选择器
  3. 分科目统计每科前三名的学生
  4. 《刷新》:拥抱同理心,建立成长型思维
  5. 23种设计模式的优点与缺点概况
  6. servlet下载文件(注意文件名字必须是英文)
  7. 微软的正则表达式教程(一):正则表达式简介
  8. oracle数据库中sql%notfound的用法
  9. 【社会/人文】概念的理解 —— 断舍离、饭(饭制版)
  10. 浅谈Linq to SQL中的模式
  11. STM32开发环境的搭建
  12. C语言读取wav文件中特定内容6,c读取wav文件,头文件后面的所有数据
  13. CISCO 思科2960G CONSOLE口越过登陆账号密码访问
  14. Python爬虫(一)——58同城租房信息
  15. 线上连锁线下整合 连锁店电商解决方案
  16. 【Scala】9、Trait、Match、CaseClass和偏函数
  17. nmn成分是什么,吃nmn对身体有哪些好处,掌握知识点
  18. 免费的可视化Web报表工具,JimuReport v1.4.4-beta版本发布
  19. 什么是软件危机?软件危机的主要表现是什么?什么是软件?什么是软件工程?什么是软件过程?软件过程与软件工程方法学有何关系?​​​​​​​什么是软件开发方法?软件开发方法主要有哪些?
  20. linux 易语言窗口程序_用易语言开发Linux程序的方法

热门文章

  1. 美团金融扫码付静态资源加载优化实践
  2. QT制作一个图片播放器
  3. 团队-排课软件-最终程序
  4. 一个按钮显示九九乘法表html,在JSP页面显示九九乘法表
  5. java并发编程的艺术-(4)
  6. SharePoint ULS Log Viewer 日志查看器
  7. python毕业设计作品基于django框架 电子书阅读系统毕设成品(6)开题答辩PPT
  8. 盘点十大程序员兼职接私活平台
  9. 电脑管家急救箱linux,电脑管家也出系统急救箱, 杀毒能力不输360!
  10. vs2012配置python_Visual Studio 2012 Ultimate 上安装 Python 开发插件 PTVS