昨天说完了args的作用,今天就继续开这个大坑的重点内容:train_net

现在就先看看train_net整个函数是怎么运行的

def train_net(args):torch.manual_seed(7)np.random.seed(7)checkpoint = args.checkpointstart_epoch = 0best_loss = float('inf')writer = SummaryWriter()epochs_since_improvement = 0decays_since_improvement = 0# Initialize / load checkpointif checkpoint is None:model = DIMModel(n_classes=1, in_channels=4, is_unpooling=True, pretrain=True)model = nn.DataParallel(model)if args.optimizer == 'sgd':optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom,weight_decay=args.weight_decay)else:optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)else:checkpoint = torch.load(checkpoint)start_epoch = checkpoint['epoch'] + 1epochs_since_improvement = checkpoint['epochs_since_improvement']model = checkpoint['model'].moduleoptimizer = checkpoint['optimizer']logger = get_logger()# Move to GPU, if availablemodel = model.to(device)# Custom dataloaderstrain_dataset = DIMDataset('train')train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8)valid_dataset = DIMDataset('valid')valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)# Epochsfor epoch in range(start_epoch, args.end_epoch):if args.optimizer == 'sgd' and epochs_since_improvement == 10:breakif args.optimizer == 'sgd' and epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0:checkpoint = 'BEST_checkpoint.tar'checkpoint = torch.load(checkpoint)model = checkpoint['model']optimizer = checkpoint['optimizer']decays_since_improvement += 1print("\nDecays since last improvement: %d\n" % (decays_since_improvement,))adjust_learning_rate(optimizer, 0.6 ** decays_since_improvement)# One epoch's trainingtrain_loss = train(train_loader=train_loader,model=model,optimizer=optimizer,epoch=epoch,logger=logger)effective_lr = get_learning_rate(optimizer)print('Current effective learning rate: {}\n'.format(effective_lr))writer.add_scalar('Train_Loss', train_loss, epoch)writer.add_scalar('Learning_Rate', effective_lr, epoch)# One epoch's validationvalid_loss = valid(valid_loader=valid_loader,model=model,logger=logger)writer.add_scalar('Valid_Loss', valid_loss, epoch)# Check if there was an improvementis_best = valid_loss < best_lossbest_loss = min(valid_loss, best_loss)if not is_best:epochs_since_improvement += 1print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))else:epochs_since_improvement = 0decays_since_improvement = 0# Save checkpointsave_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best)

前面那些变量经过查找资料之后注释如下:

epochs_since_improvement和decays_since_improvement在后续的遍历有所体现,到后面再具体说是做什么的,end_epoch在上一讲中的argparse里面有所涉及,换句话说就类似于提前默认好了一个变量直接用,而且在参数修改的时候用命令行就能进行修改。

下一步会先去判断有没有检查点也就是事先练好的模型,如果没有的话就创建一个模型model,然后再判断优化器的类型来决定模型使用的优化器。由于篇幅关系DIM_MODEL这个今天就不做详解,整体的角度捋一遍整个train_net函数的结构。

这里面有一句dataparallel,这个是为了让模型能在多个gpu运行,因为用一个gpu跑dim的话显存有限而且时间太长,为了方便训练一般都使用多个显卡一起炼丹,笔者之前试过8块3090的效果,真的快如闪电。。。。

数据集dataloader的具体实现本期跳过,直接到后面的重点内容:训练过程

# Epochsfor epoch in range(start_epoch, args.end_epoch):if args.optimizer == 'sgd' and epochs_since_improvement == 10:breakif args.optimizer == 'sgd' and epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0:checkpoint = 'BEST_checkpoint.tar'checkpoint = torch.load(checkpoint)model = checkpoint['model']optimizer = checkpoint['optimizer']decays_since_improvement += 1print("\nDecays since last improvement: %d\n" % (decays_since_improvement,))adjust_learning_rate(optimizer, 0.6 ** decays_since_improvement)# One epoch's trainingtrain_loss = train(train_loader=train_loader,model=model,optimizer=optimizer,epoch=epoch,logger=logger)effective_lr = get_learning_rate(optimizer)print('Current effective learning rate: {}\n'.format(effective_lr))writer.add_scalar('Train_Loss', train_loss, epoch)writer.add_scalar('Learning_Rate', effective_lr, epoch)# One epoch's validationvalid_loss = valid(valid_loader=valid_loader,model=model,logger=logger)writer.add_scalar('Valid_Loss', valid_loss, epoch)# Check if there was an improvementis_best = valid_loss < best_lossbest_loss = min(valid_loss, best_loss)if not is_best:epochs_since_improvement += 1print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))else:epochs_since_improvement = 0decays_since_improvement = 0# Save checkpointsave_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best)

