首先祝贺百度团队百度斩获NeurIPS2020挑战赛冠军,https://www.jiqizhixin.com/articles/2020-12-09-2。
在此次比赛中使用的是基于飞桨深度学习框架开发的图像分割套件PaddleSeg。从这篇文章开始,我将持续更新《图像分割套件PaddleSeg全面解析》系列文章,由于个人水平有限,如有错误之处请见谅,谢谢。

PaddleSeg是百度基于自家的PaddlePaddle开发的端到端图像分割开发套件。包含多种主流的分割网络。PaddleSeg采用模块化的方式设计,可以通过配置文件方式进行模型组合,帮助开发者在不需要深入了解图像分割原理的情况,实现方便快捷的完成模型的训练与部署。 但是在对需要对模型进行修改优化的时候,还是需要对图像分割原理以及PaddleSeg套件有进一步了解,本文的主要内容就是对PaddleSeg进行代码解读,帮助开发者进一步了解图像分割原理以及PaddleSeg的实现方法。本文只要介绍PaddleSeg的动态图的实现方法。

本代码解读基于PaddleSeg动态图版本V2.0.0-rc。 PaddleSeg套件的源代码可以从GitHub上进行下载,命令如下:
PaddleSeg套件的源代码可以从GitHub上进行下载,命令如下:

!git clone https://github.com/PaddlePaddle/PaddleSeg.git

PaddleSeg目录包含下几个目录:

  • configs:保存不同神经网络的配置文件。
  • contrib:真实案例相关配置与数据
  • legacy:静态图版本代码,只维护,不更新新功能
  • docs:文档
  • paddleseg:PaddleSeg核心代码,包含训练、评估、推理等文件。
  • tools:工具脚本
  • train.py:训练入口文件
  • val.py:评估模型文件
  • predict.py:预测文件

本文大概分为以下7个部分:
1.train.py代码解读:这里主要讲解paddleseg训练入口文件的代码。该文件里描述了参数的解析,训练的启动方法,以及为训练准备的资源等。
图像分割套件PaddleSeg全面解析(一)train.py代码解读
2.Config代码解读:这里主要讲解了Config类的代码,config类由train.py实例化,通过运行train.py时指定的配置文件生成config对象。
图像分割套件PaddleSeg全面解析(二)
3.DataSet代码解读:这里主要讲解了Dataset类,对每一种数据集都抽象为一个类,通过继承Dataset类,实现匿名协议,构建文件列表,供训练使用。
图像分割套件PaddleSeg全面解析(三)
4.数据增强代码解读:这里主要讲解了数据预处理与增强的一些常用算法。
图像分割套件PaddleSeg全面解析(四)
5.模型与Backbone代码解读:这里主要讲解常用的模型以及backbone的网络与算法。
图像分割套件PaddleSeg全面解析(五)
图像分割套件PaddleSeg全面解析(六)

6.损失函数代码解读:这里主要讲解常用的损失函数的代码与算法。
图像分割套件PaddleSeg全面解析(七)

7.评估模型代码解读:这里讲解评估模型性能的代码与评估方法。

1.train.py代码解读

神经网络模型训练需要使用train.py来完成。是PaddleSeg中核心代码。

我们先结合下图,来了解一下训练之前的准备工作。

可以通过以下命令快速开始一个训练任务。

python train.py --config configs/quick_start/bisenet_optic_disc_512x512_1k.yml

命令中的–config参数指定本次训练的配置文件,配置文件的详细介绍可以参见后面的第二节。

在执行train.py脚本的最开始会导入一些包,如下:

from paddleseg.cvlibs import manager, Config
from paddleseg.utils import get_sys_env, logger
from paddleseg.core import train
  • 在导入manager模块时会创建图中左侧manage方框中的5个ComponentManager对象,他们分别是MODELS、BACKBONES、DATASETS、TRANSFORMS和LOSSES。这5个ComponentManager类似字典,用来维护套件中所有对应的类,比如FCN类、ResNet类等,通过类的名称就可以找到对应的类。
  • 在train.py运行时,会创建config对象。
