点上方计算机视觉联盟获取更多干货

仅作学术分享,不代表本公众号立场,侵权联系删除

转载于:作者:Marek Paulik

编译:ronghuaiyang

AI博士笔记系列推荐

周志华《机器学习》手推笔记正式开源!可打印版本附pdf下载链接

导读

一个非常简单和容易上手的例子。

对于教程中使用的大多数人工数据集,每个类都有相同数量的数据。然而,在实际应用中,这种情况很少发生。今天,我将给你介绍来自Kaggle的木薯叶分类,并告诉你当类频率有很大差异时该怎么做。

处理类别的不平衡

有两种方法可以解决这个问题。

  • WeightedRandomSampler

  • loss函数中的weight参数

下一步是创建一个有5个方法的CassavaClassifier类:load_data()、load_model()、fit_one_epoch()、val_one_epoch()和fit()。

在load_data()中,将构造一个train和验证数据集,并返回数据加载器以供进一步使用。

在load_model()中定义了体系结构、损失函数和优化器。

fit方法包含一些初始化和对fit_one_epoch()和val_one_epoch()的循环。

早期停止

早期停止类有助于根据验证损失跟踪最佳模型,并保存检查点。

#Callbacks
# Early stopping
class EarlyStopping:def __init__(self, patience=1, delta=0, path='checkpoint.pt'):self.patience = patienceself.delta = deltaself.path= pathself.counter = 0self.best_score = Noneself.early_stop = Falsedef __call__(self, val_loss, model):if self.best_score is None:self.best_score = val_lossself.save_checkpoint(model)elif val_loss > self.best_score:self.counter +=1if self.counter >= self.patience:self.early_stop = True else:self.best_score = val_lossself.save_checkpoint(model)self.counter = 0      def save_checkpoint(self, model):torch.save(model.state_dict(), self.path)

Init

我们首先初始化CassavaClassifier类。

class CassavaClassifier():def __init__(self, data_dir, num_classes, device, Transform=None, sample=False, loss_weights=False, batch_size=16,lr=1e-4, stop_early=True, freeze_backbone=True):############################################################################################################## data_dir - directory with images in subfolders, subfolders name are categories# Transform - data augmentations# sample - if the dataset is imbalanced set to true and RandomWeightedSampler will be used# loss_weights - if the dataset is imbalanced set to true and weight parameter will be passed to loss function# freeze_backbone - if using pretrained architecture freeze all but the classification layer###############################################################################################################self.data_dir = data_dirself.num_classes = num_classesself.device = deviceself.sample = sampleself.loss_weights = loss_weightsself.batch_size = batch_sizeself.lr = lrself.stop_early = stop_earlyself.freeze_backbone = freeze_backboneself.Transform = Transform

Load Data

训练图像被组织在子文件夹中,子文件夹名称表示图像的类。这是图像分类问题的典型情况,幸运的是,不需要编写自定义数据集类。在这种情况下,可以立即使用torchvision中的ImageFolder。如果你想使用WeightedRandomSampler,你需要为数据集的每个元素指定一个权重。通常,总图像总比上类别数被用作一个权重。

def load_data(self):train_full = torchvision.datasets.ImageFolder(self.data_dir, transform=self.Transform)train_set, val_set = random_split(train_full, [math.floor(len(train_full)*0.8), math.ceil(len(train_full)*0.2)])self.train_classes = [label for _, label in train_set]if self.sample:# Need to get weight for every image in the datasetclass_count = Counter(self.train_classes)class_weights = torch.Tensor([len(self.train_classes)/c for c in pd.Series(class_count).sort_index().values]) # Can't iterate over class_count because dictionary is unorderedsample_weights = [0] * len(train_set)for idx, (image, label) in enumerate(train_set):class_weight = class_weights[label]sample_weights[idx] = class_weightsampler = WeightedRandomSampler(weights=sample_weights,num_samples = len(train_set), replacement=True)  train_loader = DataLoader(train_set, batch_size=self.batch_size, sampler=sampler)else:train_loader = DataLoader(train_set, batch_size=self.batch_size, shuffle=True)val_loader = DataLoader(val_set, batch_size=self.batch_size)return train_loader, val_loader

Load Model

在该方法中,我使用迁移学习,架构参数从预先训练的resnet50和efficientnet-b7中选择。CrossEntropyLoss和许多其他损失函数都有权重参数。这是一个手动调整参数,用于处理不平衡。在这种情况下,不需要为每个参数定义权重,只需为每个类定义权重。