乍一看很简单,实际上就是各种套娃。。。初始的状态如下:

start_epoch=0

end_epoch = 100(可以命令行设定)

epochs_since_improvement = 0

decays_since_improvement = 0

从这里就必须要弄明白一个关键的事情:为什么要设置since_improvement这类的变量,上来判断说如果epochs_since_improvement==10的时候训练就停,这是必须要思考的问题。那就只盯着epochs_since_improvement和decays_since_improvement,直到整个代码的最后一块才找到问题的所在。

 # One epoch's validationvalid_loss = valid(valid_loader=valid_loader,model=model,logger=logger)writer.add_scalar('Valid_Loss', valid_loss, epoch)# Check if there was an improvementis_best = valid_loss < best_lossbest_loss = min(valid_loss, best_loss)if not is_best:epochs_since_improvement += 1print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))else:epochs_since_improvement = 0decays_since_improvement = 0# Save checkpointsave_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best)

之前说过,best_loss初始的时候是正无穷为了方便后续的损失也就是valid_loss进行更新,那么如果valid_loss小于best_loss,那么此时此刻best_loss更新为更小的数值,然后is_best会变成1,此时此刻epochs_since_improvement和decays_since_improvement就会更新为0,反之如果best_loss更小,那么epochs_since_improvement就会加1

再回到开头,如果epochs_since_improvement==10的时候就终止循环,也就是说这里面有10次的损失值没法更新了,那么这个变量的作用就体现出来:避免过多的训练导致资源的浪费,既然有连续十次的损失函数没法更新,那就没必要接着玩。

那么decays_since_improvement这个东西又是咋回事?往前面看看。

if args.optimizer == 'sgd' and epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0:checkpoint = 'BEST_checkpoint.tar'checkpoint = torch.load(checkpoint)model = checkpoint['model']optimizer = checkpoint['optimizer']decays_since_improvement += 1print("\nDecays since last improvement: %d\n" % (decays_since_improvement,))adjust_learning_rate(optimizer, 0.6 ** decays_since_improvement)

此时此刻在进入循环的时候如果之前出来的epochs_since_improvement大于0且能被2整除(正偶数),就直接在checkpoint里面进行运作。因为我从没用过他提供的checkpoint,而且正常运行的话损失函数在每一次的循环之后都会朝着更低的方向来跑,所以这里面我的猜测就是因为使用了已经训练好的模型的checkpoint,所以在训练的时候就会出现多次的最佳损失值无法更新,因此在运行的时候直接调用checkpiont里面的参数。后续再看很多代码的训练函数都有这么写的,因此这一功能就显得特别重要。

其实到了这块整个训练的代码结构的大概就已经展现在眼前了。在前期准备工作就绪之后,直接在每一个epoch里面进行模型循环然后再得到损失值进行更新,但是这里面的细节还得放到后面填坑

1.dim的模型结构

2.train和valid都是做什么的

3.writer到底有什么作用

这些放到后面填坑吧

