Pytorch采用AlexNet实现猫狗数据集分类(训练与预测)

  • 介绍
    • AlexNet网络模型
    • 猫狗数据集
    • AlexNet网络训练
    • 训练全代码
    • 预测
    • 预测图片

介绍

AlexNet模型是CNN网络中经典的网络模型,适合初学者学习,本文对AlexNet结构参数初步说明,详细可以下载论文。通过AlexNet对Kaggle的猫狗数据集进行训练和预测,相关资料为搜集总结。

AlexNet网络模型


如图是2012年AlexNet网络模型结构,由于之前GPU内存小,当时网络是采用了两块GPU,现在训练是不需要的,AlexNet的特点归纳为以下几点:

  1. 网络结构 ,包括五个卷积层,三个池化层,三个全连接层;
  2. 激活函数,采用Relu函数,优点是:网络训练更快;防止梯度消失;使得网络更具稀疏性;
  3. 池化层,采用最大值池化Max pooling,一个像素表示一块区域的像素值,降低图像分辨率;池化层虽然没有可以学习参数,也相当于另类卷积层;
  4. 训练技巧,(1)采用数据增强,对训练阶段和测试阶段的数据都进行了增强,包括针对位置的裁剪和针对颜色的扰动。在其结论中,颜色扰动效果有限,尺度扰动有不错的效果;(2)采用了随机失活Dropout,使得权值按P的概率失活,达到防止过拟合的目的。

猫狗数据集

Kaggle猫狗数据集,可以直接在官网下载https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data。下载的压缩包解压后,如图:

train是25000张猫狗图片,各占一半,图片名字都进行了标记,由于在同一个文件中,在数据处理阶段,需要打乱图片顺序以及读取图片名对其进行标记分类。训练阶段,只需要用到train文件的图片,首先写个dataset方便图片读取和相关操作,文件命名为My_dataset.py.

import os
import random
from PIL import Image
from torch.utils.data import Datasetrandom.seed(1)
class CatDogDataset(Dataset):def __init__(self, data_dir, mode="train", split_n=0.9, rng_seed=620, transform=None):"""rmb面额分类任务的Dataset:param data_dir: str, 数据集所在路径:param transform: torch.transform,数据预处理"""self.mode = modeself.data_dir = data_dirself.rng_seed = rng_seedself.split_n = split_nself.data_info = self._get_img_info()  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本self.transform = transformdef __getitem__(self, index):path_img, label = self.data_info[index]img = Image.open(path_img).convert('RGB')     # 0~255if self.transform is not None:img = self.transform(img)   # 在这里做transform,转为tensor等等return img, labeldef __len__(self):if len(self.data_info) == 0:raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(self.data_dir))return len(self.data_info)def _get_img_info(self):img_names = os.listdir(self.data_dir)img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))random.seed(self.rng_seed)random.shuffle(img_names)img_labels = [0 if n.startswith('cat') else 1 for n in img_names]split_idx = int(len(img_labels) * self.split_n)  # 25000* 0.9 = 22500# split_idx = int(100 * self.split_n)if self.mode == "train":img_set = img_names[:split_idx]     # 数据集90%训练# img_set = img_names[:22500]     #  hard code 数据集90%训练label_set = img_labels[:split_idx]elif self.mode == "valid":img_set = img_names[split_idx:]label_set = img_labels[split_idx:]else:raise Exception("self.mode 无法识别,仅支持(train, valid)")path_img_set = [os.path.join(self.data_dir, n) for n in img_set]data_info = [(n, l) for n, l in zip(path_img_set, label_set)]return data_info

AlexNet网络训练

在开始网络搭建前,还需的准备工作:
为了提高模型分类准确度,引入AlexNet在ImageNet比赛时的预训练模型,AlexNet结构就不重写了,直接调用Pytorch中的预设模型,AlexNet最后全连接层是1000分类的,所以之后代码中还需要修改最后一层参数。
链接:https://pan.baidu.com/s/16xd6PjmjPrAKIbta81yUdw
提取码:cyh9
简单的函数对预训练模型读取。

