1.

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

    • 1.
  • pytorch版Class-Balanced Loss训练模型
  • 一、数据准备
  • 二、模型训练
  • 三、模型预测
  • 总结

pytorch版Class-Balanced Loss训练模型

论文地址:https://arxiv.org/pdf/1901.05555.pdf

一、数据准备

将自己的数据集按照一下格式进行准备
执行以下代码进行数据集划分

执行以下代码进行数据集划分

def split_dataset(dataset, split_fraction=0.9):classes = [d for d in os.listdir(dataset) if os.path.isdir(os.path.join(dataset, d))]train_data_dir = os.path.join(dataset, 'train')test_data_dir = os.path.join(dataset, 'test')if os.path.exists(train_data_dir) and os.path.exists(test_data_dir):shutil.rmtree(train_data_dir)shutil.rmtree(test_data_dir)os.makedirs(train_data_dir)os.makedirs(test_data_dir)for cls in classes:dr = os.path.join(dataset, cls)cls_samples = [f for f in os.listdir(dr) if os.path.isfile(os.path.join(dr, f))]train_samples = random.sample(cls_samples, int(np.ceil(len(cls_samples) * split_fraction)))test_samples = [s for s in cls_samples if s not in train_samples]os.mkdir(os.path.join(dataset, 'train', cls))os.mkdir(os.path.join(dataset, 'test', cls))for s in train_samples:shutil.copy(os.path.join(dr, s), os.path.join(dataset, 'train', cls))for s in test_samples:shutil.copy(os.path.join(dr, s), os.path.join(dataset, 'test', cls))return train_data_dir, test_data_dir

划分完训练集和测试集合之后进行训练数据的生成,train文件夹运行一次,test文件夹运行一次,生成train.txt和text.txt.

origin_dir = '/mnt/cecilia/all/train'
cls_dir_list = sorted(os.listdir(origin_dir))
f = open('/mnt/cecilia/class.txt', 'a')
cls_list = []
for i,cls in enumerate(cls_dir_list):cls_list.append(cls)f.write(str(cls_list))img_list = glob.glob(os.path.join(origin_dir, cls, '*'))print('cls',cls)for img in  img_list:f.write(f'{img},{i}\n')

另外还要生成一个每个类别对应的图片数量txt,列表形式[10,20,…]
训练的时候的class_index需要和类别下文件数量列表对应起来。

二、模型训练

