背景

当时的state_of_the_art结果是由残差神经网络以SGD with Momentum作为优化器训练出来的。训练一个DNN的主要困难跟学习率的调度和L2 weight decay regularization的量有关。通常学习率要么在训练中保持不变,要么每过一个区间除以一个常量。SGDR的提出是为了改善学习率的调度。SGDR是在每一次重启时学习率设定为某个值,随后按照一种策略减少。它相比于当下的学习率调度器,少用2-4倍的epoch即可达到与之媲美或更好的结果。

class CosineAnnealingLR_with_Restart(_LRScheduler):"""Set the learning rate of each parameter group using a cosine annealingschedule, where :math:`\eta_{max}` is set to the initial lr and:math:`T_{cur}` is the number of epochs since the last restart in SGDR:.. math::\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +\cos(\frac{T_{cur}}{T_{max}}\pi))When last_epoch=-1, sets initial lr as lr.It has been proposed in`SGDR: Stochastic Gradient Descent with Warm Restarts`_. The original pytorchimplementation only implements the cosine annealing part of SGDR,I added my own implementation of the restarts part.Args:optimizer (Optimizer): Wrapped optimizer.T_max (int): Maximum number of iterations.T_mult (float): Increase T_max by a factor of T_multeta_min (float): Minimum learning rate. Default: 0.last_epoch (int): The index of last epoch. Default: -1.model (pytorch model): The model to save.out_dir (str): Directory to save snapshotstake_snapshot (bool): Whether to save snapshots at every restart.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:https://arxiv.org/abs/1608.03983"""def __init__(self, optimizer, T_max, T_mult, model, out_dir, take_snapshot, eta_min=0, last_epoch=-1):self.T_max = T_maxself.T_mult = T_multself.Te = self.T_maxself.eta_min = eta_minself.current_epoch = last_epochself.model = modelself.out_dir = out_dirself.take_snapshot = take_snapshotself.lr_history = []super(CosineAnnealingLR_with_Restart, self).__init__(optimizer, last_epoch)def get_lr(self):new_lrs = [self.eta_min + (base_lr - self.eta_min) *(1 + math.cos(math.pi * self.current_epoch / self.Te)) / 2for base_lr in self.base_lrs]self.lr_history.append(new_lrs)return new_lrsdef step(self, epoch=None):if epoch is None:epoch = self.last_epoch + 1self.last_epoch = epochself.current_epoch += 1for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):param_group['lr'] = lr## restartif self.current_epoch == self.Te:print("restart at epoch {:03d}".format(self.last_epoch + 1))if self.take_snapshot:torch.save({'epoch': self.T_max,'state_dict': self.model.state_dict()}, self.out_dir + "Weight/" + 'snapshot_e_{:03d}.pth.tar'.format(self.T_max))## reset epochs since the last resetself.current_epoch = 0## reset the next goalself.Te = int(self.Te * self.T_mult)self.T_max = self.T_max + self.Te
    sgdr = CosineAnnealingLR_with_Restart(optimizer,T_max=config.cycle_inter,T_mult=1,model=net,out_dir='../input/',take_snapshot=False,eta_min=1e-3)global_min_acer = 1.0for cycle_index in range(config.cycle_num):  # 有cycle_num轮循环# parser.add_argument('--cycle_num', type=int, default=10)# parser.add_argument('--cycle_inter', type=int, default=50)print('cycle index: ' + str(cycle_index))min_acer = 1.0for epoch in range(0, config.cycle_inter): # 一个周期有cycle_inter个epoch# 更新学习率sgdr.step()lr = optimizer.param_groups[0]['lr']print('lr : {:.4f}'.format(lr))sum_train_loss = np.zeros(6,np.float32)sum = 0# 梯度清零optimizer.zero_grad()for input, truth in train_loader:iter = i + start_iter# one iteration update  -------------# 切换成训练模式net.train()#input = input.cuda()#truth = truth.cuda()# import pdb# pdb.set_trace()# 前向计算logit,_,_ = net.forward(input)truth = truth.view(logit.shape[0])# 计算lossloss  = criterion(logit, truth)precision,_ = metric(logit, truth)# 反向传播loss.backward()# 更新权重optimizer.step()# 梯度清零optimizer.zero_grad()# print statistics  ------------batch_loss[:2] = np.array(( loss.item(), precision.item(),))sum += 1if iter%iter_smooth == 0:train_loss = sum_train_loss/sumsum = 0i = i + 1# 如果epoch过了cycle_inter的一半if epoch >= config.cycle_inter // 2:# if 1:# 切换eval模式net.eval()# 算lossvalid_loss, _ = do_valid_test(net, valid_loader, criterion)# 切回训练模式net.train()# 更新本轮循环和全局最优的权重if valid_loss[1] < min_acer and epoch > 0:min_acer = valid_loss[1]ckpt_name = out_dir + '/checkpoint/Cycle_' + str(cycle_index) + '_min_acer_model.pth'torch.save(net.state_dict(), ckpt_name)log.write('save cycle ' + str(cycle_index) + ' min acer model: ' + str(min_acer) + '\n')if valid_loss[1] < global_min_acer and epoch > 0:global_min_acer = valid_loss[1]ckpt_name = out_dir + '/checkpoint/global_min_acer_model.pth'torch.save(net.state_dict(), ckpt_name)log.write('save global min acer model: ' + str(min_acer) + '\n')asterisk = ' 'log.write(config.model_name+' Cycle %d: %0.4f %5.1f %6.1f | %0.6f  %0.6f  %0.3f %s  | %0.6f  %0.6f |%s \n' % (cycle_index, lr, iter, epoch,valid_loss[0], valid_loss[1], valid_loss[2], asterisk,batch_loss[0], batch_loss[1],time_to_str((timer() - start), 'min')))# 记录本轮循环最后一个epoch的权重,我觉得没必要ckpt_name = out_dir + '/checkpoint/Cycle_' + str(cycle_index) + '_final_model.pth'torch.save(net.state_dict(), ckpt_name)log.write('save cycle ' + str(cycle_index) + ' final model \n')

