Trainer类初始化部分——

class Trainer(object):def __init__(self, args):self.args = args# Define Saverself.saver = Saver(args)self.saver.save_experiment_config()# Define Tensorboard Summaryself.summary = TensorboardSummary(self.saver.experiment_dir)self.writer = self.summary.create_summary()self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else Falseself.opt_level = args.opt_levelkwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last':True}# 加载数据部分。当前模式为search,因此将训练集分成A,B两部分。self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)# 判断是否需要使用已有的权重。【balanced是啥意思?】if args.use_balanced_weights:# 设置权重路径classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')# 加载权重if os.path.isfile(classes_weights_path):weight = np.load(classes_weights_path)else:raise NotImplementedError#if so, which trainloader to use?# weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)# 将np格式的权重转换成torch的tensor格式weight = torch.from_numpy(weight.astype(np.float32))else:weight = None# 设置衡量标准,默认的args.loss_type为'ce'self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)# 定义AutoDeeplab网络结构model = AutoDeeplab (self.nclass, 12, self.criterion, self.args.filter_multiplier,self.args.block_multiplier, self.args.step)# 优化model参数时,采用SGD随机梯度下降方法。optimizer = torch.optim.SGD(model.weight_parameters(),args.lr,momentum=args.momentum,weight_decay=args.weight_decay)self.model, self.optimizer = model, optimizer# 优化结构参数arch_parameters时,采用Adam优化算法。self.architect_optimizer = torch.optim.Adam(self.model.arch_parameters(),lr=args.arch_lr, betas=(0.9, 0.999),weight_decay=args.arch_weight_decay)# 定义模型表现评估类。Evaluator的类方法中,包含MIOU指标的计算方法。self.evaluator = Evaluator(self.nclass)# Define lr schedulerself.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,args.epochs, len(self.train_loaderA), min_lr=args.min_lr)# TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well# Using cudaif args.cuda:self.model = self.model.cuda()# mixed precisionif self.use_amp and args.cuda:keep_batchnorm_fp32 = True if (self.opt_level == 'O2' or self.opt_level == 'O3') else None# fix for current pytorch version with opt_level 'O1'if self.opt_level == 'O1' and torch.__version__ < '1.3':for module in self.model.modules():if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):# Hack to fix BN fprop without affine transformationif module.weight is None:module.weight = torch.nn.Parameter(torch.ones(module.running_var.shape, dtype=module.running_var.dtype,device=module.running_var.device), requires_grad=False)if module.bias is None:module.bias = torch.nn.Parameter(torch.zeros(module.running_var.shape, dtype=module.running_var.dtype,device=module.running_var.device), requires_grad=False)# print(keep_batchnorm_fp32)self.model, [self.optimizer, self.architect_optimizer] = amp.initialize(self.model, [self.optimizer, self.architect_optimizer], opt_level=self.opt_level,keep_batchnorm_fp32=keep_batchnorm_fp32, loss_scale="dynamic")print('cuda finished')# Using data parallelif args.cuda and len(self.args.gpu_ids) >1:if self.opt_level == 'O2' or self.opt_level == 'O3':print('currently cannot run with nn.DataParallel and optimization level', self.opt_level)self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)patch_replication_callback(self.model)print('training on multiple-GPUs')#checkpoint = torch.load(args.resume)#print('about to load state_dict')#self.model.load_state_dict(checkpoint['state_dict'])#print('model loaded')#sys.exit()# Resuming checkpointself.best_pred = 0.0if args.resume is not None:if not os.path.isfile(args.resume):raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))checkpoint = torch.load(args.resume)args.start_epoch = checkpoint['epoch']# if the weights are wrapped in module object we have to clean itif args.clean_module:self.model.load_state_dict(checkpoint['state_dict'])state_dict = checkpoint['state_dict']new_state_dict = OrderedDict()for k, v in state_dict.items():name = k[7:]  # remove 'module.' of dataparallelnew_state_dict[name] = v# self.model.load_state_dict(new_state_dict)copy_state_dict(self.model.state_dict(), new_state_dict)else:if torch.cuda.device_count() > 1 or args.load_parallel:# self.model.module.load_state_dict(checkpoint['state_dict'])copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict'])else:# self.model.load_state_dict(checkpoint['state_dict'])copy_state_dict(self.model.state_dict(), checkpoint['state_dict'])if not args.ft:# self.optimizer.load_state_dict(checkpoint['optimizer'])copy_state_dict(self.optimizer.state_dict(), checkpoint['optimizer'])self.best_pred = checkpoint['best_pred']print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))# Clear start epoch if fine-tuningif args.ft:args.start_epoch = 0

