早在 2012 年,神经网络就首次赢得了 ImageNet 大规模视觉识别挑战。Alex Krizhevsky,Ilya Sutskever 和 Geoffrey Hinton 彻底改变了图像分类领域。

如今,为图像(或图像分类)分配单个标签的任务已经非常成熟。然而,实际场景并不仅限于“一张图片一个标签”的任务——有时,我们需要更多!

在这篇文章中,我们将看看分类任务的一种修改——所谓的多输出分类或图像标记。

1.什么是多标签分类

在图像分类领域,您可能会遇到需要确定对象的多个属性的场景。例如,这些可以是类别、颜色、尺寸等。与通常的图像分类相比,此任务的输出将包含 2 个或更多属性。

在本教程中,我们将关注一个我们事先知道属性数量的问题。这种任务称为多输出分类。实际上,这是多标签分类的一个特例,您还可以预测多个属性,但它们的数量可能因样本而异。

2.数据集

我们将在 Kaggle 网站上提供的“Fashion Product Images”数据集的低分辨率子集上进行练习:https://www.kaggle.com/

在这篇文章中,我们将使用时尚产品图像数据集。它包含超过 44 000 张衣服和配饰的图像,每张图像有 9 个标签。

文件 fashion-product-images/styles.csv 包含数据标签。为简单起见,我们将在教程中仅使用三个标签:gender、articleType 和 baseColour

让我们看一下数据集中的一些示例:

让我们也从数据注释中提取我们类别的所有唯一标签。总的来说,我们将拥有:

  • 5个性别值(男孩,女孩,男人,中性,女人)
  • 47种颜色
  • 143件物品(如运动凉鞋、钱包或毛衣)。

我们的目标是创建和训练一个神经网络模型来预测我们数据集中图像的三个标签(性别、物品和颜色)。

3.所需的库

  • matplotlib
  • numpy
  • pillow
  • scikit-learn
  • torch
  • torchvision
  • tqdm
    所有的库都可以从 requirements.txt 文件安装:
python3 -m pip install -r requirements.txt

尽管下面的代码与设备无关并且可以在 CPU 上运行,但我建议使用 GPU 来显着减少训练时间。 GPU 是脚本中的默认选项。

4.拆分数据集

我们总共使用 40 000 张图像。我们将其中的 32 000 个放入训练集中,其余 8 000 个用于验证。要拆分数据,请运行 split_data.py 脚本:

all_data = []
# 打开注释文件
with open(annotation) as csv_file:# 将其解析为 CSVreader = csv.DictReader(csv_file)# tqdm 显示进度条# CSV 文件中的每一行都对应于图像for row in tqdm(reader, total=reader.line_num):# 我们需要图像 ID 来构建图像文件的路径img_id = row['id']# 我们将只使用 3 个属性gender = row['gender']articleType = row['articleType']baseColour = row['baseColour']img_name = os.path.join(input_folder, 'images', str(img_id) + '.jpg')# 检查文件是否存在if os.path.exists(img_name):# 检查图像是否有 80*60 像素和 3 个通道img = Image.open(img_name)if img.size == (60, 80) and img.mode == "RGB":all_data.append([img_name, gender, articleType, baseColour])else:print("Something went wrong: there is no file ", img_name)# 设置随机数生成器的种子,以便我们稍后重现结果
np.random.seed(42)
# 从列表中构造一个 Numpy 数组
all_data = np.asarray(all_data)# 随机抽取 40000 个样本
inds = np.random.choice(40000, 40000, replace=False)
# 将数据拆分为 train/val 并将它们保存为 csv 文件
save_csv(all_data[inds][:32000], os.path.join(output_folder, 'train.csv'))
save_csv(all_data[inds][32000:40000], os.path.join(output_folder, 'val.csv'))

上面的代码创建了 train.csv 和 val.csv。在相应的拆分文件中存储图像列表及其标签。

5.加载数据集

由于在数据注释中有多个标签,我们需要调整读取数据并将其加载到内存中的方式。为此,我们将创建一个继承PyTorch Dataset的类。它将能够解析我们的数据注释并只提取我们感兴趣的标签。多输出和多分类之间的关键区别是,我们将从数据集为每个样本返回几个标签。

class FashionDataset(Dataset):def __init__(...):...# 初始化数组以存储真实标签和图像路径self.data = []self.color_labels = []self.gender_labels = []self.article_labels = []# 从 CSV 文件中读取注释with open(annotation_path) as f:reader = csv.DictReader(f)for row in reader:self.data.append(row['image_path'])self.color_labels.append(self.attr.color_name_to_id[row['baseColour']])self.gender_labels.append(self.attr.gender_name_to_id[row['gender']])self.article_labels.append(self.attr.article_name_to_id[row['articleType']])