手撕代码1:Deep image matting(3)相关推荐

  1. 和12岁小同志搞创客开发:手撕代码,做一款遥控灯

    机缘巧合在网上认识一位12岁小同志,从零开始系统辅导其创客开发思维和技巧. 项目专栏:https://blog.csdn.net/m0_38106923/category_11097422.html ...

  2. 和12岁小同志搞创客开发:手撕代码,做一款声控灯

    机缘巧合在网上认识一位12岁小同志,从零开始系统辅导其创客开发思维和技巧. 项目专栏:https://blog.csdn.net/m0_38106923/category_11097422.html ...

  3. 手撕代码之七大常用排序算法 | 附完整代码

    点击上方↑↑↑蓝字关注我们~ 「2019 Python开发者日」全日程揭晓,请扫码咨询 ↑↑↑ 0.导语 本节为手撕代码系列之第一弹,主要来手撕排序算法,主要包括以下几大排序算法: 直接插入排序 冒泡 ...

  4. Interview:算法岗位面试—11.06早上上海某智能驾驶科技公司(创业)笔试+面试之手撕代码、项目考察、比赛考察、图像算法的考察等

    Interview:算法岗位面试-11.06早上上海某智能驾驶科技公司(创业)笔试+面试之手撕代码.项目考察.比赛考察.图像算法的考察等 导读:该公司是在同济某次大型招聘会上投的,当时和HR聊了半个多 ...

  5. 蛇形打印数组(某宝典公司面试手撕代码题)

    背景杂谈 不知道为什么,可能脑袋一下放空了,一不小心就想到了大约2年前,在某个知名的宝典公司面试中,遇到了一道手撕代码题,和多年前的google的那道螺旋遍历数据有异曲同工之妙.现脑洞大开,想写下与大 ...

  6. 前端date format_前端面试-手撕代码篇

    前言 在前端面试有一个非常重要的环节,也是面试者最担心的一个环节.对"手撕代码"的考察需要面试者平时总结和积累(临时抱佛脚是不好使的),在这里笔者就自己如何攻破"手撕代码 ...

  7. 秋招总结:遇到的手撕代码题

    2020年秋招总结:遇到的手撕代码题 跟谁学 一面:求连续子数组的最大和(力扣 53) [思路:力扣系列略,题解区都比我讲得好] 二面:翻转字符串中的每个单词(简单题,比较常见,没去找对应的原题) [ ...

  8. 【数字IC手撕代码】Verilog奇数分频|题目|原理|设计|仿真(三分频,五分频,奇数分频及特殊占空比)

    芯片设计验证社区·芯片爱好者聚集地·硬件相关讨论社区·数字verifier星球 四社区联合力荐!近500篇数字IC精品文章收录! [数字IC精品文章收录]学习路线·基础知识·总线·脚本语言·芯片求职· ...

  9. 数字IC手撕代码-兆易创新笔试真题

    前言: 本专栏旨在记录高频笔面试手撕代码题,以备数字前端秋招,本专栏所有文章提供原理分析.代码及波形,所有代码均经过本人验证. 目录如下: 1.数字IC手撕代码-分频器(任意偶数分频) 2.数字IC手 ...

  10. 【2023校招刷题】笔试及面试中常考知识点、手撕代码总结

    文章目录 一.笔试/面试常考知识点 二.面试常考手撕代码 2.1.基本电路设计 2.2.复杂电路设计 2.3.跨时钟域设计 一.笔试/面试常考知识点 奇.偶.小数分频 [Verilog基础]分频器实现 ...

最新文章

  1. 【计算理论】计算复杂性 ( 多项式等价 | P 类 | 丘奇-图灵论题延伸 )
  2. Ulua_toLua_基本案例(八)_LuaAccessingArray
  3. MySQL5 基础语法与操作
  4. C 实现 删除字符串空白符的函数 strtrim
  5. java 项目启动初始化_Spring Boot解决项目启动时初始化资源的方法
  6. .net Compact Framework 程序设计起步(智能设备的程序设计)
  7. 木兰编程语言重现:引用本地木兰模块;模拟凑十法加法
  8. SSM-物流管理常见问题4 前端向后端传递数据
  9. android 文件加密源码
  10. 公募权益类基金投资者盈利洞察报告
  11. html中span隐藏属性,JS 如果改变span标签的是否隐藏属性
  12. 幅相曲线渐近线_对数幅频特性渐近线的绘制
  13. sundancest201驱动_驱动支持列表
  14. 一张图看懂零维到十维空间
  15. MFC编辑框数据读写
  16. 抖音短视频运营借势热点:有哪些热点渠道,有哪些热点改编的套路。
  17. excel表格如何打斜杠
  18. Powershell 过火绒免杀上线
  19. 麦克斯韦方程组,史上最牛逼公式之一
  20. 毕业设计 STM32空气质量检测仪 - 单片机 嵌入式

热门文章

  1. 专业的幼儿园设计公司是怎样的?
  2. 怎样将dwg转换成pdf格式?
  3. WPS表格在一列里批量导入图片
  4. 有哪些gif动画制作软件比较好用?
  5. 第18课:生活中的外观模式——学妹别慌,学长帮你
  6. 2.2 多项式乘法与加法运算(线性结构,C)
  7. 配置IIS服务器环境--win10
  8. MySQL数据库——MySQL下载安装
  9. 车载儿童滞留检测系统
  10. C++写个三维模型展UV