下面是PointRend的源码位置,接下来先跑下看看

GitHub - zsef123/PointRend-PyTorch: A PyTorch implementation of PointRend: Image Segmentation as Renderinghttps://github.com/zsef123/PointRend-PyTorch

(1)数据准备

数据就用公共数据集CamVid,该数据集加背景0共12个类,标签值为0-11,下面是一级目标,目录结构及文件名务必保持一致,因为我后面在数据读取的时候添加了读自己数据集的数据导入函数,文件夹名字是固定了的,当然你也可以改代码。

二级目录,train/val/test,目录结构需要一致,另外如果test只有图像也可以不要labels文件夹

(2)添加自己的数据加载模块

在 __init__.py文件中添加了get_own函数,加完以后在get_loader函数添加自己数据的引导,另外需要强调下,我自己添加的数据加载没有专门加数据扩充策略,你们自己加下,加了效果应该会好点。

__init__.py代码:

import os
import cv2
from PIL import Image
import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.datasets.voc import VOCSegmentation
from torchvision.datasets.cityscapes import Cityscapesfrom .transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomFlip, ConvertMaskIDdef get_voc(C, split="train"):if split == "train":transforms = Compose([ToTensor(),RandomCrop((256, 256)),Resize((256, 256)),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])else:transforms = Compose([ToTensor(),Resize((256, 256)),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])return VOCSegmentation(C['root'], download=True, image_set=split, transforms=transforms)def get_cityscapes(C, split="train"):if split == "train":# Appendix B. Semantic Segmentation Detailstransforms = Compose([ToTensor(),RandomCrop(768),ConvertMaskID(Cityscapes.classes),RandomFlip(),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])else:transforms = Compose([ToTensor(),Resize(768),ConvertMaskID(Cityscapes.classes),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])return Cityscapes(**C, split=split, transforms=transforms)class get_own(torch.utils.data.Dataset):def __init__(self, C, split="train"):images_path = os.path.join(C['root'], split, 'images')labels_path = os.path.join(C['root'], split, 'labels')images_path_list = []labels_path_list = []imgs = os.listdir(images_path)for name in imgs:img_full_path = os.path.join(images_path, name)lab_full_path = os.path.join(labels_path, name)images_path_list.append(img_full_path)labels_path_list.append(lab_full_path)self.images_path_list = images_path_listself.labels_path_list = labels_path_listif split == "train":# Appendix B. Semantic Segmentation DetailsTransform = Compose([ToTensor(),# RandomCrop(256),# ConvertMaskID(Cityscapes.classes),# RandomFlip()# Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])else:Transform = Compose([ToTensor(),# Resize(256),# ConvertMaskID(Cityscapes.classes),# Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])self.transform = Transformdef __getitem__(self,index):  image_path = self.images_path_list[index]label_path = self.labels_path_list[index]# image = Image.open(image_path).convert('RGB')# label = Image.open(label_path)image = cv2.imread(image_path)label = cv2.imread(label_path, 0)image = np.array(image, np.float32) / 255.0label = np.array(label, np.float32)# image = self.transform('images':image)# label = self.transform('masks':label)image, label = self.transform(image, label)# image = image.type(torch.FloatTensor)# label = label.type(torch.FloatTensor)return image, labeldef __len__(self):return len(self.images_path_list)def get_loader(C, split, distributed):"""Args:C (Config): C.datasplit (str): args of dataset,The image split to use, ``train``, ``test`` or ``val`` if split="gtFine"otherwise ``train``, ``train_extra`` or ``val`"""print(C.name, C.dataset, split)if C.name == "cityscapes":dset = get_cityscapes(C.dataset, split)elif C.name == "pascalvoc":dset = get_voc(C.dataset, split)elif C.name == "own":dset = get_own(C.dataset, split)if split == "train":shuffle = Truedrop_last = Falseelse:shuffle = Falsedrop_last = Falsesampler = Noneif distributed:sampler = DistributedSampler(dset, shuffle=shuffle)shuffle = Nonereturn DataLoader(dset, **C.loader, sampler=sampler,shuffle=shuffle, drop_last=drop_last,pin_memory=True)

