本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善。

一、前言

训练深度学习模型,就像“炼丹”,模型可能需要训练很多天。

我们不可能像「太上老君」那样,拿着浮尘,24 小时全天守在「八卦炉」前,更何况人家还有炼丹童、天兵天将,轮流值守。

人手不够,“法宝”来凑。

本文就盘点一下,我们可以使用的「炼丹法宝」。

PS:文中出现的所有代码,均可在我的 Github 上下载:点击查看

二、初级“法宝”,sys.stdout

训练模型,最常看的指标就是 Loss。我们可以根据 Loss 的收敛情况,初步判断模型训练的好坏。

如果,Loss 值突然上升了,那说明训练有问题,需要检查数据和代码。

如果,Loss 值趋于稳定,那说明训练完毕了。

观察 Loss 情况,最直观的方法,就是绘制 Loss 曲线图。

通过绘图,我们可以很清晰的看到,左图还有收敛空间,而右图已经完全收敛。

通过 Loss 曲线,我们可以分析模型训练的好坏,模型是否训练完成,起到一个很好的“监控”作用。

绘制 Loss 曲线图,第一步就是需要保存训练过程中的 Loss 值。

一个最简单的方法是使用,sys.stdout 标准输出重定向,简单好用,实乃“炼丹”必备“良宝”。

import os
import sys
class Logger():def __init__(self, filename="log.txt"):self.terminal = sys.stdoutself.log = open(filename, "w")def write(self, message):self.terminal.write(message)self.log.write(message)def flush(self):passsys.stdout = Logger()print("Jack Cui")
print("https://cuijiahua.com")
print("https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA")

代码很简单,创建一个 log.py 文件,自己写一个 Logger 类,并采用 sys.stdout 重定向输出。

在 Terminal 中,不仅可以使用 print 打印结果,同时也会将结果保存到 log.txt 文件中。

运行 log.py,打印 print 内容的同时,也将内容写入了 log.txt 文件中。

使用这个代码,就可以在打印 Loss 的同时,将结果保存到指定的 txt 中,比如保存上篇文章训练 UNet 的 Loss。

三、中级“法宝”,matplotlib

Matplotlib 是一个 Python 的绘图库,简单好用。

简单几行命令,就可以绘制曲线图、散点图、条形图、直方图、饼图等等。

在深度学习中,一般就是绘制曲线图,比如 Loss 曲线、Acc 曲线。

举一个,简单的例子。

使用 sys.stdout 保存的 train_loss.txt,绘制 Loss 曲线。

train_loss.txt 下载地址:点击查看

思路非常简单,读取 txt 内容,解析 txt 内容,使用 Matplotlib 绘制曲线。

import matplotlib.pyplot as plt
# Jupyter notebook 中开启
# %matplotlib inline
with open('train_loss.txt', 'r') as f:train_loss = f.readlines()train_loss = list(map(lambda x:float(x.strip()), train_loss))
x = range(len(train_loss))
y = train_loss
plt.plot(x, y, label='train loss', linewidth=2, color='r', marker='o', markerfacecolor='r', markersize=5)
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.legend()
plt.show()

指定 x 和 y 对应的值,就可以绘制。

是不是很简单?

关于 Matplotlib 更多的详细教程,可以查看官方手册:点击查看

四、中级“法宝”,Logging

说到保存日志,那不得不提 Python 的内置标准模块 Logging,它主要用于输出运行日志,可以设置输出日志的等级、日志保存路径、日志文件回滚等,同时,我们也可以设置日志的输出格式。

import loggingdef get_logger(LEVEL, log_file = None):head = '[%(asctime)-15s] [%(levelname)s] %(message)s'if LEVEL == 'info':logging.basicConfig(level=logging.INFO, format=head)elif LEVEL == 'debug':logging.basicConfig(level=logging.DEBUG, format=head)logger = logging.getLogger()if log_file != None:fh = logging.FileHandler(log_file)logger.addHandler(fh)return loggerlogger = get_logger('info')logger.info('Jack Cui')
logger.info('https://cuijiahua.com')
logger.info('https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA')

只需要几行代码,进行一个简单的封装使用。使用函数 get_logger 创建一个级别为 info 的 logger,如果指定 log_file,则会对日志进行保存。

logging 默认支持的日志一共有 5 个等级:

日志级别等级 CRITICAL > ERROR > WARNING > INFO > DEBUG。

默认的日志级别设置为 WARNING,也就是说如果不指定日志级别,只会显示大于等于 WARNING 级别的日志。

例如:

import logging
logging.debug("debug_msg")
logging.info("info_msg")
logging.warning("warning_msg")
logging.error("error_msg")
logging.critical("critical_msg")

运行结果:

WARNING:root:warning_msg
ERROR:root:error_msg
CRITICAL:root:critical_msg

可以看到 info 和 debug 级别的日志不会输出,默认的日志格式也比较简单。

默认的日志格式为日志级别:Logger名称:用户输出消息

