这个很不错:https://blog.csdn.net/qq_39056987/article/details/106455828     【windows10】使用pytorch版本deeplabv3+训练自己数据集

参考:https://blog.csdn.net/qq_36766560/article/details/110009622?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522161594607816780266214828%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id=161594607816780266214828&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~first_rank_v2~rank_v29-1-110009622.first_rank_v2_pc_rank_v29_10&utm_term=pytorch+deplabv3%2B

  • 开发环境
  • 数据集准备
    • 1.VOC数据集格式
      • `JPEGImages`里面放原图
      • `SegmentationClass`里面放对应的mask图片png格式,注意要和`JPEGImages`里的图片一一对应
      • `ImageSets/ Segmentation`的txt文件中放去掉后缀的图片名
    • 2.json转换mask图片
    • 3.提取出所有文件夹中的`label.png`并改成对应的名字放在指定目录中
  • 修改代码
    • 1.在`mypath.py`中添加自己的数据集名称与路径
    • 2.在同级目录中修改`train.py`约185行添加自己数据集的名称(可以设置为默认)
    • 3.在dataloaders目录下修改__init__.py
    • 4. 修改dateloaders目录下`utils.py`
    • 5.在dataloaders/datasets目录下添加文件
  • 运行并测试
    • 1.开始训练
    • 2.测试

开发环境

我使用的是实验室服务器的环境:

代码在这:链接

VOC数据集:链接:https://pan.baidu.com/s/1ncWVIVgntas3cJBFMqnfFg      提取码:64pp

预训练模型:链接:https://pan.baidu.com/s/1d58JnOtxa4JuJKO3tiZYqw        提取码:po5m

数据集准备

1.VOC数据集格式

文件安排如下:

- ImageSets- Segmentation- train.txt- trainval.txt- val.txt
- JPEGImages
- SegmentationClass

JPEGImages里面放原图

SegmentationClass里面放对应的mask图片png格式,注意要和JPEGImages里的图片一一对应

ImageSets/ Segmentation的txt文件中放去掉后缀的图片名

2.json转换mask图片,(没做,制作自己的数据集时用到)

数据转换参考的这里

import argparse
import base64
import json
import os
import os.path as ospimport PIL.Image
import yamlfrom labelme.logger import logger
from labelme import utilspath = "G:/Seg552VOC/seg552Json"
dirs = os.listdir(path)def label(json_file, out_dir, label_name_to_value):data = json.load(open(json_file))if data['imageData']:imageData = data['imageData']else:imagePath = os.path.join(os.path.dirname(json_file), data['imagePath'])with open(imagePath, 'rb') as f:imageData = f.read()imageData = base64.b64encode(imageData).decode('utf-8')img = utils.img_b64_to_arr(imageData)for shape in sorted(data['shapes'], key=lambda x: x['label']):label_name = shape['label']if label_name in label_name_to_value:label_value = label_name_to_value[label_name]else:label_value = len(label_name_to_value)label_name_to_value[label_name] = label_valuelbl = utils.shapes_to_label(img.shape, data['shapes'], label_name_to_value)label_names = [None] * (max(label_name_to_value.values()) + 1)for name, value in label_name_to_value.items():label_names[value] = namelbl_viz = utils.draw_label(lbl, img, label_names)PIL.Image.fromarray(img).save(osp.join(out_dir, 'img.png'))utils.lblsave(osp.join(out_dir, 'label.png'), lbl)PIL.Image.fromarray(lbl_viz).save(osp.join(out_dir, 'label_viz.png'))with open(osp.join(out_dir, 'label_names.txt'), 'w') as f:for lbl_name in label_names:f.write(lbl_name + '\n')logger.warning('info.yaml is being replaced by label_names.txt')info = dict(label_names=label_names)with open(osp.join(out_dir, 'info.yaml'), 'w') as f:yaml.safe_dump(info, f, default_flow_style=False)logger.info('Saved to: {}'.format(out_dir))def main():logger.warning('This script is aimed to demonstrate how to convert the''JSON file to a single image dataset, and not to handle''multiple JSON files to generate a real-use dataset.')parser = argparse.ArgumentParser()parser.add_argument('json_file_dir')parser.add_argument('-o', '--out', default=None)args = parser.parse_args()label_name_to_value = {'_background_': 0}for json_file in dirs:if args.out is None:out_dir = osp.basename(json_file).replace('.', '_')out_dir = osp.join(osp.dirname(json_file), out_dir)else:out_dir = args.outif not osp.exists(out_dir):os.mkdir(out_dir)label(json_file, out_dir, label_name_to_value)
if __name__ == '__main__':main()

在conda prompt中启动labelme环境运行labelme_json_to_dataset 输出的目录,得到很多文件夹,每个文件夹里都有四个文件,我们只需要label.png

3.提取出所有文件夹中的label.png并改成对应的名字放在指定目录中,(没做,制作自己的数据集时用到)

label.py代码如下:

