参考文档:https://mp.weixin.qq.com/s/UYnBRU2b0InzM9H1xl4b4g

在之前的第二篇笔记中,我们实现了一个 CNN 网络,在 mnist 上通过两个卷积层完成分类识别。但是在我们调试代码的过程中,其实往往会想要知道我们的网络训练过程中的效果变化,比如 loss 和 accuracy 的变化曲线。

当然,我们可以像前面的文章一样,将训练过程中的数据数据打印出来,但是一个是不够直观,另外一个是没有图形的表现力强。所以本篇笔记介绍了 tensorboard 来完成可视化的操作。

1、TensorBoard介绍

tensorboard 一开始是在 TensorFlow 中的可视化工具,它可以用来展示网络图、数据的处理流程、执行过程中的指标变化。特别是在训练网络的时候,网络参数的不同设置(比如:权重、偏置、卷积层数、全连接层数等)。

通俗点就是网络训练过程中的各种参数和指标的变化都可以展示成图表的形式,除此之外,还可以展示网络模型的结构,训练数据的照片等。

在 TensorFlow 中的优秀表现,使得 pytorch 从 1.2.0 版本开始,正式自带了 tensorboard。也就可以很方便的在 pytorch 训练中进行可视化了。

2、先run一个示例吧

按照我们的习惯,先找一个简单的例子 run 起来,再一步步去学习其中的原理吧。首先我们看一个官方文档的示例:

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
x = range(100)
for i in x:writer.add_scalar('y=2x', i * 2, i)
writer.close()

执行这一段代码后,tensorboard 会在 ‘./runs’ 文件夹下保存训练的日志。接下来我们仔细的分析一下这个例子,来帮助理解这里的具体执行过程:

  • 首先导入 SummaryWriter,这是一个类,需要实例化出一个对象,这里我们命名为 writer。初始化的参数我们最关心的一个就是 log 文件的地址。如果没有输入就默认为当前文件中的 runs 文件夹,如果没有 runs 文件夹则创建一个。
  • 接下来,我们看到在 for 循环中调用了 writer 的一个方法:add_scalar(),这个函数的参数定义如下:

def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):

  • 我们介绍一下这个方法比较重要的几个参数。这个方法的作用是将我们想要的数据存入 log 文件中,那么这段数据的标签(也就是我们画图时的 title)就是参数 tag,而 scalar_value 就是当前存入的数值(也就是画图时的 y 轴),而 global_step 就是我们在哪一步存入了一次数据(也就是画图时的 x 轴)。
  • 那么在这个循环中,我们可以看到存入的数据标签是 ‘y=2x’,而在第 i 步时存入的数值是 2*i,通过循环相当于将每一步时刻的 y 值都存了进去,也可以理解为将一幅图的每个点坐标存入了 log 文件中。
  • 最后使用 write.close() 终止掉这个对象。

然后我们在 terminal 调用下列命令:

这一步的思路相当于,使用 tensorboard 来调用目标的 log 文件中的数据进行可视化,后面也可以继续加参数 —port 来指定端口。我们这里没有指定,就会选择默认端口 6006.

然后我们将显示的这个地址:http://localhost:6006/ 复制到浏览器中,就可以看到如下结果:

这里我们可以看到,做出了一个 step 从 0 到 99 的图像,而纵轴则是在对应 step 的 value,最终的图像刚好是一个 y=2x 的斜线。

通过这个例子,我们知道了如何使用 tensorboard 完成一个完整的可视化流程。知道了关键的几个步骤的含义,那么接下来,我们就修改一下之前第二篇笔记的代码,来完成对一个 CNN 训练过程中数据的可视化。

3、可视化CNN训练数据

这里使用第二篇笔记中的代码进行修改,当时我们对代码训练过程中的 loss 进行了打印,每过 50 个 step 则计算一次在测试集上的 accuracy,然后打印训练集上的 loss 和测试集上的 accuracy。

前面的网络定义我们就不重复了,大家直接看 第二篇文章 里面的介绍。我们这里主要看一下训练过程中的代码:

writer = SummaryWriter('tb_mnist')
for epoch in range(EPOCH):for step, (b_x, b_y) in enumerate(train_loader):# print(b_x.shape); breakif cuda_gpu:b_x = b_x.cuda()b_y = b_y.cuda()output = cnn(b_x)loss = loss_func(output, b_y)optimizer.zero_grad()loss.backward()optimizer.step()if step % 50 == 0:test_output = cnn(test_x)pred_y = torch.max(test_output, 1)[1].dataif cuda_gpu:pred_y = pred_y.cpu().numpy()else:pred_y = pred_y.numpy()accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))print('Epoch: ', epoch, '| train loss: %.4f' % loss.data, '| test accuracy: %.2f' % accuracy)writer.add_scalar("Train/Accuracy", accuracy, step)writer.add_scalar("Train/Loss", loss.item(), step)

可以看到第一行同样是实例化了一个 writer,其余的都一样,直到下面最后两行代码做了修改。第一个是在每次打印的时候,将对应 step 的 accuracy 添加进了 Train/Accuracy 中;第二个是在这个打印的 for 循环的外面,也就是对每一个 step 都保存了当前的 loss 信息。

和上面一样,我们在命令行中执行 tensorboard 命令,指定好 log 文件的地址,在目标端口中就可以看到如下信息:

可以看到这样的两幅图,左边的 accuracy,右边的是 loss。而且可以看到左边的 accuracy 因为我们是每过 50 个 step 保存一次,所以数据明显有折痕。而 loss 是每个 step 都在打印,所以可以看到非常详细的描述了训练过程中 loss 的变化曲线。