def load_model(self, arch='resnet'):############################################################################################################### arch - choose the pretrained architecture from resnet or efficientnetb7############################################################################################################## if arch == 'resnet':self.model = torchvision.models.resnet50(pretrained=True)if self.freeze_backbone:for param in self.model.parameters():param.requires_grad = Falseself.model.fc = nn.Linear(in_features=self.model.fc.in_features, out_features=self.num_classes)elif arch == 'efficient-net':self.model = EfficientNet.from_pretrained('efficientnet-b7')if self.freeze_backbone:for param in self.model.parameters():param.requires_grad = Falseself.model._fc = nn.Linear(in_features=self.model._fc.in_features, out_features=self.num_classes)    self.model = self.model.to(self.device)self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr) if self.loss_weights:class_count = Counter(self.train_classes)class_weights = torch.Tensor([len(self.train_classes)/c for c in pd.Series(class_count).sort_index().values])# Cant iterate over class_count because dictionary is unorderedclass_weights = class_weights.to(self.device)  self.criterion = nn.CrossEntropyLoss(class_weights)else:self.criterion = nn.CrossEntropyLoss()

Fit One Epoch

这个方法只包含一个经典的训练循环,带有训练损失记录和tqdm进度条。

def fit_one_epoch(self, train_loader, epoch, num_epochs ): step_train = 0train_losses = list() # Every epoch check average loss per batch train_acc = list()self.model.train()for i, (images, targets) in enumerate(tqdm(train_loader)):images = images.to(self.device)targets = targets.to(self.device)logits = self.model(images)loss = self.criterion(logits, targets)loss.backward()self.optimizer.step()self.optimizer.zero_grad()train_losses.append(loss.item())#Calculate running train accuracypredictions = torch.argmax(logits, dim=1)num_correct = sum(predictions.eq(targets))running_train_acc = float(num_correct) / float(images.shape[0])train_acc.append(running_train_acc)train_loss = torch.tensor(train_losses).mean()    print(f'Epoch {epoch}/{num_epochs-1}')  print(f'Training loss: {train_loss:.2f}')

Validate one epoch

与上面类似,但此方法在验证数据加载器上迭代。在每一个epoch'之后,平均batch损失和准确性被打印出来。

def val_one_epoch(self, val_loader, scaler):val_losses = list()val_accs = list()self.model.eval()step_val = 0with torch.no_grad():for (images, targets) in val_loader:images = images.to(self.device)targets = targets.to(self.device)logits = self.model(images)loss = self.criterion(logits, targets)val_losses.append(loss.item())      predictions = torch.argmax(logits, dim=1)num_correct = sum(predictions.eq(targets))running_val_acc = float(num_correct) / float(images.shape[0])val_accs.append(running_val_acc)self.val_loss = torch.tensor(val_losses).mean()val_acc = torch.tensor(val_accs).mean() # Average acc per batchprint(f'Validation loss: {self.val_loss:.2f}')  print(f'Validation accuracy: {val_acc:.2f}')

Fit

Fit方法在训练和验证过程中经历了许多阶段和循环。如果预训练模型的参数在开始时被冻结,那么unfreeze_after定义了整个模型在多少个epoch之后开始训练。在此之前,只训练全连接层(分类器)。

def fit(self, train_loader, val_loader, num_epochs=10, unfreeze_after=5, checkpoint_dir='checkpoint.pt'):if self.stop_early:early_stopping = EarlyStopping(patience=5, path=checkpoint_dir)for epoch in range(num_epochs):if self.freeze_backbone:if epoch == unfreeze_after:  # Unfreeze after x epochsfor param in self.model.parameters():param.requires_grad = Trueself.fit_one_epoch(train_loader, scaler, epoch, num_epochs)self.val_one_epoch(val_loader, scaler)if self.stop_early:early_stopping(self.val_loss, self.model)if early_stopping.early_stop:print('Early Stopping')print(f'Best validation loss: {early_stopping.best_score}')break

Run

现在,可以初始化CassavaClassifier类、创建dataloaders、设置模型并运行整个过程了。