import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
import os
from torchvision import modelsfrom torch.utils.data import Dataset, DataLoader
import numpy as np
from torch.optim import lr_scheduler
from PIL import Image
from torch.autograd import Variableimport torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as initfrom torch.autograd import VariableNO_OF_CLASSES = 2388def focal_loss(labels, logits, alpha, gamma):"""Compute the focal loss between `logits` and the ground truth `labels`.Focal loss = -alpha_t * (1-pt)^gamma * log(pt)where pt is the probability of being classified to the true class.pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit).Args:labels: A float tensor of size [batch, num_classes].logits: A float tensor of size [batch, num_classes].alpha: A float tensor of size [batch_size]specifying per-example weight for balanced cross entropy.gamma: A float scalar modulating loss from hard and easy examples.Returns:focal_loss: A float32 scalar representing normalized total loss."""BCLoss = F.binary_cross_entropy_with_logits(input = logits, target = labels,reduction = "none")if gamma == 0.0:modulator = 1.0else:modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 +torch.exp(-1.0 * logits)))loss = modulator * BCLossweighted_loss = alpha * lossfocal_loss = torch.sum(weighted_loss)focal_loss /= torch.sum(labels)return focal_lossdef CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma):"""Compute the Class Balanced Loss between `logits` and the ground truth `labels`.Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)where Loss is one of the standard losses used for Neural Networks.Args:labels: A int tensor of size [batch].logits: A float tensor of size [batch, no_of_classes].samples_per_cls: A python list of size [no_of_classes].no_of_classes: total number of classes. intloss_type: string. One of "sigmoid", "focal", "softmax".beta: float. Hyperparameter for Class balanced loss.gamma: float. Hyperparameter for Focal loss.Returns:cb_loss: A float tensor representing class balanced loss"""#effective_num = 1.0 - np.power(beta, samples_per_cls)#weights = (1.0 - beta) / np.array(effective_num)#weights = weights / np.sum(weights) * int(NO_OF_CLASSES)weights = np.array([1])*NO_OF_CLASSES labels_one_hot = F.one_hot(labels, no_of_classes).float().cuda()# print(labels_one_hot.shape)weights = torch.tensor(weights).float()# 增加维度weights = weights.unsqueeze(0).cuda()weights = weights.repeat(labels_one_hot.shape[0], 1) * labels_one_hotweights = weights.sum(1)weights = weights.unsqueeze(1)weights = weights.repeat(1, no_of_classes)# print('=====weights=====',weights)if loss_type == "focal":cb_loss = focal_loss(labels_one_hot, logits, weights, gamma)elif loss_type == "sigmoid":cb_loss = F.binary_cross_entropy_with_logits(input=logits, target=labels_one_hot, weight=weights)elif loss_type == "softmax":pred = logits.softmax(dim=1)cb_loss = F.binary_cross_entropy(input=pred, target=labels_one_hot, weight=weights)else:cb_loss = Nonereturn cb_lossdef default_loader(path):return Image.open(path).convert('RGB')class MyDataset(Dataset):def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):fh = open(txt, 'r')imgs = []for line in fh:line = line.strip('\n')line = line.rstrip()words = line.split(',')imgs.append((words[0], int(words[1])))self.imgs = imgsself.transform = transformself.target_transform = target_transformself.loader = loaderdef __getitem__(self, index):fn, label = self.imgs[index]img = self.loader(fn)if self.transform is not None:img = self.transform(img)return img, labeldef __len__(self):return len(self.imgs)def target_transform(label):label = np.array(label)target = torch.from_numpy(label).long()return target# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'
# 参数设置,使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints')  # 输出结果保存路径
args = parser.parse_args()# 超参数设置
EPOCH = 100  # 遍历数据集次数
pre_epoch = 0  # 定义已经遍历数据集的次数
BATCH_SIZE = 64  # 批处理尺寸(batch_size)
LR = 0.1  # 学习率# 准备数据集并预处理
transform_train = transforms.Compose([transforms.Resize([144, 144]),transforms.CenterCrop([128, 128]),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # R,G,B每层的归一化用到的均值和方差
])transform_test = transforms.Compose([transforms.Resize((128, 128)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])train_data = MyDataset(txt='/mnt/cecilia/train.txt', transform=transform_train)
test_data = MyDataset(txt='/mnt/cecilia/test.txt', transform=transform_test)trainloader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
testloader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE)# 模型定义-ResNet
net = models.resnet34(pretrained=True).to(device)
net.fc = nn.Sequential(nn.Linear(512, NO_OF_CLASSES)).to(device)# 定义损失函数和优化方式
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9,weight_decay=5e-4)  # 优化方式为mini-batch momentum-SGD,并采用L2正则化(权重衰减)
scheduler = lr_scheduler.MultiStepLR(optimizer, [50, 80], 0.1)# 训练
if __name__ == "__main__":cls_num = eval(open('/mnt/cecilia/num.txt', 'r').read())num = np.array(cls_num)if not os.path.exists(args.outf):os.makedirs(args.outf)best_acc = 0  # 2 初始化best test accuracyprint("Start Training, Resnet-34!")  # 定义遍历数据集的次数with open("acc.txt", "w") as f:with open("log.txt", "w")as f2:for epoch in range(pre_epoch, EPOCH):scheduler.step()print('\nEpoch: %d' % (epoch + 1))net.train()sum_loss = 0.0correct = 0.0total = 0.0for i, data in enumerate(trainloader, 0):# 准备数据length = len(trainloader)inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# forward + backwardoutputs = net(inputs).to(device)# print('=====outputs=====',outputs)loss = CB_loss(labels=labels, logits=outputs,samples_per_cls=num, no_of_classes=NO_OF_CLASSES,loss_type="sigmoid", beta=0.9999, gamma=1)loss.backward()optimizer.step()# 每训练1个batch打印一次loss和准确率sum_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += predicted.eq(labels.data).cpu().sum()print('[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% | Lr: %.03f'% (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1),100. * float(correct) / total, optimizer.state_dict()['param_groups'][0]['lr']))f2.write('%03d  %05d |Loss: %.03f | Acc: %.3f%% | Lr: %.03f'% (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1),100. * float(correct) / total, optimizer.state_dict()['param_groups'][0]['lr']))f2.write('\n')f2.flush()# 每训练完一个epoch测试一下准确率print("Waiting Test!")with torch.no_grad():correct = 0total = 0for data in testloader:net.eval()images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)# 取得分最高的那个类 (outputs.data的索引号)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum()print('测试分类准确率为:%.3f%%' % (100 * float(correct) / total))acc = 100. * float(correct) / totalif (epoch + 1) % 1 == 0:# 将每次测试结果实时写入acc.txt文件中print('Saving model......')torch.save(net.state_dict(), '%s/net_%03d.pth' % (args.outf, epoch + 1))f.write("EPOCH=%03d,Accuracy= %.3f%%| Lr: %.03f" % (epoch + 1, acc,optimizer.state_dict()['param_groups'][0]['lr']))f.write('\n')f.flush()# 记录最佳测试分类准确率并写入best_acc.txt文件中if acc > best_acc:f3 = open("best_acc.txt", "w")f3.write("EPOCH=%d,best_acc= %.3f%%" % (epoch + 1, acc))f3.close()best_acc = accprint("Training Finished, TotalEPOCH=%d" % EPOCH)