当然,我们可以通过,logging.basicConfig 的 format 参数,设置日志格式。

字段有很多,可谓应有尽有,足以满足我们定制化的需求。

五、高级“法宝”,TensorboardX

上文介绍的“法宝”,并非针对深度学习“炼丹”使用的工具。

而 TensorboardX 则不同,它是专门用于深度学习“炼丹”的高级“法宝”。

早些时候,很多人更喜欢用 Tensorflow 的原因之一,就是 Tensorflow 框架有个一个很好的可视化工具 Tensorboard。

Pytorch 要想使用 Tensorboard 配置起来费劲儿不说,还有很多 Bug。

Pytorch 1.1.0 版本发布后,打破了这个局面,TensorBoard 成为了 Pytorch 的正式可用组件。

在 Pytorch 中,这个可视化工具叫做 TensorBoardX,其实就是针对 Tensorboard 的一个封装,使得 PyTorch 用户也能够调用 Tensorboard。

TensorboardX 安装也非常简单,使用 pip 即可安装。

pip install tensorboardX

tensorboardX 使用也很简单,编写如下代码。

from tensorboardX import SummaryWriter# 创建 writer1 对象
# log 会保存到 runs/exp 文件夹中
writer1 = SummaryWriter('runs/exp')# 使用默认参数创建 writer2 对象
# log 会保存到 runs/日期_用户名 格式的文件夹中
writer2 = SummaryWriter()# 使用 commet 参数,创建 writer3 对象
# log 会保存到 runs/日期_用户名_resnet 格式的文件中
writer3 = SummaryWriter(comment='_resnet')

使用的时候,创建一个 SummaryWriter 对象即可,以上展示了三种初始化 SummaryWriter 的方法:

  • 提供一个路径,将使用该路径来保存日志
  • 无参数,默认将使用 runs/日期_用户名 路径来保存日志
  • 提供一个 comment 参数,将使用 runs/日期_用户名+comment 路径来保存日志

运行结果:

有了 writer 我们就可以往日志里写入数字、图片、甚至声音等数据。

数字 (scalar)

这个是最简单的,使用 add_scalar 方法来记录数字常量。

add_scalar(tag, scalar_value, global_step=None, walltime=None)

总共 4 个参数。

  • tag (string): 数据名称,不同名称的数据使用不同曲线展示
  • scalar_value (float): 数字常量值
  • global_step (int, optional): 训练的 step
  • walltime (float, optional): 记录发生的时间,默认为 time.time()

需要注意,这里的 scalar_value 一定是 float 类型,如果是 PyTorch scalar tensor,则需要调用 .item() 方法获取其数值。我们一般会使用 add_scalar 方法来记录训练过程的 loss、accuracy、learning rate 等数值的变化,直观地监控训练过程。

运行如下代码:

from tensorboardX import SummaryWriter
writer = SummaryWriter('runs/scalar_example')
for i in range(10):writer.add_scalar('quadratic', i**2, global_step=i)writer.add_scalar('exponential', 2**i, global_step=i)
writer.close()

通过 add_scalar 往日志里写入数字,日志保存到 runs/scalar_example中,writer 用完要记得 close,否则无法保存数据。

在 cmd 中使用如下命令:

tensorboard --logdir=runs/scalar_example --port=8088

指定日志地址,使用端口号,在浏览器中,就可以使用如下地址,打开 Tensorboad。

http://localhost:8088/

省去了我们自己写代码可视化的麻烦。

图片 (image)

使用 add_image 方法来记录单个图像数据。注意,该方法需要 pillow 库的支持

add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')

参数:

  • tag (string):数据名称
  • img_tensor (torch.Tensor / numpy.array):图像数据
  • global_step (int, optional):训练的 step
  • walltime (float, optional):记录发生的时间,默认为 time.time()
  • dataformats (string, optional):图像数据的格式,默认为 'CHW',即 Channel x Height x Width,还可以是 'CHW'、'HWC' 或 'HW' 等

我们一般会使用 add_image 来实时观察生成式模型的生成效果,或者可视化分割、目标检测的结果,帮助调试模型。

from tensorboardX import SummaryWriter
from urllib.request import urlretrieve
import cv2urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/0.png',filename = '1.jpg')
urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/1.png',filename = '2.jpg')
urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/2.png',filename = '3.jpg')writer = SummaryWriter('runs/image_example')
for i in range(1, 4):writer.add_image('UNet_Seg',cv2.cvtColor(cv2.imread('{}.jpg'.format(i)), cv2.COLOR_BGR2RGB),global_step=i,dataformats='HWC')
writer.close()

代码就是下载上篇文章数据集里的三张图片,然后使用 Tensorboard 可视化处理来,使用 8088 端口开打 Tensorboard:

tensorboard --logdir=runs/image_example --port=8088

运行结果:

试想一下,一边训练,一边输出图片结果,是不是很酸爽呢?

Tensorboard 中常用的 Scalar 和 Image,直方图、运行图、嵌入向量等,可以查看官方手册进行学习,方法都是类似的,简单好用。