cfg = Config(args.cfg,learning_rate=args.learning_rate,iters=args.iters,batch_size=args.batch_size)

在创建config对象时,会通过manager获取到配置文件中指定的类,并实例化对象,比如model和loss等。

  • train.py调用train函数,将config作为实参传入。train函数获取config中的成员来完成训练工作。

下面我们来详细解读一下train.py,首先我们从train.py的入口代码开始:

if __name__ == '__main__':# 处理运行train.py传入的参数args = parse_args()#调用主函数。main(args)

首先看一下第一行代码

args = parse_args()

parse_args()的实现如下:

 #配置文件路径parser.add_argument("--config", dest="cfg", help="The config file.", default=None, type=str)#总训练迭代次数parser.add_argument('--iters',dest='iters',help='iters for training',type=int,default=None)#batchsize大小parser.add_argument('--batch_size',dest='batch_size',help='Mini batch size of one gpu or cpu',type=int,default=None)#学习率parser.add_argument('--learning_rate',dest='learning_rate',help='Learning rate',type=float,default=None)#保存模型间隔parser.add_argument('--save_interval',dest='save_interval',help='How many iters to save a model snapshot once during training.',type=int,default=1000)#如果需要恢复训练,指定恢复训练模型路径parser.add_argument('--resume_model',dest='resume_model',help='The path of resume model',type=str,default=None)#模型保存路径parser.add_argument('--save_dir',dest='save_dir',help='The directory for saving the model snapshot',type=str,default='./output')#数据读取器线程数量,目前在AI Studio建议设置为0.parser.add_argument('--num_workers',dest='num_workers',help='Num workers for data loader',type=int,default=0)#在训练过程中进行模型评估parser.add_argument('--do_eval',dest='do_eval',help='Eval while training',action='store_true')#日志打印间隔parser.add_argument('--log_iters',dest='log_iters',help='Display logging information at every log_iters',default=10,type=int)#开启可视化训练parser.add_argument('--use_vdl',dest='use_vdl',help='Whether to record the data to VisualDL during training',action='store_true')

然后看下一行代码:

main(args)

main 的代码如下:

def main(args):#获取环境信息,比如操作系统类型、python版本号、Paddle版本、GPU数量、Opencv版本、gcc版本等内容env_info = get_environ_info()#打印环境信息info = ['{}: {}'.format(k, v) for k, v in env_info.items()]info = '\n'.join(['\n', format('Environment Information', '-^48s')] + info +['-' * 48])logger.info(info)#确定是否使用GPUplace = 'gpu' if env_info['Paddle compiled with cuda'] and env_info['GPUs used'] else 'cpu'#设置使用GPU或者CPUpaddle.set_device(place)#如果没有指定配置文件这抛出异常。if not args.cfg:raise RuntimeError('No configuration file specified.')#构建cfg对象,该对象包含数据集、图像增强、模型结构、损失函数等设置#该对象基于命令行传入参数以及yaml配置文件构建cfg = Config(args.cfg,learning_rate=args.learning_rate,iters=args.iters,batch_size=args.batch_size)#从Config对象中获取train_data对象。train_data为迭代器train_dataset = cfg.train_dataset#如果没有设置训练集,抛出异常if not train_dataset:raise RuntimeError('The training dataset is not specified in the configuration file.')#如果需要在训练中进行模型评估,则需要获取到验证集val_dataset = cfg.val_dataset if args.do_eval else None#获取损失函数losses = cfg.lossmsg = '\n---------------Config Information---------------\n'msg += str(cfg)msg += '------------------------------------------------'#打印出详细设置。logger.info(msg)#调用core/train.py中train函数进行训练train(cfg.model,train_dataset,val_dataset=val_dataset,optimizer=cfg.optimizer,save_dir=args.save_dir,iters=cfg.iters,batch_size=cfg.batch_size,resume_model=args.resume_model,save_interval=args.save_interval,log_iters=args.log_iters,num_workers=args.num_workers,use_vdl=args.use_vdl,losses=losses)