我们数据集类的 getitem() 函数获取一个图像和三个相应的标签。然后它为训练增强图像,并将其标签作为字典返回:

def __getitem__(self, idx):# 按索引取数据样本img_path = self.data[idx]# 读取图像img = Image.open(img_path)# 如果需要,应用图像增强if self.transform:img = self.transform(img)# 返回图像和所有相关标签dict_data = {'img': img,'labels': {'color_labels': self.color_labels[idx],'gender_labels': self.gender_labels[idx],'article_labels': self.article_labels[idx]}}return dict_data

好的,看来我们已经准备好加载我们的数据了。

6.数据增强

数据增强是保持图像可识别的随机变换。它们随机化数据,从而帮助我们在训练网络时对抗过度拟合。

在这里,我们将使用随机翻转、轻微的颜色修改、旋转、缩放和平移(统一在仿射变换中)。我们还将在将数据加载到网络之前对其进行标准化——这是深度学习中的标准方法。

# 在训练期间指定用于增强的图像变换
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0),transforms.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.8, 1.2),shear=None, resample=False, fillcolor=(255, 255, 255)),transforms.ToTensor(),transforms.Normalize(mean, std)
])

在验证阶段,我们不会随机化数据——只是将其标准化并将其转换为 PyTorch Tensor 格式。

# 在验证期间,我们只使用张量和归一化变换
val_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)
])

现在,当我们的数据集准备好时,让我们定义模型。

7.构建模型

查看模型类定义。我们从 torchvision.models 获取 mobilenet_v2 网络。这个模型可以解决ImageNet的分类问题,所以它的最后一层是单分类器。

为了将此模型用于多输出任务,我们将对其进行修改。我们需要预测三种属性,所以我们将使用三个新的分类头,而不是一个单一的分类器:这些头被称为颜色、性别和文章。每个头都有它自己的交叉熵损失。

现在让我们看看我们如何定义网络和这些新头。

class MultiOutputModel(nn.Module):def __init__(self, n_color_classes, n_gender_classes, n_article_classes):super().__init__()self.base_model = models.mobilenet_v2().features  # 取没有分类器的模型last_channel = models.mobilenet_v2().last_channel # 分类器前层的大小# 分类器的输入应该是二维的,但我们将有# [<batch_size>, <channels>, <width>, <height>]# 所以,让我们做空间平均:将 <width> 和 <height> 减少到 1self.pool = nn.AdaptiveAvgPool2D((1, 1))# 为我们的输出创建单独的分类器self.color = nn.Sequential(nn.Dropout(p=0.2),nn.Linear(in_features=last_channel, out_features=n_color_classes))self.gender = nn.Sequential(nn.Dropout(p=0.2),nn.Linear(in_features=last_channel, out_features=n_gender_classes))self.article = nn.Sequential(nn.Dropout(p=0.2),nn.Linear(in_features=last_channel, out_features=n_article_classes))

在通过网络的前向传递中,我们还使用自适应平均池化对最后 2 个张量维度(宽度和高度)进行平均。我们这样做是为了得到一个适合作为我们分类器输入的张量。请注意,我们将每个分类器并行应用于网络输出,并返回一个包含三个结果值的字典:

def forward(self, x):x = self.base_model(x)x = self.pool(x)# 将 [batch, channels, 1, 1] 重塑为 [batch, channels] 以将其放入分类器x = torch.flatten(x, start_dim=1)return {'color': self.color(x),'gender': self.gender(x),'article': self.article(x)}

现在让我们定义多输出网络的损失。事实上,我们将我们的损失定义为三种损失的总和——颜色、性别和物品:

def get_loss(self, net_output, ground_truth):color_loss = F.cross_entropy(net_output['color'], ground_truth['color_labels'])gender_loss = F.cross_entropy(net_output['gender'], ground_truth['gender_labels'])article_loss = F.cross_entropy(net_output['article'], ground_truth['article_labels'])loss = color_loss + gender_loss + article_lossreturn loss, {'color': color_loss, 'gender': gender_loss, 'article': article_loss}

现在我们已经准备好了模型和数据。让我们开始训练。

8.训练

多输出分类情况的训练过程与单输出分类任务的训练过程相同,所以我在这里只提到几个步骤。

首先,我们将定义训练和模型本身的几个参数。

在这里我使用小批量,因为在这种情况下它提供了更好的准确性。您可以尝试不同的值(例如 128 或 256)并自行检查 - 训练时间会减少,但效果可能会受到影响。

N_epochs = 50
batch_size = 16
...model = MultiOutputModel(n_color_classes=attributes.num_colors, n_gender_classes=attributes.num_genders,n_article_classes=attributes.num_articles).to(device)optimizer = torch.optim.Adam(model.parameters())