官方文档:点击查看

六、总结

工欲善其事,必先利其器。

本文讲解了深度学习中,常用的“炼丹法宝”的使用方法,sys.stdout、matplotlib、logging、tensorboardX 你更喜欢哪一款?

点赞再看,养成习惯,微信公众号搜索【JackCui-AI】关注一个在互联网摸爬滚打的潜行者

Pytorch深度学习实战教程(四):必知必会的炼丹法宝相关推荐

  1. Pytorch 深度学习实战教程(二):UNet语义分割网络

    本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善. 一 ...

  2. Pytorch深度学习实战教程(二):UNet语义分割网络

    1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 如果不了解语义分割原理以及开发环境的搭建,请看该系列教程的上一篇文章< ...

  3. Pytorch深度学习实战教程:UNet语义分割网络

    1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 本文的开发环境如下: 开发环境:Windows 开发语言:Python3. ...

  4. Pytorch 深度学习实战教程(六):仝卓自爆,快本打码。

    本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善. 一 ...

  5. Pytorch深度学习实战教程(一):语义分割基础与环境搭建

    Pytorch的基本使用&&语义分割算法讲解 先从最简单的语义分割基础与开发环境搭建开始讲解. 二.语义分割 语义分割是什么? 语义分割(semantic segmentation) ...

  6. Pytorch 深度学习实战教程:今天,你垃圾分类了吗?

    1 垃圾分类 还记得去年,上海如火如荼进行的垃圾分类政策吗? 2020年5月1日起,北京也开始实行「垃圾分类」了! 北京的垃圾分类标准与上海略有差别,垃圾分为厨余垃圾.可回收物.有害垃圾和其他垃圾四大 ...

  7. Pytorch深度学习实战教程:语义分割基础与环境搭建

    一.前言 许久没有更新技术博文了,给自己挖一个新坑:语义分割系列文章. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 先从最简单的语义分割基础与开发环境搭建开始讲解. 二.语义分割 ...

  8. 【Pytorch】Pytorch深度学习实战教程:超分辨率重建AI与环境搭建

    一.基础开发环境搭建 1)cuda安装 需要根据自己的显卡的型号选择支持的CUDA版本 显卡驱动查看: 鼠标右键 CUDA安装版本查看:https://docs.nvidia.com/cuda/cud ...

  9. 实战例子_Pytorch官方力荐新书《Pytorch深度学习实战指南》pdf及代码分享

    PyTorch是目前非常流行的机器学习.深度学习算法运算框架.它可以充分利用GPU进行加速,可以快速的处理复杂的深度学习模型,并且具有很好的扩展性,可以轻松扩展到分布式系统.PyTorch与Pytho ...

  10. pytorch深度学习实战——预训练网络

    来源:<Pytorch深度学习实战>,2.1,一个识别图像主体的预训练网络 from torchvision import models from torchvision import t ...

最新文章

  1. oss客户端工具_阿里云服务器ECS上使用ossfs工具挂载阿里云OSS存储
  2. 第一次使用aspnet_compiler失败记录
  3. magento mysql4-install_Magento
  4. 【做题记录】Codeforces做题记录
  5. python画图标题为蓝色_python绘制语谱图怎么设置成黄蓝色
  6. 计算机网络标准体系,计算机网络标准体系结构实验报告.doc
  7. [转]JavaScript事件(Event)
  8. 【华为云技术分享】使用pdb调试python代码的方法
  9. linux下的工作目录切换实现
  10. 南宁公交有两个应用付费通道,互不通用
  11. VMware 虚拟机NAT模式下却没有网
  12. 字典生成工具 -- pydictor
  13. 微信24小时到账_最新微信转账延迟24小时到账骗局
  14. 联想笔记本fn键linux,Linux 系统下笔记本电脑的 Fn 键失效
  15. Unity获取摄像机的视口区域(透视相机模式)
  16. 英语面试技巧以及准备工作
  17. html把图片放到文章右边,怎么在文章中把图片放在文字的左边、右边、中 – 手机爱问...
  18. ubuntu双系统时间不一致现象
  19. 2022全新好玩的恶搞屁声音小程序源码+UI不错
  20. Spring Boot事务

热门文章

  1. mysql批量导出工具_sql数据库批量导出|
  2. Sonic一站式开源分布式集群云真机测试平台阶段性使用总结
  3. android有道翻译api,有道智云自然翻译服务API
  4. 使用HTML制作在线电子时钟,用HTML5制作数字时钟的教程
  5. 2.7 汽车之家口碑爬虫
  6. 优秀的程序员是如何利用工具来提升工作效率的?
  7. 多线程(Thread的类的运用-Runnable类的使用/多线程的注意点)
  8. 这些实用的WhatsApp工具,赶快用起来
  9. matlab 空间向量的夹角,空间两向量之间的旋转角如何求?角度范围在0-360°
  10. 根据旋转角计算欧拉角 (Computing Euler angles from a rotation matrix)