在train.py脚本中,除了调用config对配置文件进行解析,就是调用core/train.py中的train函数完成训练工作。下面我先看一下train函数的工作流程。

从图中看出,整个训练过程由两个循环组成,最外层循环由总迭代次数控制,需要在yaml文件中配置,如下代码:

iters: 80000

内层循环由数据读取器控制,循环会遍历数据读取器中所有的数据,直至全部读取完毕跳出循环,这个过程通常也被叫做一个epoch。

下面我们详细解析一下core/train.py中train函数的代码。

首先看一下train函数的代码概要。

然后我们再看一下详细的代码解读,

def train(model, #模型对象train_dataset, #训练集对象val_dataset=None, #验证集对象,如果训练过程不需要验证,可以为Noneoptimizer=None, #优化器对象save_dir='output', #模型输出路径iters=10000, #训练最大迭代次数batch_size=2, #batch size大学resume_model=None, # 是否需要恢复训练,如果需要指定恢复训练模型权重路径save_interval=1000, # 模型保存间隔log_iters=10, # 设置日志输出间隔num_workers=0, #设置数据读取器线程数,0为不开启多进程use_vdl=False, #是否使用vdllosses=None): # 损失函数系数,当使用多个损失函数时,需要指定各个损失函数的系数。#为了兼容多卡训练,这里需要获取显卡数量。nranks = paddle.distributed.ParallelEnv().nranks#在分布式训练中,每个显卡都会执行本程序,所以需要在程序里获取本显卡的序列号。local_rank = paddle.distributed.ParallelEnv().local_rank#循环起始的迭代数。如果是恢复训练的话,从恢复训练中获得起始的迭代数。#比如,在2000次迭代的时候保存了中间训练过程,通过resume恢复训练,那么start_iter则为2000。start_iter = 0if resume_model is not None:start_iter = resume(model, optimizer, resume_model)#创建保存输出模型文件的目录。if not os.path.isdir(save_dir):if os.path.exists(save_dir):os.remove(save_dir)os.makedirs(save_dir)#如果是多卡训练,则需要初始化多卡训练环境。if nranks > 1:# Initialize parallel training environment.paddle.distributed.init_parallel_env()strategy = paddle.distributed.prepare_context()ddp_model = paddle.DataParallel(model, strategy)#创建一个批量采样器,这里指定数据集,通过批量采样器组成一个batch。这里需要指定batch size,是否随机打乱,是否丢弃末尾不能组成一个batch的数据等参数。batch_sampler = paddle.io.DistributedBatchSampler(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)#通过数据集参数和批量采样器等参数构建一个数据读取器。可以通过num_works设置多进程,这里的多进程通过共享内存通信,#如果共享内存过小可能会报错,如果报错可以尝将num_workers设置为0,则不开启多进程。loader = paddle.io.DataLoader(train_dataset,batch_sampler=batch_sampler,num_workers=num_workers,return_list=True,)if use_vdl:from visualdl import LogWriterlog_writer = LogWriter(save_dir)#开启定时器timer = Timer()avg_loss = 0.0iters_per_epoch = len(batch_sampler)best_mean_iou = -1.0best_model_iter = -1train_reader_cost = 0.0train_batch_cost = 0.0timer.start()iter = start_iter#开始循环,通过迭代次数控制最外层循环。while iter < iters:#内部循环,遍历数据迭代器中的数据。for data in loader:iter += 1if iter > iters:break#记录读取器时间train_reader_cost += timer.elapsed_time()#保存样本images = data[0]#保存样本标签labels = data[1].astype('int64')#供BCELoss使用edges = Noneif len(data) == 3:edges = data[2].astype('int64')#如果有多张显卡,则开启分布式训练,如果只有一张显卡则直接调用模型对象进行训练。if nranks > 1:#通过模型前向运算获得预测结果logits_list = ddp_model(images)else:#通过模型前向运算获得预测结果logits_list = model(images)#通过标签计算损失loss = loss_computation(logits_list=logits_list,labels=labels,losses=losses,edges=edges)#计算模型参数的梯度loss.backward()#执行一次优化器并进行参数更新optimizer.step()#获取当前优化器的学习率。lr = optimizer.get_lr()if isinstance(optimizer._learning_rate,paddle.optimizer.lr.LRScheduler):optimizer._learning_rate.step()#清除模型中的梯度model.clear_gradients()#计算平均损失值avg_loss += loss.numpy()[0]train_batch_cost += timer.elapsed_time()#根据配置中的log_iters打印训练日志if (iter) % log_iters == 0 and local_rank == 0:avg_loss /= log_itersavg_train_reader_cost = train_reader_cost / log_itersavg_train_batch_cost = train_batch_cost / log_iterstrain_reader_cost = 0.0train_batch_cost = 0.0remain_iters = iters - itereta = calculate_eta(remain_iters, avg_train_batch_cost)logger.info("[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}".format((iter - 1) // iters_per_epoch + 1, iter, iters,avg_loss, lr, avg_train_batch_cost,avg_train_reader_cost, eta))if use_vdl:log_writer.add_scalar('Train/loss', avg_loss, iter)log_writer.add_scalar('Train/lr', lr, iter)log_writer.add_scalar('Train/batch_cost',avg_train_batch_cost, iter)log_writer.add_scalar('Train/reader_cost',avg_train_reader_cost, iter)avg_loss = 0.0#根据配置中的save_interval判断是否需要对当前模型进行评估。if (iter % save_interval == 0or iter == iters) and (val_dataset is not None):num_workers = 1 if num_workers > 0 else 0mean_iou, acc = evaluate(model, val_dataset, num_workers=num_workers)#评估后需要将模型训练模式,该模式影响dropout和batchnorm层model.train()#根据配置中的save_interval判断是否需要保存当前模型。if (iter % save_interval == 0 or iter == iters) and local_rank == 0:current_save_dir = os.path.join(save_dir,"iter_{}".format(iter))#如果输出路径不存在,需要创建目录。if not os.path.isdir(current_save_dir):os.makedirs(current_save_dir)#保存模型权重paddle.save(model.state_dict(),os.path.join(current_save_dir, 'model.pdparams'))#保存优化器权重,恢复训练会用到。paddle.save(optimizer.state_dict(),os.path.join(current_save_dir, 'model.pdopt'))#保存最佳模型。if val_dataset is not None:if mean_iou > best_mean_iou:best_mean_iou = mean_ioubest_model_iter = iterbest_model_dir = os.path.join(save_dir, "best_model")paddle.save(model.state_dict(),os.path.join(best_model_dir, 'model.pdparams'))logger.info('[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'.format(best_mean_iou, best_model_iter))if use_vdl:log_writer.add_scalar('Evaluate/mIoU', mean_iou, iter)log_writer.add_scalar('Evaluate/Acc', acc, iter)#重置定时器timer.restart()# Sleep for half a second to let dataloader release resources.time.sleep(0.5)if use_vdl:log_writer.close()

PaddleSeg套件训练入口train.py文件解读到此结束。

PaddleSeg仓库地址:https://github.com/PaddlePaddle/PaddleSeg

图像分割套件PaddleSeg全面解析(一)train.py代码解读相关推荐

  1. yolov5——train.py代码【注释、详解、使用教程】

    yolov5--train.py代码[注释.详解.使用教程] yolov5--train.py代码[注释.详解.使用教程] yolov5--train.py代码[注释.详解.使用教程] 前言 1. p ...

  2. 【Image captioning】Show, Attend, and Tell 从零到掌握之三--train.py代码详解

    [Image captioning]Show, Attend, and Tell 从零到掌握之三–train.py代码详解 作者:安静到无声 个人主页 作者简介:人工智能和硬件设计博士生.CSDN与阿 ...

  3. im2rec.py代码解读

    im2rec.py解读 直接给代码,注释我写上了. #!/usr/bin/env python3 # -*- coding: utf-8 -*- # Licensed to the Apache So ...

  4. YOLOV5训练代码train.py注释与解析

    YOLOv5代码注释版更新啦,注释的是最近的2021.07.14的版本,且注释更全 github: https://github.com/Laughing-q/yolov5_annotations Y ...

  5. 精度87%!业内首个动静统一的图像分割套件重磅推出

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 方向:图像分割技术 图像分割技术是计算机视觉领域的一个重要的研究方向,早 ...

  6. pytorch YoLOV3 源码解析 train.py

    train.py 总体分为三部分(不算import 库) 初始的一些设定 + train函数 + main函数 源码地址: https://github.com/ultralytics/yolov3 ...

  7. 【项目实战】WaveNet 代码解析 —— train.py 【更新中】

    WaveNet 代码解析 -- train.py 文章目录 WaveNet 代码解析 -- train.py 简介 代码解析 全局变量解析 函数解析 main() get_arguments() va ...

  8. YOLOV5dataset.py代码注释与解析

    YOLOv5代码注释版更新啦,注释的是最近的2021.07.14的版本,且注释更全 github: https://github.com/Laughing-q/yolov5_annotations Y ...

  9. 【YOLOV5-5.x 源码解读】train.py

    目录 前言 0.导入需要的包和基本配置 1.设置opt参数 2.main函数 2.1.logging和wandb初始化 2.2.判断是否使用断点续训resume, 读取参数 2.3.DDP mode设 ...

最新文章

  1. 谈 三层结构与MVC模式的区别
  2. 涨见识了,在终端执行 Python 代码的 6 种方式
  3. java的connect和http_java发起HttpURLConnection和HttpsURLConnection请求 | 学步园
  4. 下一代数据网加速器成立,致力于建设智能时代的大数据基础设施
  5. python 轨迹识别
  6. java和python哪个好就业2020-java和python哪个的前途更好?
  7. 北京迎“豪宅元年”:四环房价将迈入8万元时代
  8. fastdfs java上传文件_FastDFS java客户端文件上传demo
  9. css知识笔记(二)——盒子模型
  10. Nginx----基础
  11. tensorFlow13卷积神经网络发展
  12. 信息学奥赛C++语言:调整试题顺序
  13. C#随机不重复给数组赋值1-100并排序
  14. python读取txt文件出现UnicodeError
  15. 如何启用Domino 8 的ODS磁盘结构
  16. 12864 C语言程序 带详细注解
  17. CBoard项目学习
  18. apktool java_apktool的使用
  19. 电动汽车相关功率计算
  20. Jasperreport_6.18的吐血记录三之简易交叉表 + 页面预览和导出

热门文章

  1. 【OpenCV(C++)】分离颜色通道、多通道图像混合
  2. 妙莲千里寻师拜访记【转】
  3. 基于ASP.NET的电子商务网站管理系统_WEB管理系统_SQLServer数据库应用
  4. 转:张小龙:如何把产品做简单
  5. java菜鸟到大神要知道哪些知识?
  6. CentOS8搭建SonarQube9+SonarScanner+Postgresql+bitbucket+cppcheck 扫描C语言。(未完待续)
  7. 学习笔记(2022-5-26)——bind-dns
  8. 爱,请在来得及的时候说出口
  9. Prophet模型的简介以及案例分析
  10. Textured Neural Avatars 论文方法简述