图像来自:快速上手笔记,PyTorch模型训练实用教程(附代码) - 知乎

目录

1、数据处理模块搭建

2、模型构建

3、开始训练

4、评估模型

5、使用模型进行预测

6、保存模型


1、数据处理模块搭建

这里需要根据自己的数据集进行选择合适的方法,这里就以图像分类作为一个例子来说明。

通常有两种方法:

(1)采用torchvision中的datasets.ImageFolder来读取图像,然后采用torch.utils.data.DataLoader加载;

Ps:这种情况一般是想要读取一自己在一个文件夹中的数据作为数据集
具体的形式如下:
dataset/cat/0.jpg1.jpgdog/0.jpg1.jpg--------------------------
这种情况使用ImageFolder就比较方便

(2)继承torch.utils.data.Dataset来实现用户自定义,然后采用torch.utils.data.DataLoader加载;

torch.utils.data.Dataset 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。Pytorch提供两种数据集: Map式数据集 Iterable式数据集对于Map式数据集处理方式:重写getitem(self, index),len(self) 两个内建方法,用来表示从索引到样本的映射(Map).当使用dataset[idx]命令时,可以在你的硬盘中读取你的数据集中第idx张图片以及其标签(如果有的话);len(dataset)则会返回这个数据集的容量。上述参考:https://zhuanlan.zhihu.com/p/105507334

自定义模块可以参考:

class CustomDataset(torch.utils.data.Dataset):#需要继承torch.utils.data.Datasetdef __init__(self):# TODO# 1. Initialize file path or list of file names.passdef __getitem__(self, index):# TODO# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).# 2. Preprocess the data (e.g. torchvision.Transform).# 3. Return a data pair (e.g. image and label).# 这里需要注意的是,第一步:read one data,是一个datapassdef __len__(self):# You should change 0 to the total size of your dataset.return 0

参考一个实例:

from torch.utils import data
import numpy as np
from PIL import Image# 参考:https://zhuanlan.zhihu.com/p/105507334
class face_dataset(data.Dataset):def __init__(self):# 数据集的路径self.file_path = './data/faces/'# 对应的数据集和标签,这里是保存在txt文件中的,也有的是json文件,或者csv文件等# 根据自己的数据集情况而定f = open("final_train_tag_dict.txt","r")self.label_dict = eval(f.read())f.close()def __getitem__(self,index):"""通过index返回对应的img和label"""label = list(self.label_dict.values())[index-1]img_id = list(self.label_dict.keys())[index-1]img_path = self.file_path+str(img_id)+".jpg"img = np.array(Image.open(img_path))return img,labeldef __len__(self):# 返回整个数据集的数量return len(self.label_dict)

在这里我采用第一种形式,因为我采用的数据集是下面这种形式:

每个文件对应一个类别,如果你采用的数据集是给定了一个image_label.txt或者image_label.csv,则采用第二种数据处理方法比较方便;

第一种方法的实现代码如下:

from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,datasets# 1、Data augmentation
# https://pytorch.org/vision/stable/transforms.html
# 数据增强部分可根据自己的情况选择,可以参考官方代码
transforms_train = transforms.Compose([transforms.ToTensor(),transforms.ColorJitter(),transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))])
# valid不需要数据增强
transforms_valid = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))])# 2、load dataset
ds_train = datasets.ImageFolder("../data/train/",transform=transforms_train,target_transform=lambda t:torch.tensor([t]).float())
ds_valid = datasets.ImageFolder("../data/test/",transform=transforms_valid,target_transform=lambda t:torch.tensor([t]).float())

官方文档:https://pytorch.org/vision/stable/datasets.html#torchvision.datasets.ImageFolderhttps://pytorch.org/vision/stable/datasets.html#torchvision.datasets.ImageFolder

torchvision.datasets.ImageFolder(root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = <function default_loader>, is_valid_file: Optional[Callable[[str], bool]] = None)

Parameters:1、root (string) – Root directory path.->数据集地址2、transform (callable, optional) – A function/transform that takes
in an PIL image and returns a transformed version. E.g, transforms.RandomCrop3、target_transform (callable, optional) – A function/transform that
takes in the target and transforms it.主要是处理对应的图像标签4、loader (callable, optional) – A function to load an image given
its path.5、is_valid_file – A function that takes path of an Image file and
check if the file is a valid file (used to check of corrupt files)
检查数据集中图像是否损坏Returns:通过:__getitem__(index: int) → 得到:Tuple[Any, Any]1、(sample, target) where target is class_index of the target class.

经过ImageLoader之后的数据具体是什么格式?

从上图可以看出返回的samples中是一个元组(图像的地址,图像的标签);

targets对应每张图像的标签,classes所有数据的类别,class_to_idx类别索引,extensions图像支持的扩张名等