training(epoch) —— 训练部分

    def training(self, epoch):# 初始的训练误差为0train_loss = 0.0# 训练AutoDeeplab模型。self.model.train()# 定义进度条。tbar = tqdm(self.train_loaderA)num_img_tr = len(self.train_loaderA)for i, sample in enumerate(tbar):# 获取当前样本的图像和label。image, target = sample['image'], sample['label']if self.args.cuda:image, target = image.cuda(), target.cuda()# 这一步不太理解——self.scheduler(self.optimizer, i, epoch, self.best_pred)# 将所有变量的梯度清零。self.optimizer.zero_grad()# 输入image,获取AutoDeeplab模型的输出output。output = self.model(image)# criterion定义为交叉熵ce,因此此处的loss为output和target计算得到的交叉熵值。loss = self.criterion(output, target)# 反向传播,更新AutoDeeplab模型的参数。if self.use_amp:with amp.scale_loss(loss, self.optimizer) as scaled_loss:scaled_loss.backward()else:loss.backward()# 执行一次优化步骤。self.optimizer.step()# 当epoch达到一定阙值时。if epoch >= self.args.alpha_epoch:search = next(iter(self.train_loaderB))# 获取训练集B中一次采样得到的图像和label。image_search, target_search = search['image'], search['label']if self.args.cuda:image_search, target_search = image_search.cuda (), target_search.cuda ()# 将结构优化器中所有的梯度清零。self.architect_optimizer.zero_grad()# 获取当前模型在输入图像上运行的输出值。output_search = self.model(image_search)# 计算结构loss。arch_loss = self.criterion(output_search, target_search)# 反向传播,更新结构参数alpha和beta的值。if self.use_amp:with amp.scale_loss(arch_loss, self.architect_optimizer) as arch_scaled_loss:arch_scaled_loss.backward()else:arch_loss.backward()# 执行一次结构优化步骤。self.architect_optimizer.step()# 计算训练误差,显示在进度条上。train_loss += loss.item()tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))# Show 10 * 3 inference results each epochif i % (num_img_tr // 10) == 0:global_step = i + num_img_tr * epochself.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))print('Loss: %.3f' % train_loss)# 这个参数不知道是干哈的。if self.args.no_val:# save checkpoint every epochis_best = Falseif torch.cuda.device_count() > 1:state_dict = self.model.module.state_dict()else:state_dict = self.model.state_dict()self.saver.save_checkpoint({'epoch': epoch + 1,'state_dict': state_dict,'optimizer': self.optimizer.state_dict(),'best_pred': self.best_pred,}, is_best)

validation(epoch) ——验证部分

    def validation(self, epoch):self.model.eval()self.evaluator.reset()tbar = tqdm(self.val_loader, desc='\r')test_loss = 0.0for i, sample in enumerate(tbar):# 获取当前采样样本的图像和label值。image, target = sample['image'], sample['label']if self.args.cuda:image, target = image.cuda(), target.cuda()# 由于是验证集,不参与训练,因此要在torch.no_grad()的条件下进行计算。with torch.no_grad():output = self.model(image)# 计算损失值loss。loss = self.criterion(output, target)test_loss += loss.item()tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))# 将网络结构的输出转为pred预测值。pred = output.data.cpu().numpy()target = target.cpu().numpy()pred = np.argmax(pred, axis=1)# Add batch sample into evaluator# 将这批采样结果放入evaluator中。self.evaluator.add_batch(target, pred)# Fast test during the training# 计算Acc, Acc_class, mIoU, FWIoU四个评价指标的值。Acc = self.evaluator.Pixel_Accuracy()Acc_class = self.evaluator.Pixel_Accuracy_Class()mIoU = self.evaluator.Mean_Intersection_over_Union()FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()# 打印各项指标值。self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)self.writer.add_scalar('val/mIoU', mIoU, epoch)self.writer.add_scalar('val/Acc', Acc, epoch)self.writer.add_scalar('val/Acc_class', Acc_class, epoch)self.writer.add_scalar('val/fwIoU', FWIoU, epoch)print('Validation:')print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))print('Loss: %.3f' % test_loss)new_pred = mIoUif new_pred > self.best_pred:# 更新模型获得的最新预测。is_best = Trueself.best_pred = new_predif torch.cuda.device_count() > 1:state_dict = self.model.module.state_dict()else:state_dict = self.model.state_dict()self.saver.save_checkpoint({'epoch': epoch + 1,'state_dict': state_dict,'optimizer': self.optimizer.state_dict(),'best_pred': self.best_pred,}, is_best)

main()主函数部分

def main():args = obtain_search_args()args.cuda = not args.no_cuda and torch.cuda.is_available()if args.cuda:try:args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]except ValueError:raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')if args.sync_bn is None:if args.cuda and len(args.gpu_ids) > 1:args.sync_bn = Trueelse:args.sync_bn = False# default settings for epochs, batch_size and lrif args.epochs is None:epoches = {'coco': 30,'cityscapes': 40,'pascal': 50,'kd':10}args.epochs = epoches[args.dataset.lower()]if args.batch_size is None:args.batch_size = 4 * len(args.gpu_ids)if args.test_batch_size is None:args.test_batch_size = args.batch_size#args.lr = args.lr / (4 * len(args.gpu_ids)) * args.batch_sizeif args.checkname is None:args.checkname = 'deeplab-'+str(args.backbone)print(args)torch.manual_seed(args.seed)# 这里初始化了待训练的整体结构。trainer = Trainer(args) print('Starting Epoch:', trainer.args.start_epoch)print('Total Epoches:', trainer.args.epochs)for epoch in range(trainer.args.start_epoch, trainer.args.epochs):# 开始训练。trainer.training(epoch) if not trainer.args.no_val and epoch % args.eval_interval == (args.eval_interval - 1):# 满足条件时,用验证集验证。trainer.validation(epoch)trainer.writer.close()