在我们的模型训练过程中,就可以通过 add_scalar() 来描绘一些目标参数的变化过程。

4、图片和模型的可视化

介绍完了图表的绘制,我们再展示一个官方文档中的例子,来为大家进一步学习 tensorboard 的广泛用途。这里的例子分别是利用 tensorboard 展示图片和网络结构。先看一下官方的代码吧:

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms# Writer will output to ./runs/ directory by default
writer = SummaryWriter()transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST('mnist_train', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
model = torchvision.models.resnet50(False)
# Have ResNet model take in grayscale rather than RGB
model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
images, labels = next(iter(trainloader))grid = torchvision.utils.make_grid(images)
writer.add_image('images', grid, 0)
writer.add_graph(model, images)
writer.close()

有了前面的例子,我们就可以比较轻松的理解这段代码的意思了。首先数据选择了我们前面用的数据集 mnist,网络结构直接调用了 pytorch 的自带模型 resnet,通过这条命令调取:torchvision.models.resnet50()。

重点是最后一段的代码,grid 是用 torchvision 中集成的工具,将数据集中的一批照片读取进来,按网格状排列。照片的数量我们可以看到,前面 trainloader 设置了 batch_size = 64,也就是说 64 张照片。

然后开始用 writer 去调用对应的方法,图片用 add_image,模型图用 add_graph。接下来让我们去看一下效果:


这两张图分别展示了 image 和 graph 的效果,一个是 64 张训练时的照片进行展示,拼接成 8*8 的网格状;一个是 resnet50 的网络结构,这个网络结构还支持继续展开,只需要鼠标在上面双击即可,图中的情况就是我进行了适当的展开后的样子。

5、总结

今天的文章我们介绍了 tensorboard 的用途,然后通过几个例子,逐步深入了 tensorboard 的使用,当然它的功能远不止我们文中介绍的内容,但是作为入门已经够用了。如果大家还想再继续深入了解 tensorboard,可以访问官方文档:https://pytorch.org/docs/stable/tensorboard.html

官方文档中还有很多其它 tensorboard 的使用方法介绍,接口清晰而且都有示例。

希望今天的文章可以让大家了解到 tensorboard 的入门使用,在网络训练过程中,有效的可视化可以帮助我们快速定位 bug,更加深入的理解网络的训练效果。动手尝试一下吧~

pytorch学习笔记(4):tensorboard可视化相关推荐

  1. PyTorch学习笔记(七):PyTorch可视化

    PyTorch可视化 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一) ...

  2. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  3. PyTorch学习笔记(二)——回归

    PyTorch学习笔记(二)--回归 本文主要是用PyTorch来实现一个简单的回归任务. 编辑器:spyder 1.引入相应的包及生成伪数据 import torch import torch.nn ...

  4. 深度学习入门之PyTorch学习笔记:深度学习介绍

    深度学习入门之PyTorch学习笔记:深度学习介绍 绪论 1 深度学习介绍 1.1 人工智能 1.2 数据挖掘.机器学习.深度学习 1.2.1 数据挖掘 1.2.2 机器学习 1.2.3 深度学习 第 ...

  5. Pytorch学习笔记总结

    往期Pytorch学习笔记总结: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 Pytorch系列目录: PyTorch学习笔记( ...

  6. PyTorch学习笔记(六):PyTorch进阶训练技巧

    PyTorch实战:PyTorch进阶训练技巧 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: P ...

  7. PyTorch学习笔记(五):模型定义、修改、保存

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  8. PyTorch学习笔记(四):PyTorch基础实战

    PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...

  9. PyTorch学习笔记(三):PyTorch主要组成模块

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  10. PyTorch学习笔记(二):PyTorch简介与基础知识

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

最新文章

  1. 无异常日志,就不能排查问题了???
  2. 独家 | 将人们困于贫穷之中的隐藏算法战争即将到来
  3. Appium 常见API 四(三种等待方式)
  4. 【工具】更新最新esp8266库离线安装包3.0.1、ESP32库离线安装包1.0.6
  5. 一文带你认识:Liunx的历史
  6. FFmpeg Filter基本使用
  7. PHP获取文件夹内所有文件包括子目录文件的名称或路径
  8. 五分钟带你了解什么是PID模糊算法
  9. libxml2 not found
  10. 跟随企业数字化转型,FIT2CLOUD推演全栈云管平台
  11. Unity3D中使用Projector生成阴影
  12. python查看list的shape_列表list、数组np.array等的len,size,shape操作
  13. WebGL编程指南三:varying变量的使用和理解光栅化过程。
  14. torch 显存管理
  15. 从键盘输入一个不多于3位的正整数,要求:求出它是几位数;分别打印出每一位数字;按逆序打印出各位数字
  16. 哈里森,史上最具空间价值的钟表匠
  17. 计算机中关于数字的进制转换
  18. HTML5文本元素解析
  19. 蓝牙协议HFP(Hands-Free Profile)电话免提协议 Connection management 连接管理HFP SLC 的建立跟释放
  20. Scrapy爬取顶点小说网

热门文章

  1. linux 无线命令
  2. HashMap中hash(Object key)原理(hashcode >>> 16)
  3. Dubbo(八)使用配置类方式实现服务提供者消费者dubbo配置
  4. Netty 的核心组件
  5. rpc调试工具grpcui的安装使用
  6. Go语言的context包从放弃到入门
  7. 使用Laravel Eloquent ORM 时如何查询表中指定的字段 1
  8. sublime text 3 插件推荐?
  9. RocketMQ消费失败如何处理?如何保证消费消息的幂等性?
  10. 那些用Go实现的分布式事务框架