Transform = T.Compose([T.ToTensor(),T.Resize((256, 256)),T.RandomRotation(90),T.RandomHorizontalFlip(p=0.5),T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
data_dir = "Data/cassava-disease/train/train"classifier = CassavaClassifier(data_dir=data_dir, num_classes=5, device=device, sample=True, Transform=Transform)
train_loader, val_loader = classifier.load_data()
classifier.load_model()
classifier.fit(num_epochs=20, unfreeze_after=5, train_loader=train_loader, val_loader=val_loader)

Inference

使用ImageFolder加载测试数据是不可能的,因为显然没有带有类的子文件夹。因此,我创建了一个返回图像和图像id的自定义数据集。随后,加载模型检查点,通过推理循环运行它,并将预测保存到数据帧中。将数据帧导出为CSV并提交结果。

# Inference
model = torchvision.models.resnet50()
#model = EfficientNet.from_name('efficientnet-b7')
model.fc = nn.Linear(in_features=model.fc.in_features, out_features=5)
model = model.to(device)
checkpoint = torch.load('Data/cassava-disease/sampler_checkpoint.pt')
model.load_state_dict(checkpoint)
model.eval()# Dataset for test data
class Cassava_Test(Dataset):def __init__(self, dir, transform=None):self.dir = dirself.transform = transformself.images = os.listdir(self.dir)  def __len__(self):return len(self.images)def __getitem__(self, idx):img = Image.open(os.path.join(self.dir, self.images[idx]))return self.transform(img), self.images[idx] test_dir = 'Data/cassava-disease/test/test/0'
test_set = Cassava_Test(test_dir, transform=Transform)
test_loader = DataLoader(test_set, batch_size=4)  # Test loop
sub = pd.DataFrame(columns=['category', 'id'])
id_list = []
pred_list = []model = model.to(device)with torch.no_grad():for (image, image_id) in test_loader:image = image.to(device)logits = model(image)predicted = list(torch.argmax(logits, 1).cpu().numpy())for id in image_id:id_list.append(id)for prediction in predicted:pred_list.append(prediction)
sub['category'] = pred_list
sub['id'] = id_listmapping = {0:'cbb', 1:'cbsd', 2:'cgm', 3:'cmd', 4:'healthy'}sub['category'] = sub['category'].map(mapping)
sub = sub.sort_values(by='id')sub.to_csv('Cassava_sub.csv', index=False)

如果在方案中包含WeightedRandomSampler或损失权值,则测试集的精度会提高2%。对于仅仅几行代码来说,这是一个很好的改进。对于这个数据集,我没有看到这两种方法在精度上的巨大差异,但WeightedRandomSampler的表现要好一些。

不同的学习速度、优化器和数据扩展肯定有自己的发展空间。然而,对于这种简单的方法来说,86%的准确率似乎足够好了。

—END—

英文原文:https://marekpaulik.medium.com/imbalanced-dataset-image-classification-with-pytorch-6de864982eb1

-------------------

END

--------------------

我是王博Kings,985AI博士,华为云专家、CSDN博客专家(人工智能领域优质作者)。单个AI开源项目现在已经获得了2100+标星。现在在做AI相关内容,欢迎一起交流学习、生活各方面的问题,一起加油进步!

我们微信交流群涵盖以下方向(但并不局限于以下内容):人工智能,计算机视觉,自然语言处理,目标检测,语义分割,自动驾驶,GAN,强化学习,SLAM,人脸检测,最新算法,最新论文,OpenCV,TensorFlow,PyTorch,开源框架,学习方法...

这是我的私人微信,位置有限,一起进步!

王博的公众号,欢迎关注,干货多多

王博Kings的系列手推笔记(附高清PDF下载):

博士笔记 | 周志华《机器学习》手推笔记第一章思维导图

博士笔记 | 周志华《机器学习》手推笔记第二章“模型评估与选择”

博士笔记 | 周志华《机器学习》手推笔记第三章“线性模型”

博士笔记 | 周志华《机器学习》手推笔记第四章“决策树”

博士笔记 | 周志华《机器学习》手推笔记第五章“神经网络”

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(上)

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(下)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(上)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(下)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(上)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(下)

博士笔记 | 周志华《机器学习》手推笔记第九章聚类

博士笔记 | 周志华《机器学习》手推笔记第十章降维与度量学习

博士笔记 | 周志华《机器学习》手推笔记第十一章稀疏学习

博士笔记 | 周志华《机器学习》手推笔记第十二章计算学习理论

博士笔记 | 周志华《机器学习》手推笔记第十三章半监督学习

博士笔记 | 周志华《机器学习》手推笔记第十四章概率图模型

点分享

点收藏

点点赞

点在看

使用PyTorch来进展不平衡数据集的图像分类相关推荐

  1. (pytorch-深度学习系列)pytorch实现对Fashion-MNIST数据集进行图像分类

    pytorch实现对Fashion-MNIST数据集进行图像分类 导入所需模块: import torch import torchvision import torchvision.transfor ...

  2. 使用Pytorch框架自己制作做数据集进行图像分类(一)

    第一章:Pytorch制作自己的数据集实现图像分类 第一章: Pytorch框架制作自己的数据集实现图像分类 第二章: Pytorch框架构建残差神经网络(ResNet) 第三章: Pytorch框架 ...

  3. Pytorch打怪路(三)Pytorch创建自己的数据集2

    前面一篇写创建数据集的博文--- Pytorch创建自己的数据集1 是介绍的应用于图像分类任务的数据集,即输入为一个图像和它的类别数字标签,本篇介绍输入的标签label亦为图像的数据集,并包含一些常用 ...

  4. Pytorch:利用迁移学习做图像分类

    **Pytorch:利用迁移学习做图像分类** 数据准备 数据扩充 数据加载 迁移学习 训练 验证 推理/分类 在这一篇文章中,我们描述了如何在 pytorch中进行图像分类.我们将使用Caltech ...

  5. Pytorch 目标检测和数据集

    Pytorch 目标检测和数据集 0. 环境介绍 环境使用 Kaggle 里免费建立的 Notebook 教程使用李沐老师的 动手学深度学习 网站和 视频讲解 小技巧:当遇到函数看不懂的时候可以按 S ...

  6. Paddle 环境中 使用LeNet在MNIST数据集实现图像分类

    简 介: 测试了在AI Stuio中 使用LeNet在MNIST数据集实现图像分类 示例.基于可以搭建其他网络程序. 关键词: MNIST,Paddle,LeNet #mermaid-svg-FlRI ...

  7. coco数据集大小分类_如何处理不平衡数据集的分类任务

    在情感分类任务中,数据集的标签分布往往是极度不平衡的.以我目前手上的这个二分类任务来说,正例样本14.4万个:负例样本166.1万 = 1 :11.5.很显然这是一个极度不平衡的数据集,假设我把样本全 ...

  8. TypeError: 'module' object is not callable (pytorch在进行MNIST数据集预览时出现的错误)

    在使用pytorch在对MNIST数据集进行预览时,出现了TypeError: 'module' object is not callable的错误: 上报错信息图如下: 从图中可以看出,报错位置为第 ...

  9. 〖TensorFlow2.0笔记21〗自定义数据集(宝可精灵数据集)实现图像分类+补充:tf.where!

    自定义数据集(宝可精灵数据集)实现图像分类+补充:tf.where! 文章目录 一. 数据集介绍以及加载 1.1. 数据集简单描述 1.2. 程序实现步骤 1.3. 加载数据的格式 1.4. map函 ...

最新文章

  1. android checkbox状态不刷新,android开发分享更改checkbox的值,而不触发onCheckChanged
  2. 自动计算表格html,表格怎么自动计算加减
  3. Laravel Passport里的授权类型介绍
  4. CF1550E Stringforces
  5. 一直显示数据格式错误_Excel数据分析,新手最容易犯的10个建表错误
  6. LeetCode_database刷题记录(620. 有趣的电影)
  7. 地铁系统_北斗授时助力北京地铁地下定位系统
  8. Tensorflow实现VGG网络
  9. ios开发--清理缓存
  10. 连麦互动技术及其连麦调研
  11. 【细胞分割】基于matlab中值滤波+分水岭法细胞计数【含Matlab源码 640期】
  12. 惠普打印机故障代码_惠普打印机故障代码
  13. java 图片 大小_在JAVA中调整图片大小
  14. Matlab/Simulink 自动代码生成 基于模型设计学习教程(2)---- 闪烁灯实验
  15. macOS上的符号链接Symlink是什么,以及该怎么使用
  16. 我在国图读完的第二本书 —— 《经济学的思维方式》
  17. linux 查询ip归属地的工具,Linux 通过shell查询ip归属地(curl请求转码)
  18. 亚马逊评价计算器 分析评价利器
  19. STM32DAC输出遇到的问题
  20. 提醒电脑族:眼睛酸涩会致病

热门文章

  1. anaconda创建新环境_【创建社会主义新农村】怀城街道:转变整治理念 农村人居环境换新颜...
  2. 设置mysql为utf-8_如何设置mysql数据库为utf-8编码
  3. python 图片读写_Python各种图像库的图像的基本读写方式
  4. 升级无法登录_JeeSite v4.2.2 发布,代码生成增强、Boot 2.3、短信登录、性能提升...
  5. linux4.9下alsa架构,[Alsa]4, wm8524 Kernel音频子系统入口
  6. linux fips 模式,linux – FIPS Capable OpenSSL交叉编译:内容指纹问题
  7. 第三周博客作业西北师范大学|李晓婷
  8. 蓝桥杯 基础练习 高精度加法
  9. Nginx配置性能优化(转)
  10. Sharepoint COMException 0x81020037