然后我们在主循环中运行训练:

for epoch in range(start_epoch, N_epochs + 1):...for batch in train_dataloader:optimizer.zero_grad()...

将数据批次馈送到网络:

img = batch['img']
target_labels = batch['labels']
...output = model(img.to(device))
...

计算损失和准确度:

loss_train, losses_train = model.get_loss(output, target_labels)
total_loss += loss_train.item()
batch_accuracy_color, batch_accuracy_gender, batch_accuracy_article = \calculate_metrics(output, target_labels)
...

最后,我们通过我们的模型反向传播损失并应用由此产生的权重更新:

loss_train.backward()
optimizer.step()
...

我们每 5 个 epoch 对验证数据集运行一次评估,并每 25 个 epoch 保存一次checkpoint:

if epoch % 5 == 0:validate(model, val_dataloader, attributes, logger, epoch, device)if epoch % 25 == 0:checkpoint_save(model, savedir, epoch)
...

9.评估

让我们暂时回到单输出分类任务。该问题的“默认”指标是什么?这是准确性。在最简单的情况下(我们不考虑类不平衡)准确率的定义是我们传递给模型的所有数据中正确预测的计数:

我们的多输出分类任务的指标应该是什么?确实,我们仍然可以使用准确性!回想一下,我们有几个来自网络的独立输出——每个标签一个。我们可以用与单输出分类相同的方式独立计算每个标签的准确度。

首先,我们应该将图像从数据集传递给模型并获得预测。在下面的代码中,我们将为“颜色”类执行此操作,但对于我们用于训练的所有类,该过程是相同的:

# 将模型置于评估模式
model.eval()# 初始化真实值和预测标签
predicted_color_all = []
gt_color_all = []# 浏览所有图片
for batch in dataloader:images = batch["img"]# 我们将为“颜色”预测构建混淆矩阵gt_colors = batch["labels"]["color_labels"]target_labels = {"color": gr_colors.to(device)}# 获取模型输出output = model(images.to(device))# 为每张图像获得最可信的预测_, predicted_colors = output["color"].cpu().max(1)predicted_color_all.extend(prediction.item() for prediction in predicted_colors)gt_color_all.extend(gt_color.item() for gt_color in gt_colors)

接下来,有了所有的预测和标签,我们可以计算准确率。具体来说,我们可以计算模型推理循环中每个批次的准确率,并在批次之间取平均值。由于我们将进一步使用预测和真实值,让我们保留它们并在循环外进行准确度计算:

from sklearn.metrics import accuracy_scoreaccuracy_color = accuracy_score(gt_color_all, predicted_color_all)

如果我们查看指标,我们会发现最终模型对物品类型的准确率约为 80%,对性别的准确率为 82%,对颜色的准确率为 60%。 这些值还可以,但不是很好。我们来看看测试数据集中的图像和预测标签:

大多数预测看起来都很合理,那么到底出了什么问题呢?

10.混淆矩阵

混淆矩阵是用于调试图像分类模型的出色工具。使用它,您可以获得关于您的模型可以很好地识别哪些类以及它混淆了哪些类的宝贵见解。

构建混淆矩阵图,我们首先需要的是模型预测。是的,这就是我们更早保存它们的原因!

由于我们有预测和真实标签,我们已准备好构建混淆矩阵:

from sklearn.metrics import (confusion_matrix,ConfusionMatrixDisplay
)
...cn_matrix = confusion_matrix(y_true=gt_color_all,y_pred=predicted_color_all,labels=attributes.color_labels,normalize="true",
)
ConfusionMatrixDisplay(cn_matrix, attributes.color_labels).plot(include_values=False, xticks_rotation="vertical"
)
plt.title("Colors")
plt.tight_layout()
plt.show()


现在很明显,该模型混淆了类似的颜色,例如洋红色、粉红色和紫色。即使对于人类来说,也很难识别数据集中表示的所有 47 种颜色。

正如我们所见,在我们的案例中,低色彩准确度并不是什么大问题。如果要修复它,可以将数据集中的颜色数量减少到例如 10 个,将相似颜色重新映射到一个类,然后重新训练模型。你应该得到更好的结果。

对于性别,我们看到了类似的行为:

该模型混淆了“女孩”和“女性”标签、“男性”和“中性”。同样,对于人类来说,在这些情况下有时也很难检测到正确的衣服标签。

最后,这里是衣服和配饰的混淆矩阵。请注意,它的主对角线非常明显,即在大多数情况下预测标签与基本事实相符:

同样,有些物品很难区分——下面的这些包就是很好的例子:

11.完整代码

(1)split_data.py