import os
import shutilinputdir = 'G:/Seg552VOC/datasetDir'
outputdir = 'G:/Seg552VOC/ImageSetNew'for dir in os.listdir(inputdir):# 设置旧文件名(就是路径+文件名)oldname = inputdir + os.sep + dir + os.sep + 'label.png'  # os.sep添加系统分隔符# 设置新文件名newname = outputdir + os.sep + dir.split('_')[0] + '.png'shutil.copyfile(oldname, newname)  # 用os模块中的rename方法对文件改名print(oldname, '======>', newname)

修改代码

参考博客:链接

1.在mypath.py中添加自己的数据集名称与路径

2.在同级目录中修改train.py约185行添加自己数据集的名称(可以设置为默认)

3.在dataloaders目录下修改__init__.py

在第一行添加数据集名称,复制'pascal'数据集描述,把名称修改为自己数据集的名字

4. 修改dateloaders目录下utils.py

在76行左右添加代码,设置每一类别的颜色显示。

在24行左右添加代码,其中n_classes是你要分割的类别数

5.在dataloaders/datasets目录下添加文件

复制一份pascal.py,并重命名文件

修改里面的类别数和数据集名称

运行并测试

1.开始训练

运行指令如下

python train.py --backbone mobilenet --lr 0.007 --workers 1 --epochs 50 --batch-size 8 --gpu-ids 0 --checkname deeplab-mobilenet

–backbone mobilenet 指的是使用mobilenet作为backbone
–gpu-ids 0 指定gpu
–checkname deeplab-mobilenet 使用mobilenet预训练模型

2.测试

测试testdemo.py
修改–in-path为数据集的测试图片,最后的结果保存在–in-path中

#
# demo.py
#
import argparse
import os
import numpy as np
import timefrom modeling.deeplab import *
from dataloaders import custom_transforms as tr
from PIL import Image
from torchvision import transforms
from dataloaders.utils import  *
from torchvision.utils import make_grid, save_imagedef main():parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Training")parser.add_argument('--in-path', type=str,  default='/root/home/zyx/Seg552_VOC/test',help='image to test')# parser.add_argument('--out-path', type=str, required=True, help='mask image to save')parser.add_argument('--backbone', type=str, default='resnet',choices=['resnet', 'xception', 'drn', 'mobilenet'],help='backbone name (default: resnet)')parser.add_argument('--ckpt', type=str, default='deeplab-resnet.pth',help='saved model')parser.add_argument('--out-stride', type=int, default=16,help='network output stride (default: 8)')parser.add_argument('--no-cuda', action='store_true', default=False,help='disables CUDA training')parser.add_argument('--gpu-ids', type=str, default='0',help='use which gpu to train, must be a \comma-separated list of integers only (default=0)')parser.add_argument('--dataset', type=str, default='pascal',choices=['pascal', 'coco', 'cityscapes','invoice'],help='dataset name (default: pascal)')parser.add_argument('--crop-size', type=int, default=513,help='crop image size')parser.add_argument('--num_classes', type=int, default=4,help='crop image size')parser.add_argument('--sync-bn', type=bool, default=None,help='whether to use sync bn (default: auto)')parser.add_argument('--freeze-bn', type=bool, default=False,help='whether to freeze bn parameters (default: False)')args = parser.parse_args()args.cuda = not args.no_cuda and torch.cuda.is_available()if args.cuda:try:args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]except ValueError:raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')if args.sync_bn is None:if args.cuda and len(args.gpu_ids) > 1:args.sync_bn = Trueelse:args.sync_bn = Falsemodel_s_time = time.time()model = DeepLab(num_classes=args.num_classes,backbone=args.backbone,output_stride=args.out_stride,sync_bn=args.sync_bn,freeze_bn=args.freeze_bn)ckpt = torch.load(args.ckpt, map_location='cpu')model.load_state_dict(ckpt['state_dict'])model = model.cuda()model_u_time = time.time()model_load_time = model_u_time-model_s_timeprint("model load time is {}".format(model_load_time))composed_transforms = transforms.Compose([tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),tr.ToTensor()])for name in os.listdir(args.in_path):s_time = time.time()image = Image.open(args.in_path+"/"+name).convert('RGB')# image = Image.open(args.in_path).convert('RGB')target = Image.open(args.in_path+"/"+name).convert('L')sample = {'image': image, 'label': target}tensor_in = composed_transforms(sample)['image'].unsqueeze(0)model.eval()if args.cuda:tensor_in = tensor_in.cuda()with torch.no_grad():output = model(tensor_in)grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy()),3, normalize=False, range=(0, 255))save_image(grid_image,args.in_path+"/"+"{}_mask.png".format(name[0:-4]))u_time = time.time()img_time = u_time-s_timeprint("image:{} time: {} ".format(name,img_time))# save_image(grid_image, args.out_path)# print("type(grid) is: ", type(grid_image))# print("grid_image.shape is: ", grid_image.shape)print("image save in in_path.")
if __name__ == "__main__":main()# python demo.py --in-path your_file --out-path your_dst_file

