我为了看 每个epoch 的平均loss

新建一个 list []
for 每个 step将 该 step 的 loss 放在上边那个list中
打印那个list的 均值

很麻烦,每次都得重新,而且 不优雅 ,主要是 不优雅

和上一篇氵文一样,本文也是从 那个baseline 中扒下来的:

定义一个 Counter 类

用于记录他的均值、次数和总和
(当然你也可以加一个 last_num 用于记录上一个值)

class Counter:def __init__(self):self.count, self.sum, self.avg = 0, 0, 0returndef update(self, value, num_updata=1):self.count += num_updataself.sum += value * num_updataself.avg = self.sum / self.countreturndef clear(self):self.count, self.sum, self.avg = 0, 0, 0return

用法:

# 在大循环外部先初始化这个类
loss_recorder = Counter()for epoch in range(EPOCH):for batch in batch_loader:out = model(batch)loss = Loss(out, gt)# 这里更新那个对象loss_recorder.update(loss) # 当这里的loss你要转化成 float 的optimizer.zero_grad() # 优化器清空loss.backward() # 反向求导optimizer.step() # 更新优化器# 每个 epoch 打印一下这个lossprint(loss_recorder.avg)

(你大概会说,就这,就这?? 额,好吧,感觉可能确实,就这,捂脸笑哭)

每次跑完,loss 就没了,我想保存下来,那么咱们用这个 Loss_Saver

class Loss_Saver:def __init__(self, moving=False):self.loss_list, self.last_loss = [], 0.0self.moving = moving  # 是否进行滑动平均操作def updata(self, value):# 只有进行滑动平均时,才会用到 self.last_loss if not self.moving:self.loss_list += [value]elif not self.loss_list:self.loss_list += [value]self.last_loss = valueelse:update_val = self.last_loss * 0.9 + value * 0.1self.loss_list += [[update_val]]self.last_loss = update_valreturndef loss_drawing(self, root_file, encoding='gbk'):# 这个用于在指定位置保存 loss 指标loss_array = np.array(self.loss_list)colname = ['loss']listPF = pd.DataFrame(columns=colname, data=loss_array)listPF.to_csv(f'{root_file}loss.csv', encoding=encoding)

这个类,用于在指定位置保存结果,用的时候这样用:

# 在训练的for循环开始之前,先定义一个loss保存类
losssaver, max_acc = Loss_Saver(), 0.0

