tensorboard是让torch可以使用tensorflow的web可视化工具,之前叫tensorboardX。
至于其他的介绍以及为什么需要,可自行百度。

简单的完整代码

1234567891011121314151617181920
# -*- coding: utf8 -*-#

import math

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('./run')

epochs = 100data_loader = range(1000)

total_loss = 0for epoch in range(epochs):for i, data in enumerate(data_loader):loss = math.sin(i)total_loss += loss# 打印每一个batch下的总losswriter.add_scalar('train_loss', total_loss, epoch * len(data_loader) + i)total_loss = 0

终端跑下面这条命令:

1
tensorboard --logdir=./runs

更完整的demo

其中数据集加载没有添加进来,引用代码可参考lstm_sent_polarity。

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
# Defined in Section 4.6.7

import torchfrom torch import nn, optimfrom torch.nn import functional as Ffrom torch.nn.utils.rnn import pad_sequence, pack_padded_sequencefrom torch.utils import tensorboardfrom torch.utils.data import Dataset, DataLoader# tqdm是一个Python模块,能以进度条的方式显式迭代的进度from tqdm.auto import tqdm

from utils import load_sentence_polarity

summary = tensorboard.SummaryWriter('./runs')

class LstmDataset(Dataset):def __init__(self, data):self.data = data

def __len__(self):return len(self.data)

def __getitem__(self, i):return self.data[i]

def collate_fn(examples):lengths = torch.tensor([len(ex[0]) for ex in examples])inputs = [torch.tensor(ex[0]) for ex in examples]targets = torch.tensor([ex[1] for ex in examples], dtype=torch.long)# 对batch内的样本进行padding,使其具有相同长度inputs = pad_sequence(inputs, batch_first=True)return inputs, lengths, targets

class LSTM(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class):super(LSTM, self).__init__()self.embeddings = nn.Embedding(vocab_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True, num_layers=3)self.output = nn.Linear(hidden_dim * 6, num_class)

def forward(self, inputs, lengths):embeddings = self.embeddings(inputs)x_pack = pack_padded_sequence(embeddings, lengths, batch_first=True, enforce_sorted=False)hidden, (hn, cn) = self.lstm(x_pack)outputs = self.output(hn.permute(1, 0, 2).reshape(-1, 6 * 256))log_probs = F.log_softmax(outputs, dim=-1)return log_probs

embedding_dim = 128hidden_dim = 256num_class = 2batch_size = 32num_epoch = 20

# 加载数据train_data, test_data, vocab = load_sentence_polarity()train_dataset = LstmDataset(train_data)test_dataset = LstmDataset(test_data)train_data_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)test_data_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collate_fn, shuffle=False)

# 加载模型device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = LSTM(len(vocab), embedding_dim, hidden_dim, num_class)model.to(device)  # 将模型加载到GPU中(如果已经正确安装)

# 训练过程nll_loss = nn.NLLLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)  # 使用Adam优化器scheduler = optim.lr_scheduler.StepLR(optimizer, 15, gamma=0.99)

model.train()for epoch in range(num_epoch):total_loss = 0for i, batch in tqdm(enumerate(train_data_loader), desc=f"Training Epoch {epoch}"):inputs, lengths, targets = [x.to(device) for x in batch]lengths = lengths.cpu()log_probs = model(inputs, lengths)loss = nll_loss(log_probs, targets)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()# 对每一个batch记录loss变化summary.add_scalar('train_batch_loss', total_loss, epoch * len(train_data_loader) + i)scheduler.step()print(f"Loss: {total_loss:.2f}, lr: {scheduler.get_last_lr()[0]}")# 对每一个epoch记录loss变化summary.add_scalar('train_epoch_loss', total_loss, epoch)# 对每一个epoch记录lr变化summary.add_scalar('train_epoch_lr', scheduler.get_last_lr()[0], epoch)

# 测试过程acc = 0for batch in tqdm(test_data_loader, desc=f"Testing"):inputs, lengths, targets = [x.to(device) for x in batch]lengths = lengths.cpu()

with torch.no_grad():output = model(inputs, lengths)acc += (output.argmax(dim=1) == targets).sum().item()

# 输出在测试集上的准确率print(f"Acc: {acc / len(test_data_loader):.2f}")