train_autodeeplab.py代码阅读笔记相关推荐

  1. BNN Pytorch代码阅读笔记

    BNN Pytorch代码阅读笔记 这篇博客来写一下我对BNN(二值化神经网络)pytorch代码的理解,我是第一次阅读项目代码,所以想仔细的自己写一遍,把细节理解透彻,希望也能帮到大家! 论文链接: ...

  2. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(八)—— 模型训练-训练

    系列目录: 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(一)--数据 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(二)-- 介绍及分词 菜鸟笔记-DuReader阅读理解基线模 ...

  3. CNN去马赛克代码阅读笔记

    有的博客链接是之前几周写好的草稿,最近整理的时候才发布的 CNN去马赛克论文及代码下载地址 有torch,minimal torch和caffe三种版本 关于minimal torch版所做的努力,以 ...

  4. VITAL Tracker Pytorch 代码阅读笔记

    VITAL Tracker Pytorch 代码阅读笔记 论文链接:https://arxiv.org/pdf/1804.04273.pdf 代码链接:https://github.com/abner ...

  5. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(九)—— 预测与校验

    系列目录: 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(一)--数据 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(二)-- 介绍及分词 菜鸟笔记-DuReader阅读理解基线模 ...

  6. StyleGAN2代码阅读笔记

    源代码地址:https://github.com/NVlabs/stylegan2-ada-pytorch 这是一篇代码阅读笔记,顾名思义是对代码进行阅读,讲解的笔记.对象是styleGAN2的pyt ...

  7. [置顶] Linux协议栈代码阅读笔记(一)

    Linux协议栈代码阅读笔记(一) (基于linux-2.6.21.7) (一)用户态通过诸如下面的C库函数访问协议栈服务 int socket(int domain, int type, int p ...

  8. linux 协议栈 位置,[置顶] Linux协议栈代码阅读笔记(一)

    Linux协议栈代码阅读笔记(一) (基于linux-2.6.21.7) (一)用户态通过诸如下面的C库函数访问协议栈服务 int socket(int domain, int type, int p ...

  9. leveldb代码阅读笔记(一)

    leveldb代码阅读笔记 above all leveldb是一个单机的键值存储的内存数据库,其内部使用了 LSM tree 作为底层存储结构,支持多版本数据控制,代码设计巧妙且简洁高效,十分值得作 ...

最新文章

  1. [Mysql]过大sql文件导入过慢问题解决
  2. 默认的程序化等效项(类型)
  3. 使用String 的 intern做锁提高并发能力
  4. KEIL MDK 仿真时程序”乱跑“问题
  5. python通过pyinstaller打包软件将GUI项目打包成exe文件
  6. 转:ibatis动态sql
  7. 数据结构入门学习笔记-1
  8. C语言的格式控制符问题
  9. 单片机c语言模块化实例程序设计,单片机C语言模块化设计
  10. Linux内核深入理解定时器和时间管理(4):定时器 timer
  11. Humble Numbers(丑数) 超详解!
  12. How to Install Ruby on Rails on CentOS 6
  13. intel32指令中文版
  14. 代码整洁之道读书笔记----第二章---命名--第一节-名副其实
  15. 信息学奥赛一本通知识集锦+往年真题
  16. 天野学院易语言模拟脚本制作
  17. 抖音api开放平台对接_抖音开放一键发布功能 第三方内容可分享至抖音
  18. 前端学习--PS移动和选区工具
  19. oracle 同义词public,oracle中private同义词跟public同义词
  20. ACL2021_Enhancing Entity Boundary Detection for Better Chinese Named Entity Recognition

热门文章

  1. ISP:NVP2610,NVP2620,NVP2630,NVP2631,NVP2650,参数大比拼
  2. python解析照片拍摄时间和地点信息
  3. 冲压过程中,冲裁力的变化过程
  4. Excel-判断一个文本字符串中是否包含数字! 判断一个文本字符串是否是纯汉字!...
  5. 我喜欢星期五因为我们那天有计算机课英文,34  (五年级上册)第二单元 我星期里的每天--中文...
  6. 亚马逊激光产品认证,亚马逊激光产品做什么认证,激光产品没做认证被下架了?亚马逊FDA 21CFR1040.1的激光测试报告,亚马逊IEC60825-1测试报告,激光等级分类
  7. wps论文排版的步骤
  8. android 仿网易标签切换,Android 仿网易新闻客户端Tab标签
  9. 2418. 按身高排序-快速排序力扣双百代码
  10. 从零开始一个webpack+react项目