pyTorch 图像识别教程

代码:
https://github.com/dwSun/classification-tutorial.git

这里以 TinyMind 《汉字书法识别》比赛数据为例,展示使用 Pytorch 进行图像数据分类模型训练的整个流程。

数据地址请参考:
https://www.tinymind.cn/competitions/41#property_23

或到这里下载:
自由练习赛数据下载地址:
训练集:链接: https://pan.baidu.com/s/1UxvN7nVpa0cuY1A-0B8gjg 密码: aujd

测试集: https://pan.baidu.com/s/1tzMYlrNY4XeMadipLCPzTw 密码: 4y9k

数据探索

请参考官方的数据说明

数据处理

竞赛中只有训练集 train 数据有准确的标签,因此这里只使用 train 数据即可,实际应用中,阶段 1、2 的榜单都需要使用。

数据下载

下载数据之后进行解压,得到 train 文件夹,里面有 100 个文件夹,每个文件夹名字即是各个汉字的标签。类似的数据集结构经常在分类任务中见到。可以使用下述命令验证一下每个文件夹下面文件的数量,看数据集是否符合竞赛数据描述:

for l in $(ls); do echo $l $(ls $l|wc -l); done

划分数据集

因为这里只使用了 train 集,因此我们需要对已有数据集进行划分,供模型训练的时候做验证使用,也就是 validation 集的构建。

一般认为,train 用来训练模型,validation 用来对模型进行验证以及超参数( hyper parameter)调整,test 用来做模型的最终验证,我们所谓模型的性能,一般也是指 test 集上模型的性能指标。但是实际项目中,一般只有 train 集,同时没有可靠的 test 集来验证模型,因此一般将 train 集划分出一部分作为 validation,同时将 validation 上的模型性能作为最终模型性能指标。

一般情况下,我们不严格区分 validation 和 test。

这里将每个文件夹下面随机50个文件拿出来做 validation。

export train=train
export val=validationfor d in $(ls $train); domkdir -p $val/$d/for f in $(ls train/$d | shuf | head -n 50 ); domv $train/$d/$f $val/$d/;done;
done

需要注意,这里的 validation 只间接通过超参数的调整参与了模型训练。因此有一定的数据浪费。

模型训练代码-数据部分

首先导入 pyTorch 看一下版本

import torch
import torchvision as tvtorch.__version__
'1.4.0'

训练模型的时候,模型内部全部都是数字,没有任何可读性,而且这些数字也需要人为给予一些实际的意义,这里将 100 个汉字作为模型输出数字的文字表述。

需要注意的是,因为模型训练往往是一个循环往复的过程,因此一个稳定的文字标签是很有必要的,这里利用相关 python 代码在首次运行的时候生成了一个标签文件,后续检测到这个标签文件,则直接调用即可。

import osif os.path.exists("labels.txt"):with open("labels.txt") as inf:classes = [l.strip() for l in inf]
else:classes = os.listdir("worddata/train/")with open("labels.txt", "w") as of:of.write("\r\n".join(classes))class_idx = {v: k for k, v in enumerate(classes)}
idx_class = dict(enumerate(classes))

pyTorch里面,classes有自己的组织方式,这里我们想要自定义,要做一下转换。

from PIL import Imagepth_classes = classes[:]
pth_classes.sort()
pth_classes_to_idx = {v: k for k, v in enumerate(pth_classes)}def target_transform(pth_idx):return class_idx[pth_classes[pth_idx]]

pyTorch 中提供了直接从目录中读取数据并进行训练的 API 这里使用的API如下。

这里使用了两个数据集,分别代表 train、validation。

需要注意的是,由于 数据中,使用的图像数据集,其数值在(0, 255)之间。同时,pyTorch 用 pillow 来处理图像的加载,其图像的数据layout是(H,W,C),而 pyTorch用来训练的数据需要是(C,H,W)的,因此需要对数据做一些转换。另外,train 数据集做了一定的数据预处理(旋转、明暗度),用于进行数据增广,也做了数据打乱(shuffle),而 validation则不需要做类似的变换。

