如何用PyTorch训练图像分类器
本文为 AI 研习社编译的技术博客,原标题 :
How to Train an Image Classifier in PyTorch and use it to Perform Basic Inference on Single Images
作者 | Chris Fotache
翻译 | shunshun
校对 | 酱番梨 整理 | 菠萝妹
原文链接:
https://medium.com/@chrisfotache/how-to-train-an-image-classifier-in-pytorch-and-use-it-to-perform-basic-inference-on-single-images-99465a1e9bf5
如果你刚刚开始使用PyTorch并想学习如何进行基本的图像分类,那么你可以参考本教程。它将介绍如何组织训练数据,使用预训练神经网络训练模型,然后预测其他图像。
为此,我将使用由Google地图中的地图图块组成的数据集,并根据它们包含的地形特征对它们进行分类。我会在另一篇文章中介绍如何使用它(简而言之:为了识别无人机起飞或降落的安全区域)。但是现在,我只想使用一些训练数据来对这些地图图块进行分类。
下面的代码片段来自Jupyter Notebook。你可以将它们拼接在一起以构建自己的Python脚本,或从GitHub下载。这些Notebook是基于Udacity的PyTorch课程的。如果你使用云端虚拟机进行深度学习开发并且不知道如何远程打开notebook,请查看我的教程。
组织训练数据集
PyTorch希望数据按文件夹组织,每个类对应一个文件夹。大多数其他的PyTorch教程和示例都希望你先按照训练集和验证集来组织文件夹,然后在训练集和验证集中再按照类别进行组织。但我认为这非常麻烦,必须从每个类别中选择一定数量的图像并将它们从训练集文件夹移动到验证集文件夹。由于大多数人会通过选择一组连续的文件作为验证集,因此选择可能存在很多偏差。
因此,这儿有一个将数据集快速分为训练集和测试集的更好的方法,就像Python开发人员习惯使用sklearn一样。首先,让我们导入模块:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
接下来,我们将定义train/validation数据集加载器,使用SubsetRandomSampler进行拆分:
data_dir = '/data/train'
def load_split_train_test(datadir, valid_size = .2):
train_transforms = transforms.Compose([transforms.Resize(224),
transforms.ToTensor(),
])
test_transforms = transforms.Compose([transforms.Resize(224),
transforms.ToTensor(),
])
train_data = datasets.ImageFolder(datadir,
transform=train_transforms)
test_data = datasets.ImageFolder(datadir,
transform=test_transforms)
num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
np.random.shuffle(indices)
from torch.utils.data.sampler import SubsetRandomSampler
train_idx, test_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
test_sampler = SubsetRandomSampler(test_idx)
trainloader = torch.utils.data.DataLoader(train_data,
sampler=train_sampler, batch_size=64)
testloader = torch.utils.data.DataLoader(test_data,
sampler=test_sampler, batch_size=64)
return trainloader, testloader
trainloader, testloader = load_split_train_test(data_dir, .2)
print(trainloader.dataset.classes)
接下来我们将确定是否有GPU。我假设你有一台GPU机器,否则代码将至少慢10倍。但是,检查GPU可用性是个好主意。
我们还将加载预训练模型。对于这种情况,我选择ResNet 50:
device = torch.device("cuda" if torch.cuda.is_available()
else "cpu")
model = models.resnet50(pretrained=True)
print(model)
打印模型将显示ResNet模型的图层体系结构。这可能超出了我的意识或你的理解,但看到那些深层隐藏层内的东西仍然很有趣。
这取决于你选择什么样的模型,根据你的特定数据集模型可能会不同。这里列出了所有的PyTorch模型。
现在我们进入深度神经网络的有趣部分。首先,我们必须冻结预训练过的层,因此在训练期间它们不会进行反向传播。然后,我们重新定义最后的全连接层,即使用我们的图像来训练的图层。我们还创建了标准(损失函数)并选择了一个优化器(在这种情况下为Adam)和学习率。
for param in model.parameters():
param.requires_grad = False
model.fc = nn.Sequential(nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 10),
nn.LogSoftmax(dim=1))
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.003)
model.to(device)
现在完成了,让我们训练模型吧!在这个例子中只有一个epoch,但在大多数情况下你需要更多。从代码中可以看出基本过程非常直观:加载批量图像并执行前向传播循环。然后计算损失函数,并使用优化器在反向传播中应用梯度下降。
PyTorch就这么简单。下面的大多数代码是每10个批次显示损失并计算的准确度,所以你在训练运行时得到更新。在验证期间,不要忘记将模型设置为eval()模式,然后在完成后返回train()。
epochs = 1
steps = 0
running_loss = 0
print_every = 10
train_losses, test_losses = [], []
for epoch in range(epochs):
for inputs, labels in trainloader:
steps += 1
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
logps = model.forward(inputs)
loss = criterion(logps, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if steps % print_every == 0:
test_loss = 0
accuracy = 0
model.eval()
with torch.no_grad():
for inputs, labels in testloader:
inputs, labels = inputs.to(device),
labels.to(device)
logps = model.forward(inputs)
batch_loss = criterion(logps, labels)
test_loss += batch_loss.item()
ps = torch.exp(logps)
top_p, top_class = ps.topk(1, dim=1)
equals =
top_class == labels.view(*top_class.shape)
accuracy +=
torch.mean(equals.type(torch.FloatTensor)).item()
train_losses.append(running_loss/len(trainloader))
test_losses.append(test_loss/len(testloader))
print(f"Epoch {epoch+1}/{epochs}.. "
f"Train loss: {running_loss/print_every:.3f}.. "
f"Test loss: {test_loss/len(testloader):.3f}.. "
f"Test accuracy: {accuracy/len(testloader):.3f}")
running_loss = 0
model.train()
torch.save(model, 'aerialmodel.pth')
等待几分钟后(或更长时间后,取决于数据集的大小和时期数量),完成训练并保存模型以供以后预测!
现在还有一件事可以做,即绘制训练和验证损失图:
plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Validation loss')
plt.legend(frameon=False)
plt.show()
如你所见,在我的一个epoch的特定例子中,验证损失(这是我们感兴趣的)在第一个epoch结束时的平坦线条甚至开始有上升趋势,所以可能1个epoch就足够了。正如预期的那样,训练损失非常低。
现在进入第二部分。你训练模型,保存模型,并需要在应用程序中使用它。为此,你需要能够对图像执行简单推理。你也可以在我们的存储库中找到此演示notebook。我们导入与训练笔记本中相同的模块,然后再次定义变换(transforms)。我只是再次声明图像文件夹,所以我可以使用那里的一些例子:
data_dir = '/datadrive/FastAI/data/aerial_photos/train'
test_transforms = transforms.Compose([transforms.Resize(224),
transforms.ToTensor(),
])
然后我们再次检查GPU可用性,加载模型并将其置于评估模式(因此参数不会改变):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=torch.load('aerialmodel.pth')
model.eval()
预测特定图像的类的功能非常简单。请注意,它需要Pillow图像,而不是文件路径。
def predict_image(image):
image_tensor = test_transforms(image).float()
image_tensor = image_tensor.unsqueeze_(0)
input = Variable(image_tensor)
input = input.to(device)
output = model(input)
index = output.data.cpu().numpy().argmax()
return index
现在为了便于测试,我还创建了一个从数据集文件夹中选择大量随机图像的函数:
def get_random_images(num):
data = datasets.ImageFolder(data_dir, transform=test_transforms)
classes = data.classes
indices = list(range(len(data)))
np.random.shuffle(indices)
idx = indices[:num]
from torch.utils.data.sampler import SubsetRandomSampler
sampler = SubsetRandomSampler(idx)
loader = torch.utils.data.DataLoader(data,
sampler=sampler, batch_size=num)
dataiter = iter(loader)
images, labels = dataiter.next()
return images, labels
最后,为了演示预测函数,我得到随机图像样本,预测它们并显示结果:
to_pil = transforms.ToPILImage()
images, labels = get_random_images(5)
fig=plt.figure(figsize=(10,10))
for ii in range(len(images)):
image = to_pil(images[ii])
index = predict_image(image)
sub = fig.add_subplot(1, len(images), ii+1)
res = int(labels[ii]) == index
sub.set_title(str(classes[index]) + ":" + str(res))
plt.axis('off')
plt.imshow(image)
plt.show()
以下是Google地图图块上此类预测的一个示例。标签是预测的类,我也在显示它是否是正确的预测。
这就是它。继续尝试数据集。只要你正确组织图像,此代码应该按原样运行。很快我就会有更多关于神经网络和PyTorch可以做的很酷的文章。
Chris Fotache是位于 New Jersey的 CYNET.ai的人工智能研究员。他涵盖了与生活中的人工智能,Python编程,机器学习,计算机视觉,自然语言处理等相关的主题。雷锋网雷锋网雷锋网(公众号:雷锋网)
想要继续查看该篇文章相关链接和参考文献?
长按链接点击打开或点击【如何使用PyTorch训练图像分类器】:
http://ai.yanxishe.com/page/TextTranslation/1272
AI研习社每日更新精彩内容,观看更多精彩内容:
使用Python来图像增强
新手必看:手把手教你入门 Python
多目标追踪器:用OpenCV实现多目标追踪(C++/Python)
数据科学家应当了解的五个统计基本概念:统计特征、概率分布、降维、过采样/欠采样、贝叶斯统计
等你来译:
基于图像的路径规划:Dijkstra算法
掌握机器学习必须要了解的4个概念
正向和反向运动学:雅可比和微分运动
取得自然语言处理SOA结果的分层多任务学习模型(HMTL)
如何用PyTorch训练图像分类器相关推荐
- 使用PyTorch训练图像分类器
训练分类器 2019年年初,ApacheCN组织志愿者翻译了PyTorch1.0版本中文文档(github地址),同时也获得了PyTorch官方授权,我相信已经有许多人在中文文档官网上看到了.不过目前 ...
- 计算机视觉技术 图像分类_如何训练图像分类器并教您的计算机日语
计算机视觉技术 图像分类 介绍 (Introduction) Hi. Hello. こんにちは 你好 你好. こんにちは Those squiggly characters you just saw ...
- PyTorch深度学习60分钟闪电战:04 训练一个分类器
本系列是PyTorch官网Tutorial Deep Learning with PyTorch: A 60 Minute Blitz 的翻译和总结. PyTorch概览 Autograd - 自动微 ...
- 一行代码训练一个图像分类器(Luwu教程系列)
大佬们好,很久不见--(真*很久不见=.=) 很长时间没有写过博文了,为表歉意,今天给大佬们整个花活儿~ 那就是我这次要讲的主题咯--Luwu~ 那么,Luwu是啥?是本菜鸡写的一个辣鸡开源项目-- ...
- 使用Fastai开发和部署图像分类器应用
作者|KRRAI77@GMAIL.COM 编译|Flin 来源|analyticsvidhya 介绍 Fastai是一个流行的开源库,用于学习和练习机器学习以及深度学习.杰里米·霍华德(Jeremy ...
- 麻省理工研究:深度图像分类器,居然还会过度解读
作者 | 青苹果 来源 | 数据实战派 某些情况下,深度学习方法能识别出一些在人类看来毫无意义的图像,而这些图像恰恰也是医疗和自动驾驶决策的潜在隐患所在.换句话说,深度图像分类器可以使用图像的边界,而 ...
- 利用pytorch实现多分类器
%matplotlib inline 训练分类器 就是这个.您已经了解了如何定义神经网络,计算损耗并更新网络权重. 现在你可能在想 数据怎么样? 通常,当您必须处理图像,文本,音频或视频数据时,您可以 ...
- 120种小狗图像傻傻分不清?用fastai训练一个分类器
作者:一杯奶茶的功夫 链接:https://www.jianshu.com/p/ab35ed21df87 程序员转行学什么语言? https://edu.csdn.net/topic/ai30?utm ...
- 【深度学习】翻译:60分钟入门PyTorch(四)——训练一个分类器
前言 原文翻译自:Deep Learning with PyTorch: A 60 Minute Blitz 翻译:林不清(https://www.zhihu.com/people/lu-guo-92 ...
最新文章
- 肝了3版才满意:分布式系统之CAP理论,我们对它的理解和误解
- [转]40种网页常用小技巧----Ajax中国
- Tomexam在线考试系统 2.1
- 搭建一个jupyter站点做数据分析吧
- HDU 1561 The more, The Better
- 数据库原理与应用(SQL Server)笔记 第六章 数据完整性
- 收发一体超声波测距离传感器模块_芜湖低功耗超声波液位计物位计设备排名
- OpenCV-Sobel边缘检测
- 对软件开发感到惊讶的共识
- HG255d 刷最新openwrt Pandorabox并安装njit拨号
- 基于DMD实现透过多模光纤(MMF)的聚焦
- 问卷调查报告html模版,问卷调查报告的格式
- 一篇带你熟悉MySQL
- C++操作图像、图片
- 基于纠错编码的数字水印matlab,method robustness是什么意思
- 什么是FDR校正,核磁共振成像中FDR校正方法有哪些?如何进行FDR校正?
- 使用qq邮箱发送html格式的邮件
- Python数据分析案例08——预测泰坦尼克号乘员的生存(机器学习全流程)
- Word~Word修改行间距磅值
- 前端之dl dt dd vs tr td th
热门文章
- 第 11 章 直接内存
- ClickHouse在字节跳动推荐和广告业务部门的最佳实践
- java朗控点异常_Java语言基础(day_04)
- linux mysql 查询慢_linux – MySQL非常简单的SELECT查询速度极慢
- 三分钟带你弄懂slot插槽——vue进阶
- oracle显示更新条数的函数,ORACLE学习笔记-添加更新数据函数篇
- 鸿蒙os编码_终于有人把鸿蒙OS讲明白了,大佬讲解!快收藏!
- 20190809:旋转数组
- python图片转动漫_python实现了照片转化为动漫模式
- sql server数据集中取第一条记录及保留几位小数的两种做法及前n行写法