# split_data.py
import argparse
import csv
import osimport numpy as np
from PIL import Image
from tqdm import tqdmdef save_csv(data, path, fieldnames=['image_path', 'gender', 'articleType', 'baseColour']):with open(path, 'w', newline='') as csv_file:writer = csv.DictWriter(csv_file, fieldnames=fieldnames)writer.writeheader()for row in data:writer.writerow(dict(zip(fieldnames, row)))if __name__ == '__main__':parser = argparse.ArgumentParser(description='Split data for the dataset')parser.add_argument('--input', type=str, default="./fashion-product-images", help="Path to the dataset")parser.add_argument('--output', type=str, default="", help="Path to the working folder")args = parser.parse_args()input_folder = args.inputoutput_folder = args.outputannotation = os.path.join(input_folder, 'styles.csv')# 打开注释文件all_data = []with open(annotation) as csv_file:# 将其解析为 CSVreader = csv.DictReader(csv_file)# tqdm 显示进度条# CSV 文件中的每一行都对应于图像for row in tqdm(reader, total=reader.line_num):# 我们需要图像 ID 来构建图像文件的路径img_id = row['id']# 我们将只使用 3 个属性gender = row['gender']articleType = row['articleType']baseColour = row['baseColour']img_name = os.path.join(input_folder, 'images', str(img_id) + '.jpg')# 检查文件是否存在if os.path.exists(img_name):# 检查图像是否有 80*60 像素和 3 个通道img = Image.open(img_name)if img.size == (60, 80) and img.mode == "RGB":all_data.append([img_name, gender, articleType, baseColour])# 设置随机数生成器的种子,以便我们稍后重现结果np.random.seed(42)# 从列表中构造一个 Numpy 数组all_data = np.asarray(all_data)# 随机抽取 40000 个样本inds = np.random.choice(40000, 40000, replace=False)# 将数据拆分为 train/val 并将它们保存为 csv 文件save_csv(all_data[inds][:32000], os.path.join(output_folder, 'train.csv'))save_csv(all_data[inds][32000:40000], os.path.join(output_folder, 'val.csv'))

(2)train.py

import argparse
import os
from datetime import datetimeimport torch
import torchvision.transforms as transforms
from dataset import FashionDataset, AttributesDataset, mean, std
from model import MultiOutputModel
from test import calculate_metrics, validate, visualize_grid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdef get_cur_time():return datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M')def checkpoint_save(model, name, epoch):f = os.path.join(name, 'checkpoint-{:06d}.pth'.format(epoch))torch.save(model.state_dict(), f)print('Saved checkpoint:', f)if __name__ == '__main__':parser = argparse.ArgumentParser(description='Training pipeline')parser.add_argument('--attributes_file', type=str, default='./fashion-product-images/styles.csv',help="Path to the file with attributes")parser.add_argument('--device', type=str, default='cuda', help="Device: 'cuda' or 'cpu'")args = parser.parse_args()start_epoch = 1N_epochs = 50batch_size = 16num_workers = 8  # 处理数据集加载的进程数device = torch.device("cuda" if torch.cuda.is_available() and args.device == 'cuda' else "cpu")# 属性变量包含数据集中类别的标签以及字符串名称和 ID 之间的映射attributes = AttributesDataset(args.attributes_file)# 在训练期间指定用于增强的图像变换train_transform = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0),transforms.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.8, 1.2),shear=None, resample=False, fillcolor=(255, 255, 255)),transforms.ToTensor(),transforms.Normalize(mean, std)])# 在验证期间,我们只使用张量和归一化变换val_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])train_dataset = FashionDataset('./train.csv', attributes, train_transform)train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)val_dataset = FashionDataset('./val.csv', attributes, val_transform)val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)model = MultiOutputModel(n_color_classes=attributes.num_colors,n_gender_classes=attributes.num_genders,n_article_classes=attributes.num_articles)\.to(device)optimizer = torch.optim.Adam(model.parameters())logdir = os.path.join('./logs/', get_cur_time())savedir = os.path.join('./checkpoints/', get_cur_time())os.makedirs(logdir, exist_ok=True)os.makedirs(savedir, exist_ok=True)logger = SummaryWriter(logdir)n_train_samples = len(train_dataloader)# 取消注释下面的行以查看 val 数据集中带有真实标签的示例图像和所有标签:# visualize_grid(model, val_dataloader, attributes, device, show_cn_matrices=False, show_images=True,#                checkpoint=None, show_gt=True)# print("\nAll gender labels:\n", attributes.gender_labels)# print("\nAll color labels:\n", attributes.color_labels)# print("\nAll article labels:\n", attributes.article_labels)print("Starting training ...")for epoch in range(start_epoch, N_epochs + 1):total_loss = 0accuracy_color = 0accuracy_gender = 0accuracy_article = 0for batch in train_dataloader:optimizer.zero_grad()img = batch['img']target_labels = batch['labels']target_labels = {t: target_labels[t].to(device) for t in target_labels}output = model(img.to(device))loss_train, losses_train = model.get_loss(output, target_labels)total_loss += loss_train.item()batch_accuracy_color, batch_accuracy_gender, batch_accuracy_article = \calculate_metrics(output, target_labels)accuracy_color += batch_accuracy_coloraccuracy_gender += batch_accuracy_genderaccuracy_article += batch_accuracy_articleloss_train.backward()optimizer.step()print("epoch {:4d}, loss: {:.4f}, color: {:.4f}, gender: {:.4f}, article: {:.4f}".format(epoch,total_loss / n_train_samples,accuracy_color / n_train_samples,accuracy_gender / n_train_samples,accuracy_article / n_train_samples))logger.add_scalar('train_loss', total_loss / n_train_samples, epoch)if epoch % 5 == 0:validate(model, val_dataloader, logger, epoch, device)if epoch % 25 == 0:checkpoint_save(model, savedir, epoch)

