最近学习代码时发现当自己去实现代码的时候对于样本的Loss和accuracy的计算很不理解,看别人的代码也是靠猜测,所以自己去官方文档学习加上自己做了个小实验以及搜索了别人的博客,总是算明白了怎么回事,所以打算写下来记录(纯粹记录,无参考意义)

accuracy 计算

关于accuracy的计算:acc=正确个数 / 样本总数我们知道,经过模型的输出的最后的一个结果是通过一个softmax算法的出来的,也就是说,输出的给过给出了这个模型对于每个类别的概率预测(且所有概率相加等于1),概率最大的类别也就是模型预测出的类别(插一句:那么我们可能遇到这样的情况,有几个类别预测的概率相差不大,最终结果是那一个比其他类大一点的结果,那其实这个时候表示模型的泛化性能很不好,所以对于网络的评价Loss是最好用的);那么首先我们要做的是要将预测结果最大概率的标签(类别)提取出来,这里我给出一个例子:

a = torch.tensor([[0.03,0.12,0.85], [0.01,0.9,0.09], [0.95,0.01,0.04], [0.09, 0.9, 0.01]])
print(a)
print(a.dtype)

我假设有一个经过模型输出的结果,现在我们需要提出最大概率的结果

我们用torch.max()函数打印了结果,可以看到,输出的是最大类别的概率以及对应索引(即类别),但是我们只要索引,所以取第二个元素predicted=torch.max(output.data, 1)[1],解释代码:这里第一个1表示求行的最大值因为每一行代表了一个样本的输出结果,第二个1代表了我们只要索引,还有另一种方法:predicted=torch.argmax(output, 1),torch.argmax()函数:求最大数值的索引
接下来我们需要计算预测正确的个数:correct += (output == labels).sum().item(),首先,“outpuut == labels” 的语法求出正确的类,类似于[1, 0, 1, 0],1表示预测正确,0表示错误,然后 .sum()将所有正确的预测加起来,得到预测正确的个数,torch.item(),这时候输出的是一个tensor类,比如有两个预测正确:tensor(2),.item() 语法将tensor转化为普通的float或者int,

最后total += label.size(0),求出样本总数:那么acc = correct / total,得出精度

Loss的计算

现在的框架大多是采用的是Minibatch梯度下降法,一般在进行Loss的计算我们都是用的交叉熵损失函数计算损失比如 CrossEntropyLoss(),其实,这个函数求出的是每次minibatch的平均损失,那么当我们将每次Minibatch损失加起来后需要除以的是step(即步长),数据的分批次计算采用的是torch.utils.data.DataLoader()函数,那么就需要步长就是len(Loadr),用最终的总Loss除以len(Loader),这就是Loss计算
下面会给出个人的代码全部(口说无凭,且不太好理解)