这里有一些地方需要注意一下,RandomRotation 我们使用了 expand 所以每次输出图像大小都不同,resize 操作要放在后面。pyTorch 中我没找到如何直接用灰度方式读取图像,对于汉字来说,色彩没有任何意义。因此这里用 Grayscale 来转换图像为灰度。ToTensor这个操作会转换数据的 layout,因此要放在最后面。

from multiprocessing import cpu_counttransform_train = tv.transforms.Compose([# tv.transforms.RandomRotation((-15, 15), expand=True),tv.transforms.RandomRotation((-15, 15)),tv.transforms.Resize((128, 128)),tv.transforms.ColorJitter(brightness=0.5),tv.transforms.Grayscale(),tv.transforms.ToTensor(),]
)
transform_val = tv.transforms.Compose([tv.transforms.Resize((128, 128)),tv.transforms.Grayscale(),tv.transforms.ToTensor(),]
)img_gen_train = tv.datasets.ImageFolder("worddata/train/", transform=transform_train, target_transform=target_transform
)img_gen_val = tv.datasets.ImageFolder("worddata/validation/", transform=transform_val, target_transform=target_transform
)batch_size = 32img_train = torch.utils.data.DataLoader(img_gen_train, batch_size=batch_size, shuffle=True, num_workers=cpu_count()
)
img_val = torch.utils.data.DataLoader(img_gen_val, batch_size=batch_size, num_workers=cpu_count()
)

到这里,这两个数据集就可以使用了,正式模型训练之前,我们可以先来看看这个数据集是怎么读取数据的,读取出来的数据又是设么样子的。

for imgs, labels in img_train:# img_train 只部分满足 generator 的语法,不能用 next 来获取数据break
imgs.shape, labels.shape
(torch.Size([32, 1, 128, 128]), torch.Size([32]))

可以看到数据是(batch, channel, height, width, height), 因为这里是灰度图像,因此 channel 是 1。

需要注意,pyTorch、mxnet使用的数据 layout 与Tensorflow 不同,因此数据也有一些不同的处理方式。

把图片打印出来看看,看看数据和标签之间是否匹配

import numpy as np
from matplotlib import pyplot as pltplt.imshow(imgs[0, 0, :, :], cmap="gray")
classes[labels[0]]
'寒'

模型训练代码-模型构建

pyTorch 中使用静态图来构建模型,模型构建比较简单。这里演示的是使用 class 的方式构建模型,对于简单模型,还可以直接使用 Sequential 进行构建。

这里的复杂模型也是用 Sequential 的简单模型进行的叠加。

这里构建的是VGG模型,关于VGG模型的更多细节请参考 1409.1556。

class MyModel(torch.nn.Module):def __init__(self):super(MyModel, self).__init__()# 模型有两个主要部分,特征提取层和分类器# 这里是特征提取层self.feature = torch.nn.Sequential()self.feature.add_module("conv1", self.conv(1, 64))self.feature.add_module("conv2", self.conv(64, 64, add_pooling=True))self.feature.add_module("conv3", self.conv(64, 128))self.feature.add_module("conv4", self.conv(128, 128, add_pooling=True))self.feature.add_module("conv5", self.conv(128, 256))self.feature.add_module("conv6", self.conv(256, 256))self.feature.add_module("conv7", self.conv(256, 256, add_pooling=True))self.feature.add_module("conv8", self.conv(256, 512))self.feature.add_module("conv9", self.conv(512, 512))self.feature.add_module("conv10", self.conv(512, 512, add_pooling=True))self.feature.add_module("conv11", self.conv(512, 512))self.feature.add_module("conv12", self.conv(512, 512))self.feature.add_module("conv13", self.conv(512, 512, add_pooling=True))self.feature.add_module("avg", torch.nn.AdaptiveAvgPool2d((1, 1)))self.feature.add_module("flatten", torch.nn.Flatten())self.feature.add_module("linear1", torch.nn.Linear(512, 4096))self.feature.add_module("act_linear_1", torch.nn.ReLU())self.feature.add_module("bn_linear_1", torch.nn.BatchNorm1d(4096))self.feature.add_module("linear2", torch.nn.Linear(4096, 4096))self.feature.add_module("act_linear_2", torch.nn.ReLU())self.feature.add_module("bn_linear_2", torch.nn.BatchNorm1d(4096))self.feature.add_module("dropout", torch.nn.Dropout())# 这个简单的机构是分类器self.pred = torch.nn.Linear(4096, 100)def conv(self, in_channels, out_channels, add_pooling=False):# 模型大量使用重复模块构建,# 这里将重复模块提取出来,简化模型构建过程model = torch.nn.Sequential()model.add_module("conv", torch.nn.Conv2d(in_channels, out_channels, 3, padding=1))model.add_module("act_conv", torch.nn.ReLU())model.add_module("bn_conv", torch.nn.BatchNorm2d(out_channels))if add_pooling:model.add_module("pool", torch.nn.MaxPool2d((2, 2)))return modeldef forward(self, x):# call 用来定义模型各个结构之间的运算关系x = self.feature(x)return self.pred(x)

可以看到,这里必须指定网络输入输出,对比 TF 和 mxnet 不是很方便。

实例化一个模型看看:

model = MyModel()
model
MyModel((feature): Sequential((conv1): Sequential((conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act_conv): ReLU()(bn_conv): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv2): Sequential((conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act_conv): ReLU()(bn_conv): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False))(conv3): Sequential((conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act_conv): ReLU()(bn_conv): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv4): Sequential((conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act_conv): ReLU()(bn_conv): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False))(conv5): Sequential((conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act_conv): ReLU()(bn_conv): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv6): Sequential((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act_conv): ReLU()(bn_conv): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv7): Sequential((conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act_conv): ReLU()(bn_conv): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False))(conv8): Sequential((conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act_conv): ReLU()(bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv9): Sequential((conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act_conv): ReLU()(bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv10): Sequential((conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act_conv): ReLU()(bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False))(conv11): Sequential((conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act_conv): ReLU()(bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv12): Sequential((conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act_conv): ReLU()(bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(conv13): Sequential((conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(act_conv): ReLU()(bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False))(avg): AdaptiveAvgPool2d(output_size=(1, 1))(flatten): Flatten()(linear1): Linear(in_features=512, out_features=4096, bias=True)(act_linear_1): ReLU()(bn_linear_1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(linear2): Linear(in_features=4096, out_features=4096, bias=True)(act_linear_2): ReLU()(bn_linear_2): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(dropout): Dropout(p=0.5, inplace=False))(pred): Linear(in_features=4096, out_features=100, bias=True)
)

模型训练代码-训练相关部分

要训练模型,我们还需要定义损失,优化器等。

loss_object = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())  # 优化器有些参数可以设置
import time  # 模型训练的过程中手动追踪一下模型的训练速度

因为模型整个训练过程一般是一个循环往复的过程,所以经常性的保存重启模型训练中间过程是有必要的。
这里我们一个ckpt保存了两份,便于中断模型的重新训练。

import osgpu = 1model.cuda(gpu)
if os.path.exists("model_ckpt.pth"):# 检查 checkpoint 是否存在# 如果存在,则加载 checkpointnet_state, optm_state = torch.load("model_ckpt.pth")model.load_state_dict(net_state)optimizer.load_state_dict(optm_state)# 这里是一个比较生硬的方式,其实还可以观察之前训练的过程,# 手动选择准确率最高的某次 checkpoint 进行加载。print("model lodaded")
EPOCHS = 40
for epoch in range(EPOCHS):train_loss = 0train_accuracy = 0train_samples = 0val_loss = 0val_accuracy = 0val_samples = 0start = time.time()for imgs, labels in img_train:imgs = imgs.cuda(gpu)labels = labels.cuda(gpu)preds = model(imgs)loss = loss_object(preds, labels)optimizer.zero_grad()loss.backward()optimizer.step()train_samples += imgs.shape[0]train_loss += loss.item()train_accuracy += (preds.argmax(dim=1) == labels).sum().item()train_samples_per_second = train_samples / (time.time() - start)start = time.time()for imgs, labels in img_val:imgs = imgs.cuda(gpu)labels = labels.cuda(gpu)model.eval()preds = model(imgs)model.train()val_loss += loss.item()val_accuracy += (preds.argmax(dim=1) == labels).sum().item()val_samples += imgs.shape[0]val_samples_per_second = val_samples / (time.time() - start)print("Epoch {} Loss {}, Acc {}, Val Loss {}, Val Acc {}".format(epoch,train_loss * batch_size / train_samples,train_accuracy * 100 / train_samples,val_loss * batch_size / val_samples,val_accuracy * 100 / val_samples,))print("Speed train {}imgs/s val {}imgs/s".format(train_samples_per_second, val_samples_per_second))torch.save((model.state_dict(), optimizer.state_dict()), "model_ckpt.pth")torch.save((model.state_dict(), optimizer.state_dict()),"model_ckpt-{:04d}.pth".format(epoch),)# 每个 epoch 保存一下模型,需要注意每次# 保存要用一个不同的名字,不然会导致覆盖,# 同时还要关注一下磁盘空间占用,防止太多# chekcpoint 占满磁盘空间导致错误。
Epoch 0 Loss 6.379009783063616, Acc 0.96, Val Loss 7.499249600219726, Val Acc 1.06
Speed train 234.8145027420376imgs/s val 706.7103467600252imgs/s
Epoch 1 Loss 6.250997774396624, Acc 1.1485714285714286, Val Loss 8.049263702392578, Val Acc 1.12
Speed train 230.86756209036096imgs/s val 684.4524146021467imgs/s
Epoch 2 Loss 6.0144778276715956, Acc 1.16, Val Loss 5.387196655273438, Val Acc 1.22
Speed train 226.12469399959375imgs/s val 681.3837266883015imgs/s
Epoch 3 Loss 5.589597338867187, Acc 1.0057142857142858, Val Loss 5.907029174804688, Val Acc 1.1
Speed train 225.0556244090409imgs/s val 680.1839950846079imgs/s
Epoch 4 Loss 5.402270581054688, Acc 1.1714285714285715, Val Loss 5.1295126434326175, Val Acc 1.4
Speed train 224.72032368819234imgs/s val 682.6134263746877imgs/s
Epoch 5 Loss 5.175169513811384, Acc 1.062857142857143, Val Loss 5.386006506347656, Val Acc 0.78
Speed train 224.69030177626445imgs/s val 680.2019310561357imgs/s
Epoch 6 Loss 4.945824640328544, Acc 1.2485714285714287, Val Loss 5.106301385498047, Val Acc 1.7
Speed train 224.41665937455835imgs/s val 680.7435881896529imgs/s
Epoch 7 Loss 4.78519496547154, Acc 1.2714285714285714, Val Loss 5.001887857055664, Val Acc 1.46
Speed train 224.3077478615822imgs/s val 679.1943518703032imgs/s
Epoch 8 Loss 4.681244001116071, Acc 1.5542857142857143, Val Loss 4.678347979736328, Val Acc 1.76
Speed train 224.43681583771664imgs/s val 680.9421894383488imgs/s
Epoch 9 Loss 4.594511456734794, Acc 1.977142857142857, Val Loss 4.734268209838867, Val Acc 3.48
Speed train 224.43505593143354imgs/s val 679.8263336556521imgs/s
Epoch 10 Loss 4.564881538609096, Acc 2.2142857142857144, Val Loss 4.5007019622802735, Val Acc 3.5
Speed train 224.35177417457732imgs/s val 681.4680194348261imgs/s
Epoch 11 Loss 4.359355766732352, Acc 3.797142857142857, Val Loss 4.303946963500977, Val Acc 5.1
Speed train 224.22713261480806imgs/s val 678.4494364538398imgs/s
Epoch 12 Loss 4.05738628692627, Acc 6.651428571428571, Val Loss 3.5746582946777345, Val Acc 10.42
Speed train 224.18908188445624imgs/s val 679.7794406874162imgs/s
Epoch 13 Loss 3.7937849918910436, Acc 10.214285714285714, Val Loss 3.7133444229125976, Val Acc 7.04
Speed train 224.10400818300923imgs/s val 679.113765396056imgs/s
Epoch 14 Loss 3.2694146046229773, Acc 19.425714285714285, Val Loss 3.6981288192749022, Val Acc 30.78
Speed train 224.13154184979035imgs/s val 680.3901503258012imgs/s
Epoch 15 Loss 2.7287981418064664, Acc 31.591428571428573, Val Loss 2.7384859634399414, Val Acc 43.0
Speed train 224.08093870796063imgs/s val 680.9818793692156imgs/s
Epoch 16 Loss 2.4017765145438057, Acc 40.222857142857144, Val Loss 2.373513427734375, Val Acc 55.06
Speed train 224.0240886741711imgs/s val 680.9301175413847imgs/s
Epoch 17 Loss 1.9575243755885532, Acc 50.81428571428572, Val Loss 1.8042015686035155, Val Acc 60.3
Speed train 224.023773128244imgs/s val 679.3334877190623imgs/s
Epoch 18 Loss 1.8670056664603096, Acc 52.754285714285714, Val Loss 1.7974752388000488, Val Acc 59.12
Speed train 224.04512456698183imgs/s val 677.1390343683371imgs/s
Epoch 19 Loss 1.6107693487439836, Acc 58.48571428571429, Val Loss 2.0469212783813475, Val Acc 66.72
Speed train 223.87349316639512imgs/s val 678.0275851549168imgs/s
Epoch 20 Loss 1.7171708895547049, Acc 56.642857142857146, Val Loss 1.8279149505615235, Val Acc 65.96
Speed train 223.90746153613367imgs/s val 677.346826146328imgs/s
Epoch 21 Loss 1.2915482904706683, Acc 65.76285714285714, Val Loss 1.3189221771240234, Val Acc 68.62
Speed train 223.9696856948392imgs/s val 680.7548137514192imgs/s
Epoch 22 Loss 1.1914144684110368, Acc 68.43714285714286, Val Loss 1.0220409889221191, Val Acc 67.86
Speed train 223.9880397987904imgs/s val 679.2464221336535imgs/s
Epoch 23 Loss 1.0181893185751778, Acc 72.81428571428572, Val Loss 0.6417443874359131, Val Acc 75.94
Speed train 223.8553653733427imgs/s val 678.307898091706imgs/s
Epoch 24 Loss 0.9370736787523543, Acc 75.12, Val Loss 1.0853789276123047, Val Acc 76.04
Speed train 223.88541791918547imgs/s val 681.0767780628761imgs/s
Epoch 25 Loss 0.858675898034232, Acc 76.78, Val Loss 0.6966076656341553, Val Acc 76.14
Speed train 223.89225800022191imgs/s val 677.8778524201055imgs/s
Epoch 26 Loss 0.911681534739903, Acc 75.55142857142857, Val Loss 1.2748067726135255, Val Acc 75.28
Speed train 223.8182146384268imgs/s val 678.3320102990313imgs/s
Epoch 27 Loss 0.7263422344616481, Acc 80.27428571428571, Val Loss 0.8662283229827881, Val Acc 76.24
Speed train 223.84745302957168imgs/s val 677.5476529925469imgs/s
Epoch 28 Loss 0.7096801671164377, Acc 80.64, Val Loss 0.4879303056716919, Val Acc 78.4
Speed train 223.82907588526868imgs/s val 679.641487715212imgs/s
Epoch 29 Loss 0.8400143226759774, Acc 77.24857142857142, Val Loss 0.48099885005950926, Val Acc 76.28
Speed train 223.79857240610696imgs/s val 677.2004991093537imgs/s
Epoch 30 Loss 0.6340663018226623, Acc 82.75428571428571, Val Loss 0.3814028434753418, Val Acc 77.88
Speed train 223.69922632756038imgs/s val 677.8182801677248imgs/s
Epoch 31 Loss 0.6143715186391558, Acc 83.12571428571428, Val Loss 1.5937435668945312, Val Acc 55.9
Speed train 223.73651624919145imgs/s val 677.9854113416313imgs/s
Epoch 32 Loss 0.6921936396871294, Acc 80.92285714285714, Val Loss 0.6802982173919677, Val Acc 74.06
Speed train 223.6073489617921imgs/s val 675.2316887830606imgs/s
Epoch 33 Loss 0.6144891169275556, Acc 83.18, Val Loss 0.46930033054351805, Val Acc 76.34
Speed train 223.57683437615302imgs/s val 677.3064650659536imgs/s
Epoch 34 Loss 0.568616727393014, Acc 84.28, Val Loss 0.4940680891036987, Val Acc 79.08
Speed train 223.53035806536994imgs/s val 675.1831667819793imgs/s
Epoch 35 Loss 0.5646722382409232, Acc 84.30571428571429, Val Loss 0.4494327730178833, Val Acc 79.24
Speed train 223.53708898876783imgs/s val 676.0269683297067imgs/s
Epoch 36 Loss 0.977967550604684, Acc 74.1, Val Loss 0.8460039363861084, Val Acc 74.68
Speed train 223.70501802195884imgs/s val 675.9956110209276imgs/s
Epoch 37 Loss 0.7239568670545306, Acc 80.12857142857143, Val Loss 1.048443465423584, Val Acc 80.18
Speed train 223.7010776600049imgs/s val 677.2457684330305imgs/s
Epoch 38 Loss 0.5576571273531232, Acc 84.37714285714286, Val Loss 0.7641737712860107, Val Acc 78.96
Speed train 223.92617365625426imgs/s val 679.0357044899268imgs/s
Epoch 39 Loss 0.4953382140840803, Acc 86.30285714285715, Val Loss 0.9396348545074463, Val Acc 80.98
Speed train 223.80456065939208imgs/s val 677.9383775081458imgs/s

一些技巧

因为这里定义的模型比较大,同时训练的数据也比较多,每个 epoch 用时较长,因此,如果代码有 bug 的话,经过一次 epoch 再去 debug 效率比较低。

这种情况下,我们使用的数据生成过程又是自己手动指定数据数量的,因此可以尝试缩减模型规模,定义小一些的数据集来快速验证代码。在这个例子里,我们可以通过注释模型中的卷积和全连接层的代码来缩减模型尺寸,通过修改训练循环里面的数据数量来缩减数据数量。

训练的速度很慢

类似的网络结构和参数,TF里面 20epochs能达到90%的准确率,这里要40epochs才能到86%,应该是哪里有什么问题,我再看看怎么解决。


下面的代码属于另外一个文件,因此部分代码跟上面是重复的。

模型的使用代码

模型训练好了之后要实际应用。对于模型部署有很多成熟的方案,如 Nvidia 的 TensorRT, Intel 的 OpenVINO 等,都可以做模型的高效部署,这里限于篇幅不涉及相关内容。

在模型训练过程中,也可以使用使用框架提供的 API 做模型的简单部署以方便开发。

import torch
import torchvision as tv
import os
torch.__version__
'1.4.0'

首先要加载模型的标签用于展示,因为我们训练的时候就已经生成了标签文件,这里直接用写好的代码就可以。

if os.path.exists("labels.txt"):with open("labels.txt") as inf:classes = [l.strip() for l in inf]
else:classes = os.listdir("worddata/train/")with open("labels.txt", "w") as of:of.write("\r\n".join(classes))

接着是模型的定义,这里直接将训练中使用的模型代码拿来即可。

class MyModel(torch.nn.Module):def __init__(self):super(MyModel, self).__init__()# 模型有两个主要部分,特征提取层和分类器# 这里是特征提取层self.feature = torch.nn.Sequential()self.feature.add_module("conv1", self.conv(1, 64))self.feature.add_module("conv2", self.conv(64, 64, add_pooling=True))self.feature.add_module("conv3", self.conv(64, 128))self.feature.add_module("conv4", self.conv(128, 128, add_pooling=True))self.feature.add_module("conv5", self.conv(128, 256))self.feature.add_module("conv6", self.conv(256, 256))self.feature.add_module("conv7", self.conv(256, 256, add_pooling=True))self.feature.add_module("conv8", self.conv(256, 512))self.feature.add_module("conv9", self.conv(512, 512))self.feature.add_module("conv10", self.conv(512, 512, add_pooling=True))self.feature.add_module("conv11", self.conv(512, 512))self.feature.add_module("conv12", self.conv(512, 512))self.feature.add_module("conv13", self.conv(512, 512, add_pooling=True))self.feature.add_module("avg", torch.nn.AdaptiveAvgPool2d((1, 1)))self.feature.add_module("flatten", torch.nn.Flatten())self.feature.add_module("linear1", torch.nn.Linear(512, 4096))self.feature.add_module("act_linear_1", torch.nn.ReLU())self.feature.add_module("bn_linear_1", torch.nn.BatchNorm1d(4096))self.feature.add_module("linear2", torch.nn.Linear(4096, 4096))self.feature.add_module("act_linear_2", torch.nn.ReLU())self.feature.add_module("bn_linear_2", torch.nn.BatchNorm1d(4096))self.feature.add_module("dropout", torch.nn.Dropout())# 这个简单的机构是分类器self.pred = torch.nn.Linear(4096, 100)def conv(self, in_channels, out_channels, add_pooling=False):# 模型大量使用重复模块构建,# 这里将重复模块提取出来,简化模型构建过程model = torch.nn.Sequential()model.add_module("conv", torch.nn.Conv2d(in_channels, out_channels, 3, padding=1))model.add_module("act_conv", torch.nn.ReLU())model.add_module("bn_conv", torch.nn.BatchNorm2d(out_channels))if add_pooling:model.add_module("pool", torch.nn.MaxPool2d((2, 2)))return modeldef forward(self, x):# call 用来定义模型各个结构之间的运算关系x = self.feature(x)return self.pred(x)

有了模型的定义之后,我们可以加载训练好的模型,跟模型训练的时候类似,我们可以直接加载模型训练中的 checkpoint。

import os
model = MyModel().cuda()if os.path.exists('ckpts_pth/model_ckpt.pth'):# 检查 checkpoint 是否存在# 如果存在,则加载 checkpointnet_state, optm_state = torch.load('ckpts_pth/model_ckpt.pth')model.load_state_dict(net_state)# 这里是一个比较生硬的方式,其实还可以观察之前训练的过程,# 手动选择准确率最高的某次 checkpoint 进行加载。print("model lodaded")
model lodaded

对于数据,我们需要直接处理图片,因此这里导入一些图片处理的库和数据处理的库

from matplotlib import pyplot as plt
import numpy as np
from PIL import Image

直接打开某个图片

img = Image.open("worddata/validation/从/116e891836204e4e67659d2b73a7e4780a37c301.jpg")plt.imshow(img, cmap="gray")

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hNIynK8t-1587993165325)(output_11_1.png)]

需要注意,模型在训练的时候,我们对数据进行了一些处理,在模型使用的时候,我们要对数据做一样的处理,如果不做的话,模型最终的结果会出现不可预料的问题。

img = img.resize((128, 128))
img = np.array(img) / 255
img.shape
(128, 128)

模型对图片数据的运算其实很简单,一行代码就可以。

这里需要注意模型处理的数据是 4 维的,而上面的图片数据实际是 2 维的,因此要对数据进行维度的扩充。同时模型的输出是 2 维的,带 batch ,所以需要压缩一下维度。

model.eval()
pred = np.squeeze(model(torch.Tensor(img[np.newaxis, np.newaxis, :, :]).cuda()))
pred = torch.nn.functional.softmax(pred)
pred.argsort()[-5:]print([pred[idx].item() for idx in pred.argsort()[-5:]])
print([classes[idx] for idx in pred.argsort()[-5:]])
[7.042408323165716e-11, 1.551086897810805e-10, 2.2588204917628474e-10, 4.854148372146483e-08, 1.0]
['遂', '夜', '御', '作', '从']/home/dl/miniconda3/lib/python3.7/site-packages/ipykernel_launcher.py:4: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.after removing the cwd from sys.path.

这里只给出了 top5 的结果,可以看到,准确率还是不错的。

pyTorch 图像分类模型训练教程相关推荐

  1. 【pytorch速成】Pytorch图像分类从模型自定义到测试

    文章首发于微信公众号<与有三学AI> [pytorch速成]Pytorch图像分类从模型自定义到测试 前面已跟大家介绍了Caffe和TensorFlow,链接如下. [caffe速成]ca ...

  2. 【小白学习PyTorch教程】六、基于CIFAR-10 数据集,使用PyTorch 从头开始​​构建图像分类模型...

    「@Author:Runsen」 图像识别本质上是一种计算机视觉技术,它赋予计算机"眼睛",让计算机通过图像和视频"看"和理解世界. 在开始阅读本文之前,建议先 ...

  3. 【小白学习PyTorch教程】六、基于CIFAR-10 数据集,使用PyTorch 从头开始​​构建图像分类模型

    @Author:Runsen 图像识别本质上是一种计算机视觉技术,它赋予计算机"眼睛",让计算机通过图像和视频"看"和理解世界. 在开始阅读本文之前,建议先了解 ...

  4. 利用Pytorch搭建简单的图像分类模型(之二)---搭建网络

    Pytorch搭建网络模型-ResNet 一.ResNet的两个结构 首先来看一下ResNet和一般卷积网络结构上的差异: 图中上面一部分就是ResNet34的网络结构图,下面可以理解为一个含有34层 ...

  5. 使用pytorch训练你自己的图像分类模型(包括模型训练、推理预测、误差分析)

    开源代码:https://github.com/xxcheng0708/Pytorch_Image_Classifier_Template​​​​​ 使用pytorch框架搭建一个图像分类模型通常包含 ...

  6. Pytorch通用图像分类模型(支持20+分类模型),直接带入数据就可训练自己的数据集,包括模型训练、推理、部署。

    Pytorch-Image-Classifier-Collection 介绍 ============================== 支持多模型工程化的图像分类器 =============== ...

  7. pytorch图像分类_使用PyTorch和Streamlit创建图像分类Web应用

    pytorch图像分类 You just developed a cool ML model. 您刚刚开发了一个很酷的ML模型. You are proud of it. You want to sh ...

  8. mnist数据集彩色图像_使用MNIST数据集构建多类图像分类模型。

    mnist数据集彩色图像 Below are the steps to build a model that can classify handwritten digits with an accur ...

  9. [Pytorch图像分类全流程实战]Task06:可解释性分析

    目录 前言 CAM热力图系列算法 [A]安装配置环境 [B] torchcam命令行 [C1]Pytorch预训练ImageNet图像分类-单张图像 [C2] Pytorch预训练lmageNet图像 ...

最新文章

  1. java 日志 生成器_自动生成 java 测试 mock 对象框架 DataFactory-01-入门使用教程
  2. PHP中正则表达式学习及应用(二)
  3. C++字符串函数与C字符串函数比较
  4. [转] C# 获取程序运行目录
  5. 文件上传(input为file类型)
  6. Python新手常见错误汇总|附代码检查清单
  7. wait和notify使用例子
  8. SQL2005的配置
  9. java泛型函数类型推断_为什么javac可以推断用作参数的函数的泛型类型参数?
  10. Bailian1192 最优连通子集【DFS】
  11. mysql longbolb_MySql基本数据类型及约束
  12. 解决启动Eclipse后提示’Running android lint’错误的问题
  13. 记录VS在线安装下载慢的解决
  14. arima 公式_R时间序列分析(8)ARIMA(上)
  15. java opts tomcat,tomcat JAVA_OPTS配备
  16. IT新人的辛酸反省与总结
  17. 安卓camera2 API获取YUV420_888格式详解
  18. 记录Springboot+Mybatis_Plus进行CRUD与分页的注意点
  19. 【宇麦科技】群晖NAS套件之Drive的客户端安装与配置(二),新手必读!
  20. vivado下ERROR: [USF-XSim-62] [VRFC 10-3180]

热门文章

  1. android edittext背景颜色,Android 设置 EditText 背景颜色、背景图片
  2. python .txt文件转.csv文件-ok
  3. Flume 数据采集组件
  4. java微信小程序接口openid过期_Java微信小程序登录接口获取openid
  5. AD软件出现“Your license is already used on computer “LAPTOP-F99R6OR1“ using product “AltiumDesigner“
  6. python大数据技术_大数据技术python
  7. 大数据用kettle还是python_kettle大数据抽取实际
  8. PROE产品设计:20个机械设计知识点
  9. java预研项目_YAML预研文档
  10. 多点相册--将手机的照片和视频备份到电脑的工具