(3)test.py

import argparse
import os
import warningsimport matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as transforms
from dataset import FashionDataset, AttributesDataset, mean, std
from model import MultiOutputModel
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, balanced_accuracy_score
from torch.utils.data import DataLoaderdef checkpoint_load(model, name):print('Restoring checkpoint: {}'.format(name))model.load_state_dict(torch.load(name, map_location='cpu'))epoch = int(os.path.splitext(os.path.basename(name))[0].split('-')[1])return epochdef validate(model, dataloader, logger, iteration, device, checkpoint=None):if checkpoint is not None:checkpoint_load(model, checkpoint)model.eval()with torch.no_grad():avg_loss = 0accuracy_color = 0accuracy_gender = 0accuracy_article = 0for batch in dataloader:img = batch['img']target_labels = batch['labels']target_labels = {t: target_labels[t].to(device) for t in target_labels}output = model(img.to(device))val_train, val_train_losses = model.get_loss(output, target_labels)avg_loss += val_train.item()batch_accuracy_color, batch_accuracy_gender, batch_accuracy_article = \calculate_metrics(output, target_labels)accuracy_color += batch_accuracy_coloraccuracy_gender += batch_accuracy_genderaccuracy_article += batch_accuracy_articlen_samples = len(dataloader)avg_loss /= n_samplesaccuracy_color /= n_samplesaccuracy_gender /= n_samplesaccuracy_article /= n_samplesprint('-' * 72)print("Validation  loss: {:.4f}, color: {:.4f}, gender: {:.4f}, article: {:.4f}\n".format(avg_loss, accuracy_color, accuracy_gender, accuracy_article))logger.add_scalar('val_loss', avg_loss, iteration)logger.add_scalar('val_accuracy_color', accuracy_color, iteration)logger.add_scalar('val_accuracy_gender', accuracy_gender, iteration)logger.add_scalar('val_accuracy_article', accuracy_article, iteration)model.train()def visualize_grid(model, dataloader, attributes, device, show_cn_matrices=True, show_images=True, checkpoint=None,show_gt=False):if checkpoint is not None:checkpoint_load(model, checkpoint)model.eval()imgs = []labels = []gt_labels = []gt_color_all = []gt_gender_all = []gt_article_all = []predicted_color_all = []predicted_gender_all = []predicted_article_all = []accuracy_color = 0accuracy_gender = 0accuracy_article = 0with torch.no_grad():for batch in dataloader:img = batch['img']gt_colors = batch['labels']['color_labels']gt_genders = batch['labels']['gender_labels']gt_articles = batch['labels']['article_labels']output = model(img.to(device))batch_accuracy_color, batch_accuracy_gender, batch_accuracy_article = \calculate_metrics(output, batch['labels'])accuracy_color += batch_accuracy_coloraccuracy_gender += batch_accuracy_genderaccuracy_article += batch_accuracy_article# get the most confident prediction for each image_, predicted_colors = output['color'].cpu().max(1)_, predicted_genders = output['gender'].cpu().max(1)_, predicted_articles = output['article'].cpu().max(1)for i in range(img.shape[0]):image = np.clip(img[i].permute(1, 2, 0).numpy() * std + mean, 0, 1)predicted_color = attributes.color_id_to_name[predicted_colors[i].item()]predicted_gender = attributes.gender_id_to_name[predicted_genders[i].item()]predicted_article = attributes.article_id_to_name[predicted_articles[i].item()]gt_color = attributes.color_id_to_name[gt_colors[i].item()]gt_gender = attributes.gender_id_to_name[gt_genders[i].item()]gt_article = attributes.article_id_to_name[gt_articles[i].item()]gt_color_all.append(gt_color)gt_gender_all.append(gt_gender)gt_article_all.append(gt_article)predicted_color_all.append(predicted_color)predicted_gender_all.append(predicted_gender)predicted_article_all.append(predicted_article)imgs.append(image)labels.append("{}\n{}\n{}".format(predicted_gender, predicted_article, predicted_color))gt_labels.append("{}\n{}\n{}".format(gt_gender, gt_article, gt_color))if not show_gt:n_samples = len(dataloader)print("\nAccuracy:\ncolor: {:.4f}, gender: {:.4f}, article: {:.4f}".format(accuracy_color / n_samples,accuracy_gender / n_samples,accuracy_article / n_samples))# 绘制混淆矩阵if show_cn_matrices:# colorcn_matrix = confusion_matrix(y_true=gt_color_all,y_pred=predicted_color_all,labels=attributes.color_labels,normalize='true')ConfusionMatrixDisplay(cn_matrix, attributes.color_labels).plot(include_values=False, xticks_rotation='vertical')plt.title("Colors")plt.tight_layout()plt.show()# gendercn_matrix = confusion_matrix(y_true=gt_gender_all,y_pred=predicted_gender_all,labels=attributes.gender_labels,normalize='true')ConfusionMatrixDisplay(cn_matrix, attributes.gender_labels).plot(xticks_rotation='horizontal')plt.title("Genders")plt.tight_layout()plt.show()# 取消下面代码的注释,查看物品混淆矩阵(可能太大无法显示)cn_matrix = confusion_matrix(y_true=gt_article_all,y_pred=predicted_article_all,labels=attributes.article_labels,normalize='true')plt.rcParams.update({'font.size': 1.8})plt.rcParams.update({'figure.dpi': 300})ConfusionMatrixDisplay(cn_matrix, attributes.article_labels).plot(include_values=False, xticks_rotation='vertical')plt.rcParams.update({'figure.dpi': 100})plt.rcParams.update({'font.size': 5})plt.title("Article types")plt.show()if show_images:labels = gt_labels if show_gt else labelstitle = "Ground truth labels" if show_gt else "Predicted labels"n_cols = 5n_rows = 3fig, axs = plt.subplots(n_rows, n_cols, figsize=(10, 10))axs = axs.flatten()for img, ax, label in zip(imgs, axs, labels):ax.set_xlabel(label, rotation=0)ax.get_xaxis().set_ticks([])ax.get_yaxis().set_ticks([])ax.imshow(img)plt.suptitle(title)plt.tight_layout()plt.show()model.train()def calculate_metrics(output, target):_, predicted_color = output['color'].cpu().max(1)gt_color = target['color_labels'].cpu()_, predicted_gender = output['gender'].cpu().max(1)gt_gender = target['gender_labels'].cpu()_, predicted_article = output['article'].cpu().max(1)gt_article = target['article_labels'].cpu()with warnings.catch_warnings():  # sklearn 在处理混淆矩阵中的零行时可能会产生警告warnings.simplefilter("ignore")accuracy_color = balanced_accuracy_score(y_true=gt_color.numpy(), y_pred=predicted_color.numpy())accuracy_gender = balanced_accuracy_score(y_true=gt_gender.numpy(), y_pred=predicted_gender.numpy())accuracy_article = balanced_accuracy_score(y_true=gt_article.numpy(), y_pred=predicted_article.numpy())return accuracy_color, accuracy_gender, accuracy_articleif __name__ == '__main__':parser = argparse.ArgumentParser(description='Inference pipeline')parser.add_argument('--checkpoint', type=str, default=r'checkpoints\2021-10-14_15-12\checkpoint-000050.pth', help="Path to the checkpoint")parser.add_argument('--attributes_file', type=str, default='./fashion-product-images/styles.csv',help="Path to the file with attributes")parser.add_argument('--device', type=str, default='cuda',help="Device: 'cuda' or 'cpu'")args = parser.parse_args()device = torch.device("cuda" if torch.cuda.is_available() and args.device == 'cuda' else "cpu")# 属性变量包含数据集中类别的标签以及字符串名称和 ID 之间的映射attributes = AttributesDataset(args.attributes_file)# 在验证期间,我们只使用张量和归一化变换val_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])test_dataset = FashionDataset('val.csv', attributes, val_transform)test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8)model = MultiOutputModel(n_color_classes=attributes.num_colors, n_gender_classes=attributes.num_genders,n_article_classes=attributes.num_articles).to(device)# 训练模型的可视化visualize_grid(model, test_dataloader, attributes, device, checkpoint=args.checkpoint)