cosine_annealing相关推荐

  1. 基于mindspore的口罩检测训练与在线推理

    mindspore安装地址:https://www.mindspore.cn/install mindspore开源模型库:https://gitee.com/mindspore/models 测试平 ...

  2. android 运动目标检测_MindSpore应用案例:AI对篮球运动员目标的检测

    MindSpore作为一个端边云协同的开源的全场景AI框架,今年3月份开源以来,受到了业界的广泛关注和应用,欢迎大家参与开源贡献.模型众智合作.行业创新与应用.学术合作等,贡献您在云侧.端侧 (HiM ...

  3. MindSpore实践:对篮球运动员目标的检测

    本文分享自华为云社区<MindSpore大V博文系列:AI对篮球运动员目标的检测>,原文作者:李锐锋. MindSpore作为一个端边云协同的开源的全场景AI框架,今年3月份开源以来,受到 ...

  4. 论文解读:Foreground-Aware Relation Network for Geospatial Object Segmentation in High Spatial Resolution

    发表时间:2020 项目地址:https://github.com/Z-Zheng/FarSeg pytorch实现,依赖simplecv库(resnet.fpn支持) 论文地址:https://op ...

  5. 论文精讲 | 一种隐私保护边云协同训练

    作者:王森.王鹏.姚信.崔金凯.胡钦涛.陈仁海.张弓 |单位:2012实验室理论部 论文标题 MistNet: Towards Private Neural Network Training with ...

最新文章

  1. z最大子数组c语言,关于最大子数组问题
  2. appium+python 操作APP
  3. My github blog
  4. 计算机语言学考研科目,语言学考研笔记整理(共16页)
  5. 休眠身份,序列和表(序列)生成器
  6. 第七节:语法总结(1)(自动属性、out参数、对象初始化器、var和dynamic等)
  7. TCP,UDP数据包的大小以及MTU
  8. 天联高级版客户端_壹拓网科技关于金万维天联标准版、异速联和天联高级版区别的讲解...
  9. 《游戏大师Chris Crawford谈互动叙事》一1.2 两种思维方式
  10. Oracle递归查询所有树结构,并确定其中的一条分支
  11. Java后台开发精选知识图谱
  12. Linux离线安装pyhive使用 python连接hive
  13. 紫外分光光度法测量蛋白质的含量
  14. 简述osi参考模型各层主要功能_简述OSI参考模型中各层的主要功能
  15. 卡斯卡迪亚社区建设者奖:2017年获胜者公布
  16. 寄存柜程序模拟(C语言)
  17. windows7 安装哪个版本的vs_轻钢龙骨vs木龙骨 吊顶安装选哪个好,看完立刻做出了选择...
  18. 简记_插件电阻功率选型及使用注意事项
  19. C++中的long long和__int64类型
  20. 前端学习 day4 : 盒子模型,浮动,定位

热门文章

  1. playhome的php文件怎么导入,PLAY HOME家族崩坏Importor模型导入插
  2. SAP ABAP BDC基础使用方法
  3. 宝塔安装phalcon扩展及nginx配置
  4. Proximal Gradient for LASSO
  5. Mybatis入门笔记
  6. 【Python】xlwings-删除重复行
  7. Ubuntu 20.04 Server 使用命令行设置 IP 地址
  8. linux系统bcast,Linux系统下取IP地址的几种方法
  9. 一套牛逼哄哄的开源的监控系统(附源码)
  10. Google是如何教会机器玩Atari游戏的