stable-baselines3学习之Tensorboard
stable-baselines3学习之Tensorboard系列
1.基本用法
要使用stable-baselines3的 Tensorboard,您只需将日志文件夹的位置传递给 RL 的agent:
from stable_baselines3 import A2Cmodel = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10_000)
您还可以在训练时定义自定义日志名称(默认为算法名称)
from stable_baselines3 import A2Cmodel = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10_000, tb_log_name="first_run")
# Pass reset_num_timesteps=False to continue the training curve in tensorboard
# By default, it will create a new curve
model.learn(total_timesteps=10_000, tb_log_name="second_run", reset_num_timesteps=False)
model.learn(total_timesteps=10_000, tb_log_name="third_run", reset_num_timesteps=False)
调用 learn 函数后,您可以使用以下 bash 命令在训练期间或之后监控 RL agent:
tensorboard --logdir ./a2c_cartpole_tensorboard/
注:要在该项目文件路径下运行这条命令
比如:
2.Logging More Values
使用callback可以容易的记录更多日志用Tensorboard,这里有一个简单的例子去记录额外的tensor和任意的scalar值:
import numpy as npfrom stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallbackmodel = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="./tmp/sac0/", verbose=1)class TensorboardCallback(BaseCallback):"""Custom callback for plotting additional values in tensorboard."""def __init__(self, verbose=0):super(TensorboardCallback, self).__init__(verbose)def _on_step(self) -> bool:# Log scalar value (here a random variable)value = np.random.random()self.logger.record('random_value', value)return Truemodel.learn(50000, callback=TensorboardCallback())
tensorboard --logdir ./tmp/sac0/
3.Logging Images
TensorBoard 支持定期记录图像数据,这有助于在训练期间的各个阶段评估agent。
以下是如何定期将图像渲染到 TensorBoard 的示例:
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import Imagemodel = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="./tmp/sac1/", verbose=1)class ImageRecorderCallback(BaseCallback):def __init__(self, verbose=0):super(ImageRecorderCallback, self).__init__(verbose)def _on_step(self):image = self.training_env.render(mode="rgb_array")# "HWC" specify the dataformat of the image, here channel last# (H for height, W for width, C for channel)# See https://pytorch.org/docs/stable/tensorboard.html# for supported formatsself.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))return Truemodel.learn(50000, callback=ImageRecorderCallback())
tensorboard --logdir ./tmp/sac1/
4.Logging Figures/Plots
TensorBoard 支持定期记录使用 matplotlib 创建的图形/绘图,这有助于在训练期间评估各个阶段的agent。
以下是如何在 TensorBoard 中定期存储绘图的示例:
import numpy as np
import matplotlib.pyplot as pltfrom stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import Figuremodel = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="./tmp/sac2/", verbose=1)class FigureRecorderCallback(BaseCallback):def __init__(self, verbose=0):super(FigureRecorderCallback, self).__init__(verbose)def _on_step(self):# Plot values (here a random variable)figure = plt.figure()figure.add_subplot().plot(np.random.random(3))# Close the figure after logging itself.logger.record("trajectory/figure", Figure(figure, close=True), exclude=("stdout", "log", "json", "csv"))plt.close()return Truemodel.learn(50000, callback=FigureRecorderCallback())
tensorboard --logdir ./tmp/sac1/
5.Logging Videos
TensorBoard 支持定期记录视频数据,这有助于在训练期间评估各个阶段的agent。
以下是如何显示一个episode并将生成的视频定期记录到 TensorBoard 的示例:
注:需安装moviepy
包
from typing import Any, Dictimport gym
import torch as thfrom stable_baselines3 import A2C
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.logger import Videoclass VideoRecorderCallback(BaseCallback):def __init__(self, eval_env: gym.Env, render_freq: int, n_eval_episodes: int = 1, deterministic: bool = True):"""Records a video of an agent's trajectory traversing ``eval_env`` and logs it to TensorBoard:param eval_env: A gym environment from which the trajectory is recorded:param render_freq: Render the agent's trajectory every eval_freq call of the callback.:param n_eval_episodes: Number of episodes to render:param deterministic: Whether to use deterministic or stochastic policy"""super().__init__()self._eval_env = eval_envself._render_freq = render_freqself._n_eval_episodes = n_eval_episodesself._deterministic = deterministicdef _on_step(self) -> bool:if self.n_calls % self._render_freq == 0:screens = []def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None:"""Renders the environment in its current state, recording the screen in the captured `screens` list:param _locals: A dictionary containing all local variables of the callback's scope:param _globals: A dictionary containing all global variables of the callback's scope"""screen = self._eval_env.render(mode="rgb_array")# PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image conventionscreens.append(screen.transpose(2, 0, 1))evaluate_policy(self.model,self._eval_env,callback=grab_screens,n_eval_episodes=self._n_eval_episodes,deterministic=self._deterministic,)self.logger.record("trajectory/video",Video(th.ByteTensor([screens]), fps=40),exclude=("stdout", "log", "json", "csv"),)return Truemodel = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="./tmp/runs/", verbose=1)
video_recorder = VideoRecorderCallback(gym.make("CartPole-v1"), render_freq=5000)
model.learn(total_timesteps=int(5e4), callback=video_recorder)
tensorboard --logdir ./tmp/runs/
stable-baselines3学习之Tensorboard相关推荐
- 【深度学习】Tensorboard可视化模型训练过程和Colab使用
[深度学习]Tensorboard可视化模型训练过程和Colab使用 文章目录 1 概述 2 手撸代码实现 3 Colab使用3.1 详细步骤3.2 Demo 4 总结 1 概述 在利用TensorF ...
- AIGC - Stable Diffusion 学习踩坑实录总结
学习路径 淘宝拼多多找教程就没必要了,我踩过坑,还跟店主纠缠过,付了钱,不过都退了,淘宝平台介入,啥都能解决,现在卖得都是搬运的 B 站里面的大佬视频,我目前正在不断关注 B 站大佬的各种课程,探索更 ...
- Tensorflow学习教程------tensorboard网络运行和可视化
tensorboard可以将训练过程中的一些参数可视化,比如我们最关注的loss值和accuracy值,简单来说就是把这些值的变化记录在日志里,然后将日志里的这些数据可视化. 首先运行训练代码 #co ...
- 【深度学习】tensorboard中的图片放到论文中
1. 利用tensorboard中的导出数据功能 选中左上角的标签,然后选择csv格式的数据下载即可. 左边的smoothing 数值不影响导出数据的大小. 2.tensorboard 图下标签介绍: ...
- 深度学习(33)随机梯度下降十一: TensorBoard可视化
深度学习(33)随机梯度下降十一: TensorBoard可视化 Step1. run listener Step2. build summary Step3.1 fed scalar(监听标量) S ...
- 极客学院 TensorBoard:可视化学习
TensorBoard:可视化学习 TensorBoard 涉及到的运算,通常是在训练庞大的深度神经网络中出现的复杂而又难以理解的运算. 为了更方便 TensorFlow 程序的理解.调试与优化,我们 ...
- pytorch深度学习实战一书,tensorboard可视化踩坑
书评&踩坑 @[TOC](书评&踩坑) `提示:纯个人观点,仅供参考` 前言 一.源码学习,又是版本问题(省略内心独白...) 二.步骤 1.安装tensorflow 2.思考,看代码 ...
- 最NB强化学习路线图
人工智能是21世纪最激动人心的技术之一.人工智能,就是像人一样的智能,而人的智能包括感知.决策和认知(从直觉到推理.规划.意识等).其中,感知解决what,深度学习已经超越人类水平:决策解决how,强 ...
- 天下苦深度强化学习久矣,这有一份训练与调参技巧手册
©作者 | 申岳 单位 | 北京邮电大学 研究方向 | 机器人学习 天下苦 RL 久矣,其中最苦的地方莫过于训练和调参了,人人欲"调"之而后快. 在此为 RL 社区贡献一点绵薄之力 ...
最新文章
- 食品行业特点及SAP解决方案探讨
- rsync 模块同步失败
- Javascript基本概念之数据类型
- 复旦 哈工大计算机学院,国内高校中哈工大和上交复旦在一个档次吗?从这些方面看你就知道...
- Struts2的CRUD
- C#实践设计模式原则SOLID
- 设计模式(一)Chain Of Responsibility责任链模式
- 【OJ】洛谷顺序结构题单题解锦集
- 机器学习(六)支持向量机svm初级篇
- windows做ntp server,linux做ntp client端的配置方法
- lc300.最长递增子序列
- 掌握好这几点方法学习Linux,一定比别人更快入门运维!
- Entity Framework Plus
- 应用层TCP三次握手及各种协议简介telnet【笔记】
- 极点五笔特殊符号输入方法
- 分数加减法—两个分数的加减法
- 阿里云服务器操作系统怎么选择?
- 2021年汽车修理工(中级)考试题库及汽车修理工(中级)实操考试视频
- 电脑如何打开EPUB文件
- 读书笔记 -- 算法入门
热门文章
- 硅谷最凶猛的云计算“独角兽”:Snowflake造富神话 能否在中国复制?| 硅谷速递...
- 视频号迎来重大更新,这些功能久等了
- win七系统如何卸载MySQL_Win7完全卸载sql2005和删除sql2005的方法
- 2022年下半年网络规划设计师考试下午真题
- STL源码剖析——空间配置器
- 面试官最常问的面试题及答案,每1题都很经典
- 基于深度学习的人脸性别识别系统(含UI界面,Python代码)
- 华为路由器console口加密 telnet远程登录 DHCP server在路由器中的两种写法
- Yii2本身自带实现用户注册、登录
- 如何修复“Windows/System32/Config/System中文件丢失或损坏”故障