执行该段代码即可进行模型的训练。

三、模型预测

from torchvision import models
import torchvision.transforms as transforms
import os
import glob
import copy
import cv2
import random
import numpy as np
import shutil
import torch
import shutil
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([transforms.Resize((128, 128)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
net = models.resnet34(pretrained=True).to(device)
net.fc = nn.Sequential(nn.Linear(512, 2388)).to(device)
torch.no_grad()
PATH = '/data/net_091.pth'
model_dict=torch.load(PATH)
model_dict=net.load_state_dict(torch.load(PATH))
net.eval()
old = sorted(os.listdir('/data/all/test'))
image_cls_acc = {}
num = 0
for i,cls in enumerate(old):
#     if cls != '飘柔滋润去屑洗发露400':
#         continueimg_list = glob.glob(os.path.join('/data/all/test', cls, '*'))for img_path in img_list:img =  Image.open(img_path).convert('RGB')img_ = transform(img).unsqueeze(0)img_ = img_.to(device)outputs = net(img_).cpu().detach().numpy()[0]
#         print('outputs',outputs)top1_index = np.argmax(outputs)top5_list =list(outputs.argsort()[-5:])print('outputs',i,top1_index,)if cls not in image_cls_acc:image_cls_acc[cls] = [1, 0, 0]if int(i) == int(top1_index) :image_cls_acc[cls][1] += 1if int(i) in top5_list:image_cls_acc[cls][2] += 1else:image_cls_acc[cls][0] += 1if int(i) == int(top1_index):image_cls_acc[cls][1] += 1if int(i) in top5_list:image_cls_acc[cls][2] += 1num += 1

总结

先转过来,有空重新把这个代码优化一下。

pytorch训练Class-Balanced Loss相关推荐

  1. pytorch训练过程中loss出现NaN的原因及可采取的方法

    在pytorch训练过程中出现loss=nan的情况 1.学习率太高. 2.loss函数 3.对于回归问题,可能出现了除0 的计算,加一个很小的余项可能可以解决 4.数据本身,是否存在Nan,可以用n ...

  2. pytorch训练时前一次迭代的loss正常后一次迭代却报nan

    问题描述:训练一个有四个阶段模型的网络(可以简单理解每一阶段就是一个卷积层),每阶段输出都会进行loss计算.现将pytorch原始的ce loss改为focal loss后,网络训练了数个迭代后lo ...

  3. pytorch训练神经网络loss刚开始下降后来停止下降的原因

    问题提出:用pytorch训练VGG16分类,loss从0.69下降到0.24就开始小幅度震荡,不管如何调整batch_size和learning_rate都无法解决. 原因:没有加载预训练模型 那么 ...

  4. pytorch训练 loss=inf或者训练过程中loss=Nan

    造成 loss=inf的原因之一:data underflow 最近在测试Giou的测试效果,在mobilenetssd上面测试Giou loss相对smoothl1的效果: 改完后训练出现loss= ...

  5. pytorch中实现Balanced Cross-Entropy

    当你明白了pytorch中F.cross_entropy以及F.binary_cross_entropy是如何实现的之后,你再基于它们做改进重新实现一个损失函数就很容易了. 1.背景 变化检测中,往往 ...

  6. 让PyTorch训练速度更快,你需要掌握这17种方法

    选自efficientdl.com 作者:LORENZ KUHN 机器之心编译 编辑:陈萍 掌握这 17 种方法,用最省力的方式,加速你的 Pytorch 深度学习训练. 近日,Reddit 上一个帖 ...

  7. PyTorch训练加速17种技巧

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 文自 机器之心 作者:LORENZ KUHN 编辑:陈萍 掌握这 ...

  8. 送你9个快速使用Pytorch训练解决神经网络的技巧(附代码)

    来源:读芯术 本文约4800字,建议阅读10分钟. 本文为大家介绍9个使用Pytorch训练解决神经网络的技巧 图片来源:unsplash.com/@dulgier 事实上,你的模型可能还停留在石器时 ...

  9. 这17 种方法让 PyTorch 训练速度更快!

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:选自 | efficientdl.com   作者 | LO ...

最新文章

  1. 常见的块状和内联元素
  2. CodePush使用调研
  3. AOP 中必须明白的概念-切面(Aspect)
  4. JS原生方法实现jQuery的ready()
  5. 计算机专业方向是什么意思,计算机专业就业方向是什么
  6. Python项目--飞机作战完整版(附带图片素材)
  7. 单例模式中的线程安全问题
  8. 关于ExtJS错误“例外被抛出且未被接住”问题
  9. 74ls390设计任意进制计数器_《设计任意进制计数器》的实验报告
  10. 小米平板完整bios_小米平板bios设置u盘启动操作步骤
  11. vue使用prevent修饰符阻止标签的默认行为
  12. 基于无监督深度学习的单目深度和自身运动轨迹估计的深度神经模型
  13. bootstrapTreeTable 树性插件不能够自动展开的问题
  14. 台湾大学 李宏毅教授的个人主页
  15. 创新Sound Blaster Tactic3D Alpha耳机驱动v1.0官方版
  16. 随机权重的PSO算法
  17. PostgreSQL 使用RETURNING返回值
  18. python 福利_python 处理json
  19. 清华大学计算机学院92届韩松,清华大学计算机科学与技术系2020—2021学年度学生科协主席候选人名单公示...
  20. 高通平台 mipi转接屏调试 (以转lvds icn6202例)

热门文章

  1. Centos 7 根目录存储容量调整大小
  2. 车势科技发力汽车VR,继阿里BUY+之后再现VR购物挑战者
  3. 你还不知道钉钉服务端API全局错误码吗?
  4. postman中 form-data、x-www-form-urlencoded、raw、binary的区别
  5. 服务器开机系统进不去怎么办,开机就进BIOS进不去系统怎么处理
  6. redis cli命令详解
  7. 在同一个数据集中同时更新多表..............
  8. Java程序员拼多多3轮面试,这些面试题你能掌握多少?
  9. 《Hibernate上课笔记》----class4----Hibernate继承关系映射实现详解
  10. C语言如何编程换算小时秒,小时分钟秒的换算(c语言把时间转换成秒)