Python 绘制 loss 曲线、准确率曲线

使用 python 绘制网络训练过程中的的 loss 曲线以及准确率变化曲线,这里的主要思想就时先把想要的损失值以及准确率值保存下来,保存到 .txt 文件中,待网络训练结束,我们再拿这存储的数据绘制各种曲线。

其大致步骤为:数据读取与存储 - > loss曲线绘制 - > 准确率曲线绘制

一、数据读取与存储部分

我们首先要得到训练时的数据,以损失值为例,网络每迭代一次都会产生相应的 loss,那么我们就把每一次的损失值都存储下来,存储到列表,保存到 .txt 文件中。保存的文件如下图所示:

[1.3817585706710815, 1.8422836065292358, 1.1619832515716553, 0.5217241644859314, 0.5221078991889954, 1.3544578552246094, 1.3334463834762573, 1.3866571187973022, 0.7603049278259277]

上图为部分损失值,根据迭代次数而异,要是迭代了1万次,这里就会有1万个损失值。
而准确率值是每一个 epoch 产生一个值,要是训练100个epoch,就有100个准确率值。

(那么问题来了,这里的损失值是怎么保存到文件中的呢? 很少有人讲这个,也有一些小伙伴们来咨询,这里就统一记录一下,包括损失值和准确率值。)

首先,找到网络训练代码,就是项目中的 main.py,或者 train.py ,在文件里先找到训练部分,里面经常会有这样一行代码:

for epoch in range(resume_epoch, num_epochs):   # 就是这一行####...loss = criterion(outputs, labels.long())              # 损失样例...epoch_acc = running_corrects.double() / trainval_sizes[phase]    # 准确率样例...###

从这一行开始就是训练部分了,往下会找到类似的这两句代码,就是损失值和准确率值了。

这时候将以下代码加入源代码就可以了:

train_loss = []
train_acc = []
for epoch in range(resume_epoch, num_epochs):          # 就是这一行###...loss = criterion(outputs, labels.long())           # 损失样例train_loss.append(loss.item())                     # 损失加入到列表中...epoch_acc = running_corrects.double() / trainval_sizes[phase]    # 准确率样例train_acc.append(epoch_acc.item())                 # 准确率加入到列表中...
with open("./train_loss.txt", 'w') as train_los:train_los.write(str(train_loss))with open("./train_acc.txt", 'w') as train_ac:train_ac.write(str(train_acc))

这样就算完成了损失值和准确率值的数据存储了!

二、绘制 loss 曲线

主要需要 numpy 库和 matplotlib 库,如果不会安装可以自行百度,很简单。

首先,将 .txt 文件中的存储的数据读取进来,以下是读取函数:

import numpy as np# 读取存储为txt文件的数据
def data_read(dir_path):with open(dir_path, "r") as f:raw_data = f.read()data = raw_data[1:-1].split(", ")   # [-1:1]是为了去除文件中的前后中括号"[]"return np.asfarray(data, float)

然后,就是绘制 loss 曲线部分:

if __name__ == "__main__":train_loss_path = r"E:\relate_code\Gaitpart-master\train_loss.txt"   # 存储文件路径y_train_loss = data_read(train_loss_path)        # loss值,即y轴x_train_loss = range(len(y_train_loss))            # loss的数量,即x轴plt.figure()# 去除顶部和右边框框ax = plt.axes()ax.spines['top'].set_visible(False)ax.spines['right'].set_visible(False)plt.xlabel('iters')    # x轴标签plt.ylabel('loss')     # y轴标签# 以x_train_loss为横坐标,y_train_loss为纵坐标,曲线宽度为1,实线,增加标签,训练损失,# 默认颜色,如果想更改颜色,可以增加参数color='red',这是红色。plt.plot(x_train_loss, y_train_loss, linewidth=1, linestyle="solid", label="train loss")plt.legend()plt.title('Loss curve')plt.show()

这样就算把损失图像画出来了!如下:

三、绘制准确率曲线

有了上面的基础,这就简单很多了。
只是有一点要记住,上面的x轴是迭代次数,这里的是训练轮次 epoch。

if __name__ == "__main__":train_acc_path = r"E:\relate_code\Gaitpart-master\train_acc.txt"   # 存储文件路径y_train_acc = data_read(train_acc_path)       # 训练准确率值,即y轴x_train_acc = range(len(y_train_acc))          # 训练阶段准确率的数量,即x轴plt.figure()# 去除顶部和右边框框ax = plt.axes()ax.spines['top'].set_visible(False)ax.spines['right'].set_visible(False)plt.xlabel('epochs')    # x轴标签plt.ylabel('accuracy')     # y轴标签# 以x_train_acc为横坐标,y_train_acc为纵坐标,曲线宽度为1,实线,增加标签,训练损失,# 增加参数color='red',这是红色。plt.plot(x_train_acc, y_train_acc, color='red',linewidth=1, linestyle="solid", label="train acc")plt.legend()plt.title('Accuracy curve')plt.show()

这样就把准确率变化曲线画出来了!如下:

以下是完整代码,以绘制准确率曲线为例,并且将x轴换成了iters,和损失曲线保持一致,供参考:

import numpy as np
import matplotlib.pyplot as plt# 读取存储为txt文件的数据
def data_read(dir_path):with open(dir_path, "r") as f:raw_data = f.read()data = raw_data[1:-1].split(", ")return np.asfarray(data, float)# 不同长度数据,统一为一个标准,倍乘x轴
def multiple_equal(x, y):x_len = len(x)y_len = len(y)times = x_len/y_leny_times = [i * times for i in y]return y_timesif __name__ == "__main__":train_loss_path = r"E:\relate_code\Gaitpart-master\file_txt\train_loss.txt"train_acc_path = r"E:\relate_code\Gaitpart-master\train_acc.txt"y_train_loss = data_read(train_loss_path)y_train_acc = data_read(train_acc_path)x_train_loss = range(len(y_train_loss))x_train_acc = multiple_equal(x_train_loss, range(len(y_train_acc)))plt.figure()# 去除顶部和右边框框ax = plt.axes()ax.spines['top'].set_visible(False)ax.spines['right'].set_visible(False)plt.xlabel('iters')plt.ylabel('accuracy')# plt.plot(x_train_loss, y_train_loss, linewidth=1, linestyle="solid", label="train loss")plt.plot(x_train_acc, y_train_acc,  color='red', linestyle="solid", label="train accuracy")plt.legend()plt.title('Accuracy curve')plt.show()

日常学习记录,一起交流讨论吧!侵权联系~

Python绘制loss曲线、准确率曲线相关推荐

  1. Caffe学习系列(19): 绘制loss和accuracy曲线

    转载自: Caffe学习系列(19): 绘制loss和accuracy曲线 - denny402 - 博客园 http://www.cnblogs.com/denny402/p/5110204.htm ...

  2. Caffe---Pycaffe 绘制loss和accuracy曲线

    Caffe---Pycaffe 绘制loss和accuracy曲线 <Caffe自带工具包---绘制loss和accuracy曲线>:可以看出使用caffe自带的工具包绘制loss曲线和a ...

  3. python画直方图成绩分析-使用Python绘制直方图和正态分布曲线

    本文主要介绍两个内容: 如何使用记事本生成包含某一数据集的CSV文件: 如何使用Python绘制给定数据集的直方图和正态分布曲线. 1. 使用记事本创建CSV文件 ① 新建一个文本文件,打开后输入数据 ...

  4. python导入数据画直方图加正态曲线_使用Python绘制直方图和正态分布曲线

    原博文 2020-03-20 22:01 − 本文主要介绍两个内容: 如何使用记事本生成包含某一数据集的CSV文件: 如何使用Python绘制给定数据集的直方图和正态分布曲线. 1. 使用记事本创建C ...

  5. Python用log文件绘制损失、准确率曲线

    一.导入包 from matplotlib import rcParams import matplotlib.pyplot as plt import re 二.读取文件 ##显示中文 rcPara ...

  6. Python绘制三次贝塞尔曲线

    对于贝塞尔曲线而言,其特点在于第一个控制点恰好是曲线的起点,最后一个控制点是曲线的终点,其他控制点并不在曲线上,而是起到控制曲线形状的作用.另外,曲线的起点处与前两个控制点构成的线段相切,而曲线的终点 ...

  7. python生成loss/acc训练曲线

    根据训练模型的工作日志生成loss(acc)曲线 准备数据 利用excel将工作日志中的iter.loss.acc分别提出来单独放置文本文件中,获得iter.txt.acc.txt和loss.txt ...

  8. python绘制曲线视频_使用Python绘制各种优美简单曲线

    matplotlib是著名的Python绘图库,它提供了一整套绘图API,十分适合交互式绘图.,解决数据分析和可视化问题,其实也是Python的拿手好戏.另外,在数据处理过程中会用到numpy. 例如 ...

  9. 神经网络训练时如何绘制loss的动态曲线

    在神经网络训练中,可以利用tensorboard进行查看loss曲线及graph图,但是比较麻烦,本人想在训练代码中加入一段代码,实现train_loss及val_loss的实时动态变化,方便观察损失 ...

最新文章

  1. cython安装、使用
  2. java b2b2c开源商城系统源码
  3. [剑指offer] 跳台阶
  4. AJAX应注意IIS有没有.ashx扩展
  5. mac地址 linux c api,如何使用C程序获取linux中接口的mac地址?
  6. Java监控工具VisualVM
  7. 《In Search of an Understandable Consensus Algorithm》翻译
  8. 接口 vs 抽象类 的区别
  9. mysql 事物隔离级别解读
  10. sqlalchemy mysql教程_SQLAlchemy 教程 —— 基础入门篇
  11. 三维观察---三维裁剪算法
  12. 电子邮件如何运行(MTA,MDA,MUA)
  13. 数学建模方法 — 【01】模糊数学
  14. 频谱泄露和吉布斯现象
  15. 如何管理一个外包团队小总结
  16. layui自定义新增tab页方法
  17. “二码合一”健康码和行程码一次出示即可
  18. 敲笨钟 (20 分)
  19. win7的终端服务器,win7系统远程提示终端服务器超出了最大允许连接的解决方法...
  20. Msfvenom编码免杀技术实现免杀实战

热门文章

  1. python支持哪些数据类型_Python支持的数据类型
  2. java计算机毕业设计健康食谱系统服务器端源码+mysql数据库+系统+lw文档+部署
  3. 介入治疗为肿瘤治疗带来生机
  4. Ubuntu上配置nginx及相关命令
  5. CIFilter 滤镜 ,分别有什么作用
  6. PHP游戏服务器的设计思路
  7. @service解决 error creating bean with name(XXX)的问题
  8. Go 语言GC原理概述
  9. EOJ 2849 成绩排序 C++
  10. 【总结】- 从 0 到 1 上手 Web Components 业务组件库开发