from torchvision import transforms
from torchvision.datasets import ImageFolder
import torchvision
import torch.nn as nn
import torch
from tqdm import tqdm
import sys
def train(lr, weight_decay, num_epochs):device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.\n".format(device))train_transfrom = transforms.Compose([transforms.Resize(224),transforms.RandomResizedCrop(224, scale=(0.64, 1.0), ratio=(0.75, 1.33)),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=(0.7, 1.3), contrast=(0.8, 1.2), saturation=(0.9, 1.1), hue=0),transforms.RandomRotation(degrees=(-20, 20)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])valid_transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])train_dataset = ImageFolder('../input/fruits/fruits-360_dataset/fruits-360/Training', transform=train_transfrom)valid_dataset = ImageFolder('../input/fruits/fruits-360_dataset/fruits-360/Test', transform=valid_transform)n_classes = len(train_dataset.classes)train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=32,shuffle=True,num_workers=2)valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,batch_size=16,shuffle=False,num_workers=2)loss_function = nn.CrossEntropyLoss()net = torchvision.models.resnet18(pretrained=True)in_channel = net.fc.in_featuresnet.fc = nn.Sequential(nn.Linear(in_channel, n_classes))net.to(device)optimizer = torch.optim.AdamW(net.parameters(), lr=lr, weight_decay=weight_decay)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.6, verbose=True)best_acc = 0for epoch in range(num_epochs):net.train()train_total = 0train_correct = 0train_loss = 0for batch in tqdm(train_loader):imgs, labels = batchimgs, labels = imgs.to(device), labels.to(device)optimizer.zero_grad()optputs = net(imgs)loss = loss_function(optputs, labels)loss.backward()optimizer.step()train_loss += loss.item()predicted = torch.argmax(optputs, 1)train_correct += (predicted == labels).sum().item()train_total += labels.size(0)del imgs, labelstorch.cuda.empty_cache()scheduler.step()train_loss = train_loss / len(train_loader)train_accuracy = train_correct / train_totalprint(f"[ Train | {epoch + 1:03d}/{num_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_accuracy:.5f}")net.eval()valid_correct, valid_total, valid_loss = 0, 0, 0for batch in tqdm(valid_loader):imgs, labels = batchimgs, labels = imgs.to(device), labels.to(device)with torch.no_grad():outputs = net(imgs)loss = loss_function(outputs, labels)#predicted = torch.argmax(optputs, 1)predicted = torch.max(outputs.data, 1)[1]valid_correct += (predicted == labels).sum().item()valid_loss += loss.item()valid_total += labels.size(0)del imgs, labelstorch.cuda.empty_cache()valid_accuracy = valid_correct / valid_totalvalid_loss = valid_loss / len(valid_loader)print(f"[ Valid | {epoch + 1:03d}/{num_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_accuracy:.5f}\n")if valid_accuracy > best_acc:best_acc =valid_accuracyprint(f"best acc [{valid_accuracy:.5f}] in epoch {epoch + 1}\n")print(f"last, best acc [{valid_accuracy:.5f}]")

pytorch accuracy和Loss 的计算相关推荐

  1. Pytorch的model.train() model.eval() torch.no_grad() 为什么测试的时候不调用loss.backward()计算梯度还要关闭梯度

    使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval model.train() 启用 BatchNormalization 和 Dropout 告诉我们的网络,这 ...

  2. pytorch中网络loss传播和参数更新理解

    相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56 ...

  3. pytorch训练Class-Balanced Loss

    1. 提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 1. pytorch版Class-Balanced Loss训练模型 一.数据准备 二.模型训练 三.模型预测 总结 ...

  4. pytorch版Class-Balanced Loss训练模型

    pytorch版Class-Balanced Loss训练模型 1.论文参考原文 https://arxiv.org/pdf/1901.05555.pdf 2.数据准备 将自己的数据集按照一下格式进行 ...

  5. (27)第四节课:从零起步在框架中编码实现损失度Loss的计算及其可视化

    本节将从零起步在框架中编码实现损失度Loss的计算及其可视化.本节包含三部分的内容: (1)   关于Loss的思考,Loss是所有AI框架终身的魔咒. (2)   编码实现Loss实现并进行测试. ...

  6. Pytorch里面多任务Loss是加起来还是分别backward?

    作者丨歪杠小胀@知乎(已授权) 来源丨https://zhuanlan.zhihu.com/p/451441329 编辑丨极市平台 导读 如果只有一个loss,那么直接loss.backward()即 ...

  7. 有bug!PyTorch在AMD CPU的计算机上卡死了

    视学算法报道 转载自:机器之心 编辑:小舟.陈萍 AMD,No?PyTorch在AMD CPU的机器上出现死锁了. PyTorch 作为机器学习中广泛使用的开源框架,具有速度快.效率高等特点.而近年来 ...

  8. 交叉熵损失(Cross Entropy Loss)计算过程

    交叉熵损失(Cross Entropy Loss)计算过程_藏知阁-CSDN博客_交叉熵计算公式

  9. Pytorch中nn.Conv2d数据计算模拟

    Pytorch中nn.Conv2d数据计算模拟 最近在研究dgcnn网络的源码,其网络架构部分使用的是nn.Conv2d模块.在Pytorch的官方文档中,nn.Conv2d的输入数据为(B, Cin ...

  10. 如何利用PyTorch中的Moco-V2减少计算约束

    介绍 SimCLR论文(http://cse.iitkgp.ac.in/~arastogi/papers/simclr.pdf)解释了这个框架如何从更大的模型和更大的批处理中获益,并且如果有足够的计算 ...

最新文章

  1. 2014Esri国际用户大会ArcGIS Online
  2. 【数据平台】sklearn库特征工程之数据预处理
  3. PHP创建图像的应用!!!!
  4. Eclipse中单元测试
  5. 【OS学习笔记】四十 保护模式十:中断和异常的处理与抢占式多任务对应的汇编代码----动态加载的用户程序/任务二代码
  6. three.js 渲染器更改背景色的几种方法
  7. 还不懂!软件测试(功能、接口、性能、自动化)详解
  8. python面向对象思路_Python面向对象三要素-继承(Inheritance)
  9. uva 11916 Emoogle Grid (BSGS)
  10. 提取swf素材_swf素材提取工具
  11. 关于vs2015各版本的卸载
  12. 拼音模糊搜索 php,基于 XunSearch(迅搜)SDK 的全文搜索 Laravel 5.* 软件包,支持全拼、拼音简写、模糊搜索、热门搜索、搜索提示...
  13. 自媒体月入过万的自媒体赚钱工具,免费教会你!
  14. Excel如何合并相同项单元格
  15. java基础——多态
  16. 2019最新天善智能python3数据分析与挖掘项目实战(完整)
  17. 2019年成功与失败的危机公关案例分析
  18. 指南-Luat二次开发教程-功能开发教程-SOCKET
  19. 高速串行计算机扩展总线标准,高速串行计算机扩展总线标准Bosch Sensortec开发出BMP384...
  20. 【软考笔记】1. 计算机原理与体系结构

热门文章

  1. MeanShift算法原理及其python自定义实现
  2. 如何批量导出QQ空间相册到电脑中
  3. PR中视频材料声音大小不一样?1招快速统一音量
  4. PS4 eye camera v2 ROS测试
  5. 经济危机离你并不遥远!
  6. 2007年中国网络游戏市场分析及投资咨询报告(上下卷)
  7. 电脑桌面计算机图标下不显示文字,电脑桌面图标下面的文字有时会突然不见,然后 – 手机爱问...
  8. 电脑两个,电脑有两个系统盘怎么办
  9. 通过Redis实现数据的交集、并集、补集
  10. excel怎么合并同类项数据并求和(去除重复项)