def get_model(path_state_dict, vis_model=False):"""创建模型,加载参数:param path_state_dict::return:"""model = models.alexnet()pretrained_state_dict = torch.load(path_state_dict)model.load_state_dict(pretrained_state_dict)if vis_model:from torchsummary import summarysummary(model, input_size=(3, 224, 224), device="cpu")model.to(device)return model

训练全代码

import os
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
import torchvision.models as models
from A_alexnet.tools.my_dataset import CatDogDatasetBASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")def get_model(path_state_dict, vis_model=False):"""创建模型,加载参数:param path_state_dict::return:"""model = models.alexnet()pretrained_state_dict = torch.load(path_state_dict)model.load_state_dict(pretrained_state_dict)if vis_model:from torchsummary import summarysummary(model, input_size=(3, 224, 224), device="cpu")model.to(device)return modelif __name__ == "__main__":# configdata_dir = os.path.join(BASE_DIR, "..", "data", "train")# 读取预训练模型path_state_dict = os.path.join(BASE_DIR, "..", "data", "alexnet-owt-4df8aa71.pth")# 二分类,设置类为2num_classes = 2MAX_EPOCH = 3       # 可自行修改,设置大效果会好点BATCH_SIZE = 200    # 可自行修改,内存大可以设置大点,速度快点LR = 0.001          # 可自行修改log_interval = 1    # 可自行修改val_interval = 1    # 可自行修改classes = 2start_epoch = -1lr_decay_step = 1   # 可自行修改# ============================ step 1/5 数据 ============================norm_mean = [0.485, 0.456, 0.406]norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.Resize((256)),      # (256, 256) 区别transforms.CenterCrop(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),])normalizes = transforms.Normalize(norm_mean, norm_std)valid_transform = transforms.Compose([transforms.Resize((256, 256)),transforms.TenCrop(224, vertical_flip=False),transforms.Lambda(lambda crops: torch.stack([normalizes(transforms.ToTensor()(crop)) for crop in crops])),])# 构建MyDataset实例train_data = CatDogDataset(data_dir=data_dir, mode="train", transform=train_transform)valid_data = CatDogDataset(data_dir=data_dir, mode="valid", transform=valid_transform)# 构建DataLodertrain_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)valid_loader = DataLoader(dataset=valid_data, batch_size=4)# ============================ step 2/5 模型 ============================alexnet_model = get_model(path_state_dict, False)num_ftrs = alexnet_model.classifier._modules["6"].in_featuresalexnet_model.classifier._modules["6"] = nn.Linear(num_ftrs, num_classes)alexnet_model.to(device)# ============================ step 3/5 损失函数 ============================criterion = nn.CrossEntropyLoss()# ============================ step 4/5 优化器 ============================optimizer = optim.SGD(alexnet_model.parameters(), lr=LR, momentum=0.9)  # 选择优化器scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)  # 设置学习率下降策略# ============================ step 5/5 训练 ============================train_curve = list()valid_curve = list()for epoch in range(start_epoch + 1, MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.alexnet_model.train()for i, data in enumerate(train_loader):# if i > 1:#     break# forwardinputs, labels = datainputs, labels = inputs.to(device), labels.to(device)outputs = alexnet_model(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().cpu().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i+1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))loss_mean = 0.scheduler.step()  # 更新学习率# validate the modelif (epoch+1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.alexnet_model.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)bs, ncrops, c, h, w = inputs.size()     # [4, 10, 3, 224, 224outputs = alexnet_model(inputs.view(-1, c, h, w))outputs_avg = outputs.view(bs, ncrops, -1).mean(1)loss = criterion(outputs_avg, labels)_, predicted = torch.max(outputs_avg.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().cpu().sum().numpy()loss_val += loss.item()loss_val_mean = loss_val/len(valid_loader)valid_curve.append(loss_val_mean)print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))alexnet_model.train()train_x = range(len(train_curve))train_y = train_curvetrain_iters = len(train_loader)valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterationsvalid_y = valid_curve# 保存网络模型及参数torch.save(alexnet_model.state_dict(), 'whole_CatDog_params.pth')plt.plot(train_x, train_y, label='Train')plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')plt.ylabel('loss value')plt.xlabel('Iteration')plt.show()

运行训练代码,就开始计算。

一般设置3个epoch,训练集和验证集准确率都在96%以上。

注意:在代码中,torch.save(alexnet_model.state_dict(), ‘whole_CatDog_params.pth’)已经保存了最终训练模型及参数,路径就自己设置,这里我直接保存。

预测

import os
os.environ['NLS_LANG'] = 'SIMPLIFIED CHINESE_CHINA.UTF8'
import time
import json
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.models as models
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")def img_transform(img_rgb, transform=None):"""将数据转换为模型读取的形式:param img_rgb: PIL Image:param transform: torchvision.transform:return: tensor"""if transform is None:raise ValueError("找不到transform!必须有transform对img进行处理")img_t = transform(img_rgb)return img_tdef load_class_names(p_clsnames, p_clsnames_cn):"""加载标签名:param p_clsnames::param p_clsnames_cn::return:"""with open(p_clsnames, "r") as f:class_names = json.load(f)with open(p_clsnames_cn, encoding='UTF-8') as f:  # 设置文件对象class_names_cn = f.readlines()return class_names, class_names_cndef get_model(path_state_dict, num_classes, vis_model=False):"""创建模型,加载参数:param path_state_dict::return:"""model = models.alexnet(num_classes=num_classes)pretrained_state_dict = torch.load(path_state_dict)model.load_state_dict(pretrained_state_dict)model.eval()if vis_model:from torchsummary import summarysummary(model, input_size=(3, 224, 224), device="cpu")model.to(device)return modeldef process_img(path_img):# hard codenorm_mean = [0.485, 0.456, 0.406]norm_std = [0.229, 0.224, 0.225]inference_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop((224, 224)),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),])# path --> imgimg_rgb = Image.open(path_img).convert('RGB')# img --> tensorimg_tensor = img_transform(img_rgb, inference_transform)img_tensor.unsqueeze_(0)        # chw --> bchwimg_tensor = img_tensor.to(device)return img_tensor, img_rgbif __name__ == "__main__":num_classes=2# configpath_state_dict = os.path.join(BASE_DIR, "whole_CatDog_params_0909.pth")path_img = os.path.join(BASE_DIR, "..", "data", "272.jpg")# 1/5 load imgimg_tensor, img_rgb = process_img(path_img)# 2/5 load modelalexnet_model = get_model(path_state_dict,num_classes, True)with torch.no_grad():time_tic = time.time()outputs = alexnet_model(img_tensor)time_toc = time.time()# 4/5 index to class names_, pred_int = torch.max(outputs.data, 1)_, top1_idx = torch.topk(outputs.data, 1, dim=1)#pred_idx = int(pred_int.cpu().numpy())if pred_idx == 0:pred_str= str("cat")print("img: {} is: {}".format(os.path.basename(path_img), pred_str))else:pred_str = str("dog")print("img: {} is: {}".format(os.path.basename(path_img), pred_str))print("time consuming:{:.2f}s".format(time_toc - time_tic))# 5/5 visualizationplt.imshow(img_rgb)plt.title("predict:{}".format(pred_str))plt.text(5, 45, "top {}:{}".format(1, pred_str), bbox=dict(fc='yellow'))plt.show()

预测模型有需要注意的地方,是get_model直接修改最后全连接层的分类数为2,跟训练修改方法不一致。用训练模型的修改方法不知道为啥一直有bug,水平有限就直接换了方法修改。预测模型没有全导入test的照片预测,只是截取一张预测一个结果,感兴趣可以自行导入全部test照片。

推荐使用torchsummary,结果可以很直观看到网络结构及相关参数。直接pip install torchsummary就可以安装了。

预测图片


仅作为学习总结分享,有错误望小伙伴们指正。

Pytorch采用AlexNet实现猫狗数据集分类(训练与预测)相关推荐

  1. 基于VGG深度学习神经网络的猫狗数据集分类

    摘要:VGG网络是由牛津大学视觉几何组完成的基于深度卷积神经网络的大规模图像识别架构,该网络参考了AlexNet.ZFNet.OverFeat等经典的网络架构,从而得出的.这个架构参加了ILSVRC- ...

  2. 用Tensorflow实现AlexNet识别猫狗数据集(猫狗大战)【附代码】

    AlexNet识别猫狗数据集 一.下载猫狗数据集 二.AlexNet实现 1.划分训练集和测试集 2.将训练集和测试集图片放缩为224x224 3.AlexNet实现 4.训练过程 5.模型测试 三. ...

  3. 使用PYTORCH复现ALEXNET实现猫狗识别

    完整代码链接:https://github.com/SPECTRELWF/pytorch-cnn-study 网络介绍: Alexnet网络是CV领域最经典的网络结构之一了,在2012年横空出世,并在 ...

  4. python学习之猫狗数据集分类实验(二)

    前面一个博客,介绍了训练之类的详情:https://blog.csdn.net/qq_43433255/article/details/93855517 下面,就用得到.h5文件继续. 这里是所有的头 ...

  5. 基于VGGnet识别猫狗数据集(猫狗大战)【附代码】

    文章目录 一.下载kaggle猫狗大战数据集 二.VGGnet实现 1.划分数据集 2.将训练集和测试集图片放缩为224x224 2.实现VGGnet 3.测试模型 三.总结 四.致歉 一.下载kag ...

  6. AlexNet 实现猫狗分类(keras and pytorch)

    AlexNet 实现猫狗分类 前言 在训练网络过程中遇到了很多问题,先在这里抱怨一下,没有硬件条件去使用庞大的ImageNet2012 数据集 .所以在选择合适的数据集上走了些弯路,最后选择有kagg ...

  7. (!详解 Pytorch实战:①)kaggle猫狗数据集二分类:加载(集成/自定义)数据集

    这系列的文章是我对Pytorch入门之后的一个总结,特别是对数据集生成加载这一块加强学习 另外,这里有一些比较常用的数据集,大家可以进行下载: 需要注意的是,本篇文章使用的PyTorch的版本是v0. ...

  8. PyTorch搭建预训练AlexNet、DenseNet、ResNet、VGG实现猫狗图片分类

    目录 前言 AlexNet DensNet ResNet VGG 前言 在之前的文章中,利用一个简单的三层CNN猫狗图片分类,正确率不高,详见: CNN简单实战:PyTorch搭建CNN对猫狗图片进行 ...

  9. 神经网络学习小记录17——使用AlexNet分类模型训练自己的数据(猫狗数据集)

    神经网络学习小记录17--使用AlexNet分类模型训练自己的数据(猫狗数据集) 学习前言 什么是AlexNet模型 训练前准备 1.数据集处理 2.创建Keras的AlexNet模型 开始训练 1. ...

最新文章

  1. Nature展示迄今为止最详细的“人脑零部件清单”
  2. Flutter开发之数据存储-2-文件存储(33)
  3. undertow ssl_SSL与WildFly 8和Undertow
  4. CSS3: 动画循环执行(带延迟)的实现
  5. python django开发工具_Python和Django web开发工具pycharm介绍
  6. Maven + Docker
  7. .NET中删除确认框的实现
  8. ros之旋转加平移公式
  9. loadrunner:关联操作
  10. HDOJ 1394 Minimum Inversion Number
  11. ueditor php提交表单,ThinkPHP使用Ueditor的方法详解
  12. 《量子信息与量子计算简明教程》第一章·基本概念(下)
  13. 【历史上的今天】8 月 24 日:Windows 95 问世;乔布斯辞任苹果 CEO;库克上台
  14. 818释放规模效能,苏宁易购全场景智慧零售迈上新台阶
  15. Hadoop 3.X, 纠删码
  16. IDEA启动Tomcat 中文乱码问题
  17. C++遇到Id returned 1 exit status解决办法
  18. 电子身份证助力打击钓鱼攻击
  19. 手机视频投屏到电视或投影仪
  20. STM32F4中断优先级NVIC管理

热门文章

  1. 去除桌面快捷键小箭头
  2. 基于Lumerical Mode的典型波导腔面本征模式的光场图计算
  3. 看完做不好百度爱采购你打我!
  4. trunc函数 mysql_oracle trunc函数使用详解
  5. 湖北立捷科技:淘宝商品发布规则介绍
  6. RimWorld模组教程之物品
  7. 第十八章----面向对象(宠物乱斗之子类篇)
  8. 服务器16g内存装哪个系统,16g内存的电脑装什么系统好
  9. Kibana常用命令
  10. VR元宇宙技术在各领域的应用场景