torch使用tensorboard简明备忘录相关推荐

  1. 错误处理笔记 导入 torch.utils.tensorboard时 找不到tensorboard

    from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter('torch_tensorboard_example') ...

  2. 成功解决ModuleNotFoundError: No module named 'torch.utils.tensorboard'

    成功解决ModuleNotFoundError: No module named 'torch.utils.tensorboard' 目录 解决问题 解决思路 解决方法 解决问题 ModuleNotF ...

  3. torch.utils.tensorboard用法

    相比于visdom,tensorborad对结果可视化集成的更好,界面相对更为美观,内容更为丰富,实现过程也更为简单. 1. Tensorboard面板介绍: 2. 使用tensorboard记录结果 ...

  4. 随便聊聊torch.utils.tensorboard跟tensorboardX(待验证)

    浅谈torch.utils.tensorboard跟tensorboardX 1. 前言 2. 分析 2.1 tensorboardX 2.2 torch.utils.tensorboard 3. 结 ...

  5. torch复现论文简明笔记

    1)常数初始化: x = F.normalize(x, p=2, dim=1)按行计算 x = F.normalize(x, p=2, dim=0)按列计算 torch.empty(size)返回形状 ...

  6. pytorch中ModuleNotFoundError: No module named ‘tensorboard‘

    from torch.utils.tensorboard import SummaryWriter 报错: ModuleNotFoundError: No module named 'tensorbo ...

  7. yolov5的3.0版本代码在训练的时候报错:ImportError: cannot import name ‘amp‘ from ‘torch.cuda‘ 以及yolov5的3.0环境安装

    欢迎大家关注笔者,你的关注是我持续更博的最大动力 原创文章,转载告知,盗版必究 yolov5的3.0版本代码在训练的时候报错:ImportError: cannot import name 'amp' ...

  8. Pytorch学习-tensorboard的使用

    Pytorch学习-tensorboard的使用 1 Tensorboard简介 运行机制 安装及测试 2 SummaryWriter实例的使用教程 (1)初始化summaryWriter的方法 (2 ...

  9. pytorch中tensorboard使用

    参考官网 导入tensorboard from torch.utils.tensorboard import SummaryWriter 创建SummaryWriter实例 writer=Summar ...

最新文章

  1. numpy使用[]语法索引二维numpy数组中指定指定行之前所有数据行的数值内容(accessing rows in numpy array before specifc row)
  2. mysql binlog恢复sql_binlog2sql实现MySQL误操作的恢复
  3. jsTree设置默认节点全部展开的方法
  4. JQuery方式执行ajax请求
  5. php 7.2 安装 mcrypt 扩展(亲测)
  6. 三层交换机工作原理介绍
  7. leetcode(一)刷题两数之和
  8. c# winform如何异步不卡界面
  9. linux 将当前时间往后调整2分钟_【转】修改LINUX时间
  10. struts2中的constant配置详解
  11. Atitit sprbt 多数据源mltds datasource multi 目录 第一节 App cfg 1 第二节 Cfg bean 1 第三节 Use 4 第一节 App cfg
  12. php data取年月,PHP-date函数 年、月、日参数详解
  13. php的redis安装配置,Redis 的安装配置介绍_php
  14. 极度未知HyperX20周年盛惠—HyperX Cloud 2 飓风FPS耳机听音辨位
  15. Android 源码目录结构
  16. 用C 绘制实时曲线图
  17. 从“账房先生”到“中国巨型计算机之父”,慈云桂先后主导了中国四代计算机的研发...
  18. Android Q中外部存储盘路径正则表达式的理解
  19. 内网如何下载docker镜像
  20. 线上CPU负载过高处理

热门文章

  1. 反知识蒸馏后门攻击:Anti-Distillation Backdoor Attacks: Backdoors Can Really Survive in Knowledge Distillation
  2. GPS天线走线类型及注意事项
  3. linux 麦克风设备,Linux-创建虚拟麦克风和扬声器
  4. 【Unity】 HTFramework框架(十九)ILHotfix热更新模块
  5. Django-admin后台LOGO字样修改方法
  6. 搞了一个论坛玩玩!http://lupeiqing.3322.org/bbs
  7. U盘不识别,磁盘管理器显示无媒体
  8. AS打开照相机拍照保存本地、显示页面
  9. 微信小程序答题赢红包 微信答题小程序抢红包,答题领微信零钱红包,答题红包小程序,可以自己出题考试的小程序
  10. web前端|品优购|html+css|代码