# 查看数据集中的类别
print(ds_train)
print(ds_valid.classes)
# 每个类别对应的标签
print(ds_valid.class_to_idx)

经过ImageLoader处理后,还需要经过DataLoader进一步处理:

# 通过DataLoader加载ImageFolder
# 这里的num_workers为了避免出错,尽量设置为0
dl_train = DataLoader(ds_train,batch_size=50,shuffle=True,num_workers=0)
dl_valid = DataLoader(ds_valid,batch_size=50,shuffle=True,num_workers=0)

注意:这个num_workers如果设置为其他数字,刚开始可能没问题,但是后续会可能会出现问题,不妨设置为0;

官方文档:

https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoaderhttps://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader

介绍:

DataLoader.:Combines a dataset and a sampler, and provides an iterable over the given dataset.

包括一个数据集和一个采样器,并且提供一个给定数据集的可迭代对象;

The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.

DataLoader支持的格式比较多,本次采用的是map-style;

看一下经过DataLoader之后的数据形式:

查看一下数据集中的部分样本:

import matplotlib.pyplot as plt
plt.figure(figsize=(5,5))
for i in range(9):# ds_train[i]也可以img,label = ds_valid[i]# 图像是b*c*w*h->b*w*h*cimg = img.permute(1,2,0)ax = plt.subplot(3,3,i+1)ax.imshow(img.numpy())ax.set_title("label = %d"%label.item(),fontsize=8)ax.set_xticks([])ax.set_yticks([])
plt.show()

2、模型构建

(1)使用torch.nn.Sequential按层顺序构建模型;

(2)继承torch.nn.Module基类构建模型;

(3)继承torch.nn.Module基类构建并辅助应用模型容器(nn.Sequential,nn.ModuleList,nn.ModuleDict);

nn.Sequential案例

# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(1,20,5)),('relu1', nn.ReLU()),('conv2', nn.Conv2d(20,64,5)),('relu2', nn.ReLU())]))

而对于nn.Module通过官方的一个example:

它是所有神经网络模块的基类,自己定义的模型应该继承这个类。

同时该模块还可以包含其他模块,允许将它们嵌套在树结构中。

import torch.nn as nn
import torch.nn.functional as Fclass Model(nn.Module):def __init__(self):super(Model, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x):x = F.relu(self.conv1(x))return F.relu(self.conv2(x))

本文采用继承nn.Module创建model

class Image_Net(nn.Module):def __init__(self):super(Image_Net,self).__init__()self.conv1 = nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3)self.pool = nn.MaxPool2d(kernel_size=2,stride=2)self.conv2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=5)self.dropout = nn.Dropout2d(p=0.2)self.adaptive_pool = nn.AdaptiveMaxPool2d((1,1))self.flatten = nn.Flatten()self.linear1 = nn.Linear(64,32)self.relu = nn.ReLU()self.linear2 = nn.Linear(32,1)self.sigmoid = nn.Sigmoid()def forward(self,x):x = self.conv1(x)x = self.pool(x)x = self.conv2(x)x = self.dropout(x)x = self.adaptive_pool(x)x = self.flatten(x)x = self.linear1(x)x = self.relu(x)x = self.linear2(x)y = self.sigmoid(x)return y# 实例化
net = Image_Net()
print(net)

3、开始训练

首先设置一些训练参数

import pandas as pd
# 其他指标可以查看sklearn.metrics
from sklearn.metrics import roc_auc_score
model = Image_Net()
model.optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
model.loss_func = torch.nn.BCELoss()
model.metric_func = lambda y_pred,y_true:roc_auc_score(y_true.data.numpy(),y_pred.data.numpy())
model.metric_name = "auc"

下面采用函数式训练循环

首先创建train模块

def train(model,features,labels):""":param model: :param features: :param labels: :return: loss & metric"""# 训练模式,dropout层发生作用model.train()# 梯度清零model.optimizer.zero_grad()# 正向传播求损失predictions = model(features)# 计算损失loss = model.loss_func(predictions,labels)# metric计算,这里选择的是AUCmetric = model.metric_func(predictions,labels)# 反向传播求梯度loss.backward()model.optimizer.step()return loss.item(),metric.item()

然后创建valid模块:

def valid(model,features,labels):"""因为只是验证所以不对模型的参数进行更新,只需要输出对应的结果就行:param model: :param features: :param labels: :return: loss & metric"""# 预测模式,dropout层不发生作用model.eval()predictions = model(features)loss = model.loss_func(predictions,labels)metric = model.metric_func(predictions,labels)return loss.item(),metric.item()

设置GPU

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device) # 移动模型到cuda

还有一个重要的事情,就是GPU用完一定要及时释放:在上述设置GPU的情况下,增加下述代码:

torch.cuda.empty_cache()

tensorflow中清理显存的方法:解决tensorflow占用GPU显存问题

完整的训练代码如下:

import datetime
def train_model(model, epochs, dl_train, dl_valid, log_step_freq):metric_name = model.metric_name# 用于记录训练过程中的loss和metricdfhistory = pd.DataFrame(columns=["epoch", "loss", metric_name, "val_loss", "val_" + metric_name])print("Start Training...")nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')print("==========" * 8 + "%s" % nowtime)for epoch in range(1, epochs + 1):# 1,训练循环-------------------------------------------------loss_sum = 0.0metric_sum = 0.0step = 1for step, (features, labels) in enumerate(dl_train, 1):# train模块,也可以直接放在这里loss, metric = train(model, features, labels)# 打印batch级别日志loss_sum += lossmetric_sum += metric# 设置打印freqif step % log_step_freq == 0:print(("[step = %d] loss: %.3f, " + metric_name + ": %.3f") %(step, loss_sum / step, metric_sum / step))# 2,验证循环-------------------------------------------------val_loss_sum = 0.0val_metric_sum = 0.0val_step = 1for val_step, (features, labels) in enumerate(dl_valid, 1):# valid模块val_loss, val_metric = valid(model, features, labels)val_loss_sum += val_lossval_metric_sum += val_metric# 3,记录日志-------------------------------------------------info = (epoch, loss_sum / step, metric_sum / step,val_loss_sum / val_step, val_metric_sum / val_step)dfhistory.loc[epoch - 1] = info# 打印epoch级别日志print(("\nEPOCH = %d, loss = %.3f," + metric_name + \"  = %.3f, val_loss = %.3f, " + "val_" + metric_name + " = %.3f")% info)nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')print("\n" + "==========" * 8 + "%s" % nowtime)print('Finished Training...')return dfhistory

训练实例:

epochs = 25
dfhistory = train_model(model,epochs,dl_train,dl_valid,50)
Start Training...
================================================================================2022-01-19 10:39:33
[step = 50] loss: 0.662, auc: 0.699
[step = 100] loss: 0.627, auc: 0.747
[step = 150] loss: 0.605, auc: 0.762
[step = 200] loss: 0.593, auc: 0.770EPOCH = 1, loss = 0.593,auc  = 0.770, val_loss = 0.514, val_auc = 0.839================================================================================2022-01-19 10:39:43
[step = 50] loss: 0.541, auc: 0.805
[step = 100] loss: 0.539, auc: 0.806
[step = 150] loss: 0.531, auc: 0.813
[step = 200] loss: 0.524, auc: 0.819......

4、评估模型

直接print(dfhistory)即可;

def plot_metric(dfhistory,metric,name):""":param dfhistory: 训练的info:param metric: 指定训练的哪个指标:return: 返回对应的训练曲线"""train_metrics = dfhistory[metric]val_metrics = dfhistory['val_'+metric]epochs = range(1,len(train_metrics)+1)plt.plot(epochs,train_metrics,"bo--")plt.plot(epochs,val_metrics,"ro-")plt.title("Training and validation "+metric)plt.xlabel("Epochs")plt.ylabel(metric)plt.legend(["train_"+metric, 'val_'+metric])# saveplt.savefig("figure/"+name+".jpg")plt.show()

这里将plot_metric放在utils.py中;

from utils import plot_metric
plot_metric(dfhistory,"loss",name="image_train_loss")
plot_metric(dfhistory,"auc",name="image_train_auc")

image_train_auc.jpg

image_train_loss.jpg

5、使用模型进行预测

def predict(model,dl):model.eval()result = torch.cat([model.forward(t[0]) for t in dl])return (result.data)
# 预测概率
y_pred_probs = predict(model,dl_valid)
print("y_pred_probs:",y_pred_probs)
# 预测类别
y_pred = torch.where(y_pred_probs>0.5,torch.ones_like(y_pred_probs),torch.zeros_like(y_pred_probs))
print(y_pred)

6、保存模型

浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式https://www.jb51.net/article/187269.htm采用torch.save保存模型参数:

https://pytorch.org/docs/stable/generated/torch.save.html?highlight=save#torch.savehttps://pytorch.org/docs/stable/generated/torch.save.html?highlight=save#torch.save

torch.save(model.state_dict(),"model/model_parameter_image.pkl")
net_clone = Image_Net()
net_clone.load_state_dict(torch.load("model/model_parameter_image.pkl"))
# test
predict(net_clone,dl_valid)

后续会增加onnx模型部署!