测试命令

 python testdemo.py --ckpt run/Seg552/deeplab-mobilenet/checkpoint.pth.tar --backbone mobilenet

参考博客:
Pytorch 语义分割DeepLabV3+ 训练自己的数据集 从数据准备到模型训练
【windows10】使用pytorch版本deeplabv3+训练自己数据集
制作自己的语义分割数据集(仿voc2012数据集)

Pytorch版deeplabv3+环境配置训练自己的数据集相关推荐

  1. 【深度之眼PyTorch框架班第五期】作业打卡01:PyTorch简介及环境配置;PyTorch基础数据结构——张量

    文章目录 任务名称 任务简介 详细说明 作业 1. 安装anaconda,pycharm, CUDA+CuDNN(可选),虚拟环境,pytorch,并实现hello pytorch查看pytorch的 ...

  2. mysql免安装版net不是_MYSQL 免安装版的环境配置

    如:D:\Program Files\mysql-5.6.23-winx64 2.配置my.ini文件: [client] port=3306 default-character-set=utf8 [ ...

  3. 基于cuda10.0的pytorch深度学习环境配置

    基于cuda10.0的pytorch深度学习环境配置(报错解决) 1.首先查看自己nvidia 显卡的版本,一般都能适用cuda10.0: 1.打开win+s 搜索nvidia控制面版,查看系统信息 ...

  4. PyTorch安装与环境配置

    PyTorch安装与环境配置 NLP课程第一个实验,需要安装PyTorth,下面是一些安装方法及遇到的问题记录. 一.本地安装环境 Windows10 + VScode 网上大多数都是 anacond ...

  5. macbook m1版 前端环境配置

    macbook m1版 前端环境配置 安装 Homebrew 安装nvm 使用nvm安装node 安装git 安装 Homebrew 复制以下代码到你的终端 /bin/zsh -c "$(c ...

  6. Mac版JDK环境配置及Java多版本切换

    Mac版JDK环境配置及Java多版本切换 一.下载JDK包 JDK 各版本可通过 ORACLE 官网下载 ,下载较慢,可以百度搜索 Java Development Kit Mac 找下国内资源 传 ...

  7. OpenGL超级宝典(第五版)环境配置

    本文转自:http://blog.csdn.net/sunny_unix/article/details/8056807,感谢作者分享. OpenGL超级宝典(第五版)环境配置 Vs2008+winX ...

  8. OpenGL超级宝典(第五版) 环境配置

    特别提醒:有些在word中或者其他中的代码复制到vs中会报错,原因是word中有些隐含的字符,复制到vs中就会报错:重新输一遍就可以解决问题,这里只是提醒下! 可以参阅我前面转载的一篇文章,进行比较然 ...

  9. OpenGL超级宝典(第五版) 环境配置(WinXp+VS2008)

    转自:http://blog.csdn.net/sunny_unix/article/details/8056807 OpenGL超级宝典(第五版)环境配置 1.各种库的配置 (1)glew 下载:h ...

最新文章

  1. Tableau可视化分析实战系列Tableau基础概念全解析 (一)-数据结构及字段
  2. 深度干货 | 多维分析中的 UV 与 PV
  3. C#中二进制和流之间的各种相互转换
  4. MySQL中的读锁和写锁
  5. Java编程语言的历史和未来
  6. jsp中@import导入外部样式表与link链入外部样式表的区别
  7. DTCC 2020 | 阿里云梁高中:DAS之基于Workload的全局自动优化实践
  8. oracle同机单实例加入集群,将oracle同机单实例加入rac集群的操作步骤
  9. 去年写的测试GDAL用法的一些函数
  10. 谈谈对一些软件架构设计箴言的理解 对软件的过早地优化是万恶的根源 反设计模式案例简介...
  11. 深度学习如何入门?知乎
  12. 【单片机基础篇】51单片机流水灯
  13. 【美赛】全面助力2023年美国大学生数学建模竞赛,祝大家取得好成绩
  14. Oracle新增字段后,写入数据是提示ORA-00917:XXX 标识符无效
  15. win10设置计算机关机时间,w10怎么设置自动关机_win10电脑设置自动关机的方法
  16. 麦克纳姆轮(全向轮)
  17. MAC软件推荐(Java方向)
  18. Python一行代码实现正三角形
  19. 宇视NVR录像机下载录像没有声音如何解决
  20. 市场分析——行业背景分析

热门文章

  1. 【一头扎进JMS】(4)----RabbitMQ概述
  2. 详解MySQL之事务
  3. 1036: 某年某月有多少天 C语言
  4. 调节阀各种特性气动调节阀如何存放
  5. 需求收集方法工具,以及进行需求分析的6大要素
  6. 【金猿投融展】众盟科技——专注商业智能的技术服务平台
  7. 如何用python制作小游戏
  8. powershell获取linux文件,技术|微软爱上 Linux:当 PowerShell 来到 Linux 时
  9. 芜湖市小学生c语言培训班,芜湖少儿学编程-地址-电话
  10. 快速排序的四种python实现