(3)训练

这个GitHub项目结构比较好,训练模块在train.py中,不需要改,主要改main.py文件中的部分东西,由于这个项目用了apex来加速训练,而我这里安装不方便,还报错了,我main.py的主要改动就是注释掉apex相关的部分。

main.py代码

import os
import sys
import argparse
import logging
from tokenize import Double
from configs.parser import Parserimport torch# from apex import amp
# from apex.parallel import DistributedDataParallel as ApexDDPfrom model import deeplabv3, PointHead, PointRend
from datas import get_loader
from train import train
from utils.gpus import synchronize, is_main_processdef parse_args():parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")parser.add_argument("--config", type=str, default="./configs/default.yaml", help="It must be config/*.yaml")  #yaml文件是必要的配置文件,后面会简要说明parser.add_argument("--save", type=str, default="build", help="Save path in out directory")parser.add_argument("--local_rank", type=int, default=0, help="Using for Apex DDP")return parser.parse_args()def amp_init(args):# Apex Initializeargs.distributed = Falseif 'WORLD_SIZE' in os.environ:args.distributed = int(os.environ['WORLD_SIZE']) > 1if args.distributed:torch.cuda.set_device(args.local_rank)torch.distributed.init_process_group(backend="nccl", init_method="env://")synchronize()torch.backends.cudnn.benchmark = Truedef set_loggging(save_dir):if not os.path.exists(save_dir):os.makedirs(save_dir)log_format = '%(asctime)s %(message)s'logging.basicConfig(stream=sys.stdout, level=logging.INFO,format=log_format, datefmt='[%y/%m/%d %H:%M:%S]')fh = logging.FileHandler(f"{save_dir}/log.txt")fh.setFormatter(logging.Formatter(log_format))logging.getLogger().addHandler(fh)if __name__ == "__main__":args = parse_args()amp_init(args)parser = Parser(args.config)C = parser.Csave_dir = f"{os.getcwd()}/outs/{args.save}"if is_main_process():if not os.path.exists(save_dir):os.makedirs(save_dir, mode=0o775)parser.dump(f"{save_dir}/config.yaml")set_loggging(save_dir)device = torch.device("cuda")train_loader = get_loader(C.data, "train", distributed=args.distributed)valid_loader = get_loader(C.data, "val", distributed=args.distributed)net = PointRend(deeplabv3(**C.net.deeplab),PointHead(**C.net.pointhead)).to(device)params = [{"params": net.backbone.backbone.parameters(),   "lr": float(C.train.lr)},{"params": net.head.parameters(),                "lr": float(C.train.lr)},{"params": net.backbone.classifier.parameters(), "lr": float(C.train.lr) * 10}]# optim = torch.optim.SGD(params, momentum=C.train.momentum, weight_decay=C.train.weight_decay)#这里尝试了用adamw优化器训练optim = torch.optim.AdamW(params, lr=float(C.train.lr), weight_decay=float(C.train.weight_decay))#这里注释了需要apex加速的模块# net, optim = amp.initialize(net, optim, opt_level=C.apex.opt)# if args.distributed:#     net = ApexDDP(net, delay_allreduce=True)train(C.run, save_dir, train_loader, valid_loader, net, optim, device)#Apex混合精度加速 介绍:为了帮助提高Pytorch的训练效率,英伟达提供了混合精度训练工具Apex。
# 号称能够在不降低性能的情况下,将模型训练的速度提升2-4倍,训练显存消耗减少为之前的一半。
# 该项目开源于:https://github.com/NVIDIA/apex ,文档地址是:https://nvidia.github.io/apex/index.html该工具提供了三个功能,amp、parallel和normalization。

训练用的default.yaml文件

data:name: "own"dataset:root: "./datasets/CamVid/"mode: "fine"target_type: "semantic"loader:batch_size: 5num_workers: 0net:deeplab:pretrained: Falseresnet: "res101"head_in_ch: 2048num_classes: 12pointhead:in_c: 524 # 512 + num_classesnum_classes: 12k: 3beta: 0.75run:epochs: 101train:lr: 1e-3       momentum: 0.9weight_decay: 1e-3apex:opt: "O0"

(4)预测

原始的预测用的是infer.py文件,这个预测要加载标签,而且会给出精度评价,我考虑到会有直接预测而不加标签预测的情况,改了一个预测代码

predict.py代码:

import os
import time
import logging
import cv2
from PIL import Image
import numpy as np
import torch
import argparse
from configs.parser import Parser
from model import deeplabv3, PointHead, PointRend
from utils.metrics import ConfusionMatrix
from utils.gpus import synchronize, is_main_process@torch.no_grad()
def infer(loader, net, device):net.eval()num_classes = 2 # Hard coding for Cityscapesmetric = ConfusionMatrix(num_classes)for i, (x, gt) in enumerate(loader):x = x.to(device, non_blocking=True)gt = gt.squeeze(1).to(device, dtype=torch.long, non_blocking=True)pred = net(x)["fine"].argmax(1)metric.update(pred, gt)mIoU = metric.mIoU()logging.info(f"[Infer] mIOU : {mIoU}")return mIoUdef parse_args():parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")parser.add_argument("--config", type=str, default="./configs/default.yaml", help="It must be config/*.yaml")parser.add_argument("--save", type=str, default="build", help="Save path in out directory")parser.add_argument("--local_rank", type=int, default=0, help="Using for Apex DDP")return parser.parse_args()def amp_init(args):# Apex Initializeargs.distributed = Falseif 'WORLD_SIZE' in os.environ:args.distributed = int(os.environ['WORLD_SIZE']) > 1if args.distributed:torch.cuda.set_device(args.local_rank)torch.distributed.init_process_group(backend="nccl", init_method="env://")synchronize()torch.backends.cudnn.benchmark = Truedef predict(data_path, model_path, net, save_path):net.load_state_dict(torch.load(model_path))net.eval()img_names = os.listdir(data_path)for ele in img_names:full_path = os.path.join(data_path, ele)# image = Image.open(full_path).convert('RGB')image = cv2.imread(full_path)image = np.array(image, np.float32) / 255.0# image = np.array(image)image = image.transpose(2,0,1)image = np.expand_dims(image, axis=0)# image = torch.from_numpy(image)image = torch.FloatTensor(image)x = image.to(device, non_blocking=True)pred = net(x)["fine"].argmax(1)# pred = net(x)["fine"]save_full_path = os.path.join(save_path, ele)pred = pred.cpu().numpy()cv2.imwrite(save_full_path, pred[0])if __name__ == "__main__":path = './datasets/CamVid/test/'save_path = './datasets/pred/'model_path = './outs/CamVid/epoch_0100_loss_0.54185.pth'args = parse_args()amp_init(args)parser = Parser(args.config)C = parser.Cdevice = torch.device("cuda")net = PointRend(deeplabv3(**C.net.deeplab),PointHead(**C.net.pointhead)).to(device)predict(path, model_path, net, save_path)

效果:

  

图像                                                                        标签

预测结果

从结果看,很明显效果不理想,不过不要太过悲观,因为我去掉了加速模块,这个训练有点慢,我训练了100个epoch就停掉了,并且数据也没有做增强,效果肯定还可以提高的。

PointRend使用记录相关推荐

  1. mysql建立联合索引,mysql建立唯一键,mysql如何解决重复记录联合索引

    在项目中,常常要用到联合唯一   在一些配置表中,一些列的组合成为一条记录.   比如,在游戏中,游戏的分区和用户id会形成一条记录.(比如,一个qq用户可以在艾欧尼亚.德玛西亚创建两个账号) 添加联 ...

  2. 实现 连续15签到记录_MySQL和Redis实现用户签到,你喜欢怎么实现?

    现在的网站和app开发中,签到是一个很常见的功能 如微博签到送积分,签到排行榜 微博签到 如移动app ,签到送流量等活动, 移动app签到 用户签到是提高用户粘性的有效手段,用的好能事半功倍! 下面 ...

  3. 记录一次http请求失败的问题分析

    问题背景 当前我有一个基于Flask编写的Restful服务,由于业务的需求,我需要将该服务打包成docker 镜像进行离线部署,原始服务的端口是在6661端口进行开启,为了区分,在docker中启动 ...

  4. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

  5. LeetCode简单题之学生出勤记录 I

    题目 给你一个字符串 s 表示一个学生的出勤记录,其中的每个字符用来标记当天的出勤情况(缺勤.迟到.到场).记录中只含下面三种字符: 'A':Absent,缺勤 'L':Late,迟到 'P':Pre ...

  6. 关于TVM的点滴记录

    关于TVM的点滴记录

  7. MySql数据库Update批量更新与批量更新多条记录的不同值实现方法

    批量更新 mysql更新语句很简单,更新一条数据的某个字段,一般这样写: UPDATE mytable SET myfield = 'value' WHERE other_field = 'other ...

  8. 记录篇,自己在项目中使用过的。

    图片选择器,6.0已经适配过,类似qq空间上传 点击打开链接_胡小牧记录 下面是效果图: PictureSelector PhotoPicker 类似qq空间发布心情. 点击打开链接 BubbleSe ...

  9. HTML5与CSS3权威指南之CSS3学习记录

    title: HTML5与CSS3权威指南之CSS3学习记录 toc: true date: 2018-10-14 00:06:09 学习资料--<HTML5与CSS3权威指南>(第3版) ...

最新文章

  1. 初探 Unix 操作系统
  2. [单选题]PDO::ATTR_ERRMODE设置为以下哪个值时,PDO会抛出PDOException?
  3. C#的同步和异步调用方法
  4. 阿里云原生张羽辰:服务发现技术选型那点事儿
  5. 电脑上玩和平精英_《和平精英》怎么投屏到电脑上?手把手教你电脑键鼠玩手游...
  6. Oracle 11g中关于数据定义的思考
  7. 从计算机系统结构的发展和演变看,近代计算机是以,西南民族大学计算机系统结构试卷B有答案.doc...
  8. iOS 浅谈:深.浅拷贝与copy.strong
  9. 408计算机及格要什么水平,2019考研计算机408难度水平?
  10. 2012 r2 万能网卡驱动_MultiBeast | 黑苹果的驱动精灵简单使用解读
  11. Excel远程连接Oracle,excel连接数据库_怎么用oracle命令连接远程数据库�9�3
  12. 简书网页劫持分析,网站劫持,利用 CSP 预防劫持
  13. Ubuntu deb文件 安装 MySQL
  14. 仿淘宝详情页上拉看详情
  15. 导师吐槽大会:自己招的学生,哭着也要带完
  16. Freemaker之代码生成
  17. 张寓博当选山东省收藏者协会副主席兼美术评论委员会主任
  18. 普通的视觉工程师的待遇是怎样的?
  19. mcnpf5输出结果_MCNP及使用.ppt
  20. Windows消息处理

热门文章

  1. 为了成长,我豁出去了!同程苏州,我来了!
  2. iOS_使用金山快盘管理项目
  3. RNA 1. SCI 文章中读取 GEO 数据
  4. 卷积神经网络-猫狗识别(附源码)
  5. select二级联动价格策略+js的eval()
  6. [ORACLE]数据字典视图大全
  7. 手把手教你搭建SpringCloudAlibaba之Nacos服务配置中心
  8. A律13折现编解码实现,SystemVerilog实现,实测可用带完整的testbench
  9. (三)爬取一些网页图片
  10. Postgresql快照优化Globalvis新体系分析(性能大幅增强)