Pytorch快速搭建并训练CNN模型?相关推荐

  1. Pytorch快速搭建Alexnet实现手写英文字母识别+PyQt实现鼠标绘图

    Pytorch快速搭建Alexnet实现手写英文字母识别+PyQt实现鼠标绘图 前言 一.案例要求 二.训练数据准备 1.下载手写英文字母数据集 2.构建自己的数据集 三.AlexNet实现 1.Al ...

  2. GitChat · 人工智能 | 如何零基础用 Keras 快速搭建实用深度学习模型

    GitChat 作者:谢梁 原文: 如何零基础用 Keras 快速搭建实用深度学习模型 关注微信公众号:GitChat 技术杂谈 ,一本正经的讲技术 [不要错过文末活动] 前言 在这篇小文章中,我们将 ...

  3. 用tensorflow框架和Mnist手写字体,训练cnn模型以及测试一张手写字体

    感想 首先我是首先看了一下莫凡pyhton教程中tensorflow python搭建自己的神经网络教程以及查看了官方的教程TensorFlow中文社区-MNIST进阶教程,这里面只是有简单的测试出来 ...

  4. pytorch下搭建网络训练并保存模型

    最近在学习pytorch,使用mnist数据集,搭建AlexNet训练并保存模型,将代码做一记录. 建立数据集的方法见pytorch建立自己的数据集(以mnist为例) 搭建网络的方法见用pytorc ...

  5. 使用Pytorch快速搭建神经网络模型(附详细注释和讲解)

    文章目录 0 前言 1 数据读入 2 模型搭建 3 模型训练 4 模型测试 5 模型保存 6 参考博客 0 前言 代码参考了知乎上"10分钟快速入门PyTorch"系列,并且附上了 ...

  6. tflearn教程_利用 TFLearn 快速搭建经典深度学习模型

    使用 TensorFlow 一个最大的好处是可以用各种运算符(Ops)灵活构建计算图,同时可以支持自定义运算符(见本公众号早期文章<TensorFlow 增加自定义运算符>).由于运算符的 ...

  7. 使用Tensorflow搭建并训练TextCNN模型,对文本进行分类

    最近有学习关于文本分类的深度学习模型,最先接触的就是TextCNN模型,该模型看起来非常简单效果也非常好,在此简单记录下整个模型的搭建以及训练过程.通过本博文,你可以自己搭建并训练一个简单的文本分类模 ...

  8. pytorch实现resnet50(训练+测试+模型转换)

    本章使用pytorch训练resnet50,使用cifar数据集. 数据集: 代码工程: 1.train.py import torch from torch import nn, optim imp ...

  9. 关于auto-keras训练cnn模型

    # 我在训练自己的人脸分类模型的时候发现图片的维度不能太高,经过很多次测试过后觉得一般人脸图片分为28*28大小训练的效果比较好.建议在使用其训练自己的物体识别模型的时候,尽量把图片压缩到28*28# ...

最新文章

  1. 半波对称振子方向图_移动天线的概念 | 天线方向性
  2. javascript删除数组,索引出现问题解决办法。
  3. Struts2学习---基本配置,action,动态方法调用,action接收参数
  4. 如何证明CPU的乱序执行(Out-of-order Execution)?
  5. Python MD5
  6. 【摘抄】其实我是间谍!
  7. R︱Rstudio 1.0版本尝鲜(R notebook、下载链接、sparkR、代码时间测试profile)
  8. IDEA Jsp乱码大全
  9. C++对数计算log
  10. dlib实现人脸对齐方法
  11. 高级JAVA面试题详解(三)——Redis(redis cluster、虚拟槽、一致性hash算法、master选举、淘汰策略、String数据结构)
  12. CST启用GPU加速的调试笔记
  13. Photoshop抠图(色彩范围命令扣人物/动物毛发图)
  14. 在html语言中的换行标记是指,南开20春学期(1709、1803、1809、1903、1909、2003)《电子商务网页制作》在线作业题目【标准答案】...
  15. Iphone手机被偷了 我是如何自保和尝试找回的
  16. FileReader读取文件
  17. 2021年8月互联网舆情热点事件报告
  18. 苹果新系统耗电过快怎么解决(解决方法)
  19. 用Python写了一个贪吃蛇大冒险小游戏
  20. 同事背后说坏话怎么办?为人再老实,也要做这3件事,吃亏不是福

热门文章

  1. C#找到最小的整数X,同时满足:X是2019的整倍数,X的每一位数字是奇数
  2. 论文写作-Latex问题和工具
  3. 国家公务员面试主要采取的是结构化的面试形式
  4. 需要免费虚拟机的朋友看过来
  5. 端口映射指导 ----- 配置文件方式
  6. The Wiley Handbook of Human Computer Interaction翻译
  7. opecv 证件照处理
  8. 面试中的操作系统知识
  9. 大功率半导体可调谐激光器
  10. 介绍一款好用的java反编译工具 - jd-gui