(4)辅助函数model.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as modelsclass MultiOutputModel(nn.Module):def __init__(self, n_color_classes, n_gender_classes, n_article_classes):super().__init__()self.base_model = models.mobilenet_v2().features  # 取没有分类器的模型last_channel = models.mobilenet_v2().last_channel  # 分类器前层的大小# 分类器的输入应该是二维的,但我们将有# [batch_size, channels, width, height]# 所以,让我们做空间平均:将宽度和高度减少到 1self.pool = nn.AdaptiveAvgPool2d((1, 1))# 为我们的输出创建单独的分类器self.color = nn.Sequential(nn.Dropout(p=0.2),nn.Linear(in_features=last_channel, out_features=n_color_classes))self.gender = nn.Sequential(nn.Dropout(p=0.2),nn.Linear(in_features=last_channel, out_features=n_gender_classes))self.article = nn.Sequential(nn.Dropout(p=0.2),nn.Linear(in_features=last_channel, out_features=n_article_classes))def forward(self, x):x = self.base_model(x)x = self.pool(x)# 将 [batch, channels, 1, 1] 重塑为 [batch, channels] 以将其放入分类器x = torch.flatten(x, 1)return {'color': self.color(x),'gender': self.gender(x),'article': self.article(x)}def get_loss(self, net_output, ground_truth):color_loss = F.cross_entropy(net_output['color'], ground_truth['color_labels'])gender_loss = F.cross_entropy(net_output['gender'], ground_truth['gender_labels'])article_loss = F.cross_entropy(net_output['article'], ground_truth['article_labels'])loss = color_loss + gender_loss + article_lossreturn loss, {'color': color_loss, 'gender': gender_loss, 'article': article_loss}