然后就是训练中途,保存一下loss

 for epoch in range(opt.max_epoch):model.train()epoch_loss = Counter()train_dataset.update_num_epoch(epoch)for batch_id, batch in tqdm(enumerate(dataloader):inputs, target = batchpred = model(inputs)loss = loss_fn(pred, target)epoch_loss.updata(float(loss.item())) # 这里是之前咱们用的那个loss(Counter类)optimizer.zero_grad()loss.backward()optimizer.step()losssaver.updata(epoch_loss.avg)  # <----------- 这里就是保存loss的部分

在训练完毕后,把刚刚的loss保存一下:

root_exp_file = "model"
losssaver.loss_drawing(f'{root_exp_file}/{name_exp}/')
logger.info('finish training!')   # 顺便打印一个 finish training hhh(上一篇博客的)

这里这个 name_exp 变量也值得说叨说叨,一般来说,我们每次运行这个train文件,loss每次都会被覆盖,在懒得想名字的情况下,我们直接给他赋值一个随机的名字,时间一长(一般我一小时就忘了)诶,哪个是哪个的loss来着hhh

于是可以借鉴这个baseline中起名字的方法,直接给文件夹的名字赋值为 当前的日期和时间

from datetime import datetimetime = datetime.now()
name_exp = f'{str(time.month).zfill(2)}{str(time.day).zfill(2)}_{str(time.hour).zfill(2)}' \f'{str(time.minute).zfill(2)}'
>>> name_exp
0424_1636

--------------- 完 ---------------

[每日一氵] Python 训练过程中,如何优雅的保存loss相关推荐

  1. (转)YOLO-V3可视化训练过程中的参数,绘制loss、IOU、avg Recall等的曲线图

    https://blog.csdn.net/qq_34806812/article/details/81459982 看了好几个博客,发现了些问题,有些博客是有bug的,此博客亲测无误. 查看全文 h ...

  2. 【python处理过程中的数据另保存为CSV文件】

    ##保存csv文件1 dataframe = pd.DataFrame({'ID':test_index,'PRICE': y_pred}) # dataframe = pd.DataFrame({' ...

  3. DeepLearning tutorial(2)机器学习算法在训练过程中保存参数

    FROM: http://blog.csdn.net/u012162613/article/details/43169019 DeepLearning tutorial(2)机器学习算法在训练过程中保 ...

  4. Python爬虫过程中验证码识别的三种解决方案

    在Python爬虫过程中,有些网站需要验证码通过后方可进入网页,目的很简单,就是区分是人阅读访问还是机器爬虫.验证码问题看似简单,想做到准确率很高,也是一件不容易的事情.为了更好学习爬虫,后续推文中将 ...

  5. 学习python/pytorch过程中遇到的知识点

    Pytorch torch.backends.cudnn.deterministic 和 torch.backends.cudnn.benchmark 这两个参数,用于固定算法,使每次运行结果都一样. ...

  6. 在Caffe的训练过程中打印验证集的预测结果

    起因:Caffe里的GoogLeNet Inception V1只能输出对应于三个loss的accuracy,我想计算precision,recall和F1-measure.但是调用caffe的Pyt ...

  7. 神经网络训练过程中出现loss为nan,神经元坏死

    最近在手撸Tensorflow2版本的Faster RCNN模型,稍后会进行整理.但在准备好了模型和训练数据之后的训练环节中出现了大岔子,即训练过程中loss变为nan.nan表示not a numb ...

  8. Pytorch在训练过程中常见的问题

    1 Input type (CUDAFloatTensor) and weight type (CPUFloatTensor) should be the same 仔细看错误信息,CUDA和CPU, ...

  9. 理解YOLOv2训练过程中输出参数含义

    转载自https://blog.csdn.net/dcrmg/article/details/78565440 原英文地址: https://timebutt.github.io/static/und ...

  10. dqn在训练过程中loss越来越大_用DQN算法玩FlappyBird

    DQN算法可以用于解决离散的动作问题,而FlappyBird的操作正好是离散的. FlappyBird的游戏状态一般可以通过图像加卷积神经网络(CNN)来进行强化学习.但是通过图像分析会比较麻烦,因为 ...

最新文章

  1. Flask-请求上下文
  2. C++ 读入优化与输出优化 模板
  3. 快速发包突破ARP防火墙思路
  4. springcloud微服务实战--笔记--1、基础知识
  5. 转载 从SRAM中读写一个数据问题——Verilog
  6. 如何在excel不同的工作表之间使用数据有效性?
  7. 数据结构(十五)dijkstra单源最短路径
  8. ubuntu阿里云快速下载
  9. Deformation Transfer for Triangle Meshes
  10. json,pickle,shelve序列化和反序列化
  11. 程序员最想得到的十大证件,你最想得到哪个?
  12. 【Vue】Nodejs下载与安装
  13. 根据消费定额生成菜单的算法(原创)
  14. 无法修改计算机时间权限,无法修改系统时间怎么办
  15. C语言实现客房管理系统
  16. python之迭代器和生成器全解--包含实现原理及应用场景
  17. 【BFS】lydsy3161 孤舟蓑笠翁
  18. 初识-Python-day03
  19. 日常计算机网络基础练习题(每天进步一点点系列)
  20. 小说文学行业之“盛大文学”

热门文章

  1. Educational Codeforces Round 101 (Rated for Div. 2)
  2. 外贸常用术语_外贸常用术语大全
  3. base_url 是什么
  4. 链接mysql 504_phpMyAdmin错误代码:504 MySQL查询
  5. mq使用replyto队列进行消息回复
  6. android手机扩容软件,Android手机 6.0 + TF卡 扩容新选择
  7. html里 alt属性什么意思,HTML
  8. [转载] 柴静《看见》新书发布会
  9. 利率掉期(利率互换)的解释
  10. 对Orders订单表中的常见统计查询