(5)辅助函数之dataset.py

import csvimport numpy as np
from PIL import Image
from torch.utils.data import Datasetmean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]class AttributesDataset():def __init__(self, annotation_path):color_labels = []gender_labels = []article_labels = []with open(annotation_path) as f:reader = csv.DictReader(f)for row in reader:color_labels.append(row['baseColour'])gender_labels.append(row['gender'])article_labels.append(row['articleType'])self.color_labels = np.unique(color_labels)self.gender_labels = np.unique(gender_labels)self.article_labels = np.unique(article_labels)self.num_colors = len(self.color_labels)self.num_genders = len(self.gender_labels)self.num_articles = len(self.article_labels)self.color_id_to_name = dict(zip(range(len(self.color_labels)), self.color_labels))self.color_name_to_id = dict(zip(self.color_labels, range(len(self.color_labels))))self.gender_id_to_name = dict(zip(range(len(self.gender_labels)), self.gender_labels))self.gender_name_to_id = dict(zip(self.gender_labels, range(len(self.gender_labels))))self.article_id_to_name = dict(zip(range(len(self.article_labels)), self.article_labels))self.article_name_to_id = dict(zip(self.article_labels, range(len(self.article_labels))))class FashionDataset(Dataset):def __init__(self, annotation_path, attributes, transform=None):super().__init__()self.transform = transformself.attr = attributes# 初始化数组以存储真实标签和图像路径self.data = []self.color_labels = []self.gender_labels = []self.article_labels = []# 从 CSV 文件中读取注释with open(annotation_path) as f:reader = csv.DictReader(f)for row in reader:self.data.append(row['image_path'])self.color_labels.append(self.attr.color_name_to_id[row['baseColour']])self.gender_labels.append(self.attr.gender_name_to_id[row['gender']])self.article_labels.append(self.attr.article_name_to_id[row['articleType']])def __len__(self):return len(self.data)def __getitem__(self, idx):# 按索引取数据样本img_path = self.data[idx]# 读取图像img = Image.open(img_path)# 如果需要,应用图像增强if self.transform:img = self.transform(img)# 返回图像和所有相关标签dict_data = {'img': img,'labels': {'color_labels': self.color_labels[idx],'gender_labels': self.gender_labels[idx],'article_labels': self.article_labels[idx]}}return dict_data

总结

在本教程中,我们学习了如何从现有的单输出模型构建多输出模型。我们还展示了如何使用混淆矩阵检查结果的有效性。

作为最后的建议,我建议在训练之前始终检查您的数据集。通过这种方式,您可以深入了解您的数据:更好地了解您感兴趣的对象、标签及其在数据中的分布等。这通常是使您的模型获得最佳结果的重要步骤。

链接:https://pan.baidu.com/s/1F2KmD1f9jw8TMqwenMru9g
提取码:123a

参考目录

https://learnopencv.com/multi-label-image-classification-with-pytorch/

Pytorch基础知识(15)基于PyTorch的多标签图像分类相关推荐

  1. 第02章 PyTorch基础知识

    文章目录 第02章 Pytorch基础知识 2.1 张量 2.2 自动求导 2.3 并行计算简介 2.3.1 为什么要做并行计算 2.3.2 CUDA是个啥 2.3.3 做并行的方法 补充:通过股票数 ...

  2. 深入浅出Pytorch:02 PyTorch基础知识

    深入浅出Pytorch 02 PyTorch基础知识 内容属性:深度学习(实践)专题 航路开辟者:李嘉骐.牛志康.刘洋.陈安东 领航员:叶志雄 航海士:李嘉骐.牛志康.刘洋.陈安东 开源内容:http ...

  3. Pytorch之CNN:基于Pytorch框架实现经典卷积神经网络的算法(LeNet、AlexNet、VGG、NIN、GoogleNet、ResNet)——从代码认知CNN经典架构

    Pytorch之CNN:基于Pytorch框架实现经典卷积神经网络的算法(LeNet.AlexNet.VGG.NIN.GoogleNet.ResNet)--从代码认知CNN经典架构 目录 CNN经典算 ...

  4. pytorch基础知识+构建LeNet对Cifar10进行训练+PyTorch-OpCounter统计模型大小和参数量+模型存储与调用

    整个环境的配置请参考我另一篇博客.ubuntu安装python3.5+pycharm+anaconda+opencv+docker+nvidia-docker+tensorflow+pytorch+C ...

  5. Day01_网页开发基础知识、HTML概述、HTML标签

    01.01_网页开发基础知识 代码编写工具 python PyCharm Subline iPython... HTML HBulider Dreamweaver WebStorm Eclipse.. ...

  6. 基于Keras的多标签图像分类

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 本篇记录一下自己项目中用到的keras相关的部分.由于本项目既有涉及multi-class(多 ...

  7. Pytorch基础知识(9)单目标分割

    目标分割是在图像中寻找目标物体边界的过程.目标分割有很多应用.例如,通过勾勒医学图像中的解剖对象,临床专家可以了解有关患者病情的有用信息. 根据图像中目标的数量,我们可以进行单目标或多目标分割任务.本 ...

  8. Pytorch基础知识整理(六)参数初始化

    参数初始化的目的是限定网络权重参数的初始分布,试图让权重参数更接近参数空间的最优解,从而加速训练.pytorch中网络默认初始化参数为随机均匀分布,设定额外的参数初始化并非总能加速训练. 1,模板 在 ...

  9. pytorch基础知识整理(五) 优化器

    深度学习网络必须通过优化器进行训练.在pytorch中相关代码位于torch.optim模块中. 1, 常规用法 optimizer = torch.optim.Adam(model.paramete ...

最新文章

  1. 顺序表的结构和9个基本运算算法
  2. (转载)Web 开发人员需知的 Web 缓存知识
  3. boost::ratio_less_equal相关的测试程序
  4. yunos5 linux内核,魅蓝5S、魅蓝5对比看差异 选Android还是YunOS?
  5. 小程序开发语言python_小程序是用什么语言开发的?5种最佳语言分享
  6. Unicode - 想说爱你不容易
  7. 2021年吉林高考成绩怎么查询,2021年吉林高考成绩排名查询系统,吉林高考位次排名查询...
  8. mysql 要完 知乎_必知必会 MySQL笔记(未完)
  9. 学习时间序列法ARIMA模型与LSTM很好的文章
  10. EBP与ESP的作用
  11. 手机收不到验证码问题
  12. 如何搭建一个简单的个人网站
  13. 强制关机后进不了系统
  14. 栅栏布局合并html,arcgis栅格数据合并 arcgis栅格图像拼接步骤
  15. Apache htaccess 重写如果文件存在!
  16. ubuntu16.04 百度网盘加速下载文件
  17. 基于SpringBoot+MyBatis的餐饮点餐系统
  18. python for循环遍历涉及的相关问题及代码实现(非全部)
  19. 用英文给嵌入式计算机下定义,嵌入式课程设计报告--嵌入式系统项目设计.doc
  20. 敏捷开发绩效管理系列之八:阿米巴经营之序言

热门文章

  1. Windows部署halo并配置自启动服务
  2. springboot+dubbo+zookeeper 项目实战
  3. 图灵机器人——VQA模型的介绍
  4. HBASE Compaction 简介
  5. 重装系统后无法启动连接 mysql
  6. 哔哩哔哩电脑版怎么缓存视频?
  7. 元道N90刷机升级到安卓4.0.3
  8. 深度学习(18)机器学习常用的评价指标
  9. [UITabBar appearance]不生效
  10. 大数据数据库选型:NoSQL数据库入门