宝藏博主:霹雳吧啦Wz_太阳花的小绿豆_CSDN博客-深度学习,Tensorflow,软件安装领域博主

目录

数据集下载

训练集与测试集划分

“split_data.py”

Alexnet讲解:

名称解读

1)过拟合:

2) Dropout:

3)gpu

1. model.py

2. train.py

先对训练集的预处理

1)transforms.ToTensor(),

2) transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

3)transforms.Compose()类详解:串联多个transform操作

导入、加载 训练集

导入、加载 验证集

存储 索引:标签 的字典

训练过程

3. predict.py


数据集下载

http://download.tensorflow.org/example_images/flower_photos.tgz
包含 5 中类型(雏菊,蒲公英,玫瑰,向日葵,郁金香)的花,每种类型有600~900张图像

训练集与测试集划分

因为这次数据集不在dataset中,因此需要自己划分

参考:deep-learning-for-image-processing/README.md at master · WZMIAOMIAO/deep-learning-for-image-processing · GitHub

1.先将数据集压缩包解压到data_set文件夹中的flower_data中

2.在data_set目录下执行 shift + 右键 打开 PowerShell ,

3.执行 “split_data.py” 分类脚本自动将数据集划分成 训练集train 和 验证集val

“split_data.py”

import os
from shutil import copy
import randomdef mkfile(file):if not os.path.exists(file):os.makedirs(file)# 获取 flower_photos 文件夹下除 .txt 文件以外所有文件夹名(即5种花的类名)
file_path = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file_path) if ".txt" not in cla] # 创建 训练集train 文件夹,并由5种类名在其目录下创建5个子目录
mkfile('flower_data/train')
for cla in flower_class:mkfile('flower_data/train/'+cla)# 创建 验证集val 文件夹,并由5种类名在其目录下创建5个子目录
mkfile('flower_data/val')
for cla in flower_class:mkfile('flower_data/val/'+cla)# 划分比例,训练集 : 验证集 = 9 : 1
split_rate = 0.1# 遍历5种花的全部图像并按比例分成训练集和验证集
for cla in flower_class:cla_path = file_path + '/' + cla + '/'  # 某一类别花的子目录images = os.listdir(cla_path)           # iamges 列表存储了该目录下所有图像的名称num = len(images)eval_index = random.sample(images, k=int(num*split_rate)) # 从images列表中随机抽取 k 个图像名称for index, image in enumerate(images):# eval_index 中保存验证集val的图像名称if image in eval_index:                 image_path = cla_path + imagenew_path = 'flower_data/val/' + clacopy(image_path, new_path)  # 将选中的图像复制到新路径# 其余的图像保存在训练集train中else:image_path = cla_path + imagenew_path = 'flower_data/train/' + clacopy(image_path, new_path)print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing barprint()print("processing done!")

Alexnet讲解:

名称解读

1)过拟合:

过拟合是指为了得到一致假设而使假设变得过度严格。避免过拟合是分类器设计中的一个核心任务。通常采用增大数据量和测试样本集的方法对分类器性能进行评价

我的理解是因为太贴合训练集结果,导致我们的程序过度解读特征,我们的训练后的模型不能很好的预测其他的数据,而几乎完美贴合测试集

2) Dropout:

Dropout说的简单一点就是:我们在前向传播的时候,让某个神经元的激活值以一定的概率p停止工作(意思就是随机失活下一层的神经元)

3)gpu

gpu加速就不用说了,打游戏深有体会

具体图

1. model.py

import torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=1000, init_weights=False):super(AlexNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6])self.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(128 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1)x = self.classifier(x)return x# 网络权重初始化,实际上 pytorch 在构建网络时会自动初始化权重def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)

2. train.py

先对训练集的预处理

data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),       # 随机裁剪,再缩放成 224×224transforms.RandomHorizontalFlip(p=0.5),  # 水平方向随机翻转,概率为 0.5, 即一半的概率翻转, 一半的概率不翻转transforms.ToTensor(),      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

1)transforms.ToTensor(),

ToTesnor会数据归一化到均值为0,方差为1(是将数据除以255)

2) transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

本来就将数据缩小了,那为什么ToTensor后面加Normalize?

我找到很好的博文分享给大家

Normalize()是对数据按通道进行标准化,即减去均值,再除以方差

数据如果分布在(0,1)之间,可能实际的bias,就是神经网络的输入b会比较大,而模型初始化时b=0的,这样会导致神经网络收敛比较慢,经过Normalize后,可以加快模型的收敛速度
因为对RGB图片而言,数据范围是[0-255]的,需要先经过ToTensor除以255归一化到[0,1]之后,再通过Normalize计算过后,将数据归一化到[-1,1]。
————————————————
版权声明:本文为CSDN博主「小研一枚」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_35027690/article/details/103742697

3)transforms.Compose()类详解:串联多个transform操作

参考transforms.Compose()类详解:串联多个transform操作_migue_math-CSDN博客_transforms.compose

导入、加载 训练集

但是这次的 花分类数据集 并不在 pytorch 的 torchvision.datasets. 中,因此需要用到datasets.ImageFolder()    来导入。

ImageFolder()返回的对象是一个包含数据集所有图像及对应标签构成的二维元组容器,支持索引和迭代,可作为torch.utils.data.DataLoader的输入

参考Pytorch 加载图像数据(ImageFolder和Dataloader)_陶将的博客-CSDN博客

# 获取图像数据集的路径
data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))          # get data root path 返回上上层目录
image_path = data_root + "/data_set/flower_data/"                       # flower data_set path# 导入训练集并进行预处理
train_dataset = datasets.ImageFolder(root=image_path + "/train",       transform=data_transform["train"])
train_num = len(train_dataset)# 按batch_size分批次加载训练集
train_loader = torch.utils.data.DataLoader(train_dataset,  # 导入的训练集batch_size=32,     # 每批训练的样本数shuffle=True,    # 是否打乱训练集num_workers=0)    # 使用线程数,在windows下设置为0

导入、加载 验证集

# 导入验证集并进行预处理
validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])
val_num = len(validate_dataset)# 加载验证集
validate_loader = torch.utils.data.DataLoader(validate_dataset,    # 导入的验证集batch_size=32, shuffle=True,num_workers=0)

存储 索引:标签 的字典

这和爬虫爬取数据操作很像

# 字典,类别:索引 {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
# 将 flower_list 中的 key 和 val 调换位置
cla_dict = dict((val, key) for key, val in flower_list.items())# 将 cla_dict 写入 json 文件中
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:json_file.write(json_str)

训练过程

  • net.train():训练过程中开启 Dropout
  • net.eval(): 验证过程关闭 Dropout
    net = AlexNet(num_classes=5, init_weights=True)      # 实例化网络(输出类型为5,初始化权重)
    net.to(device)                                        # 分配网络到指定的设备(GPU/CPU)训练
    loss_function = nn.CrossEntropyLoss()                # 交叉熵损失
    optimizer = optim.Adam(net.parameters(), lr=0.0002)     # 优化器(训练参数,学习率)save_path = './AlexNet.pth'
    best_acc = 0.0for epoch in range(10):########################################## train ###############################################net.train()                       # 训练过程中开启 Dropoutrunning_loss = 0.0                    # 每个 epoch 都会对 running_loss  清零time_start = time.perf_counter()    # 对训练一个 epoch 计时for step, data in enumerate(train_loader, start=0):  # 遍历训练集,step从0开始计算images, labels = data   # 获取训练集的图像和标签optimizer.zero_grad()  # 清除历史梯度outputs = net(images.to(device))                # 正向传播loss = loss_function(outputs, labels.to(device)) # 计算损失loss.backward()                                   # 反向传播optimizer.step()                              # 优化器更新参数running_loss += loss.item()# 打印训练进度(使训练过程可视化)rate = (step + 1) / len(train_loader)           # 当前进度 = 当前step / 训练一轮epoch所需总stepa = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")print()print('%f s' % (time.perf_counter()-time_start))########################################### validate ###########################################net.eval()    # 验证过程中关闭 Dropoutacc = 0.0  with torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]  # 以output中值最大位置对应的索引(标签)作为预测输出acc += (predict_y == val_labels.to(device)).sum().item()    val_accurate = acc / val_num# 保存准确率最高的那次网络参数if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f \n' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')
    

  • 生成pth的模型文件

  • 3. predict.py

  • import torch
    from model import AlexNet
    from PIL import Image
    from torchvision import transforms
    import matplotlib.pyplot as plt
    import json# 预处理
    data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load image
    img = Image.open("./01.jpeg")
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)# read class_indict
    try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
    except Exception as e:print(e)exit(-1)# create model
    model = AlexNet(num_classes=5)
    # load model weights
    model_weight_path = "./AlexNet.pth"
    model.load_state_dict(torch.load(model_weight_path))# 关闭 Dropout
    model.eval()
    with torch.no_grad():# predict classoutput = torch.squeeze(model(img))     # 将输出压缩,即压缩掉 batch 这个维度predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
    print(class_indict[str(predict_cla)], predict[predict_cla].item())
    plt.show()
    

pytorch——AlexNet——训练花分类数据集相关推荐

  1. 深度学习网络模型——RepVGG网络详解、RepVGG网络训练花分类数据集整体项目实现

    深度学习网络模型--RepVGG网络详解.RepVGG网络训练花分类数据集整体项目实现 0 前言 1 RepVGG Block详解 2 结构重参数化 2.1 融合Conv2d和BN 2.2 Conv2 ...

  2. 使用pytorch搭建AlexNet并训练花分类数据集

    深度学习学习笔记 导师博客:https://blog.csdn.net/qq_37541097/article/details/103482003 导师github:https://github.co ...

  3. 3.2 使用pytorch搭建AlexNet并训练花分类数据集

    文章目录 class_indices.json model.py predict.py train.py 创建自己的数据集 #详解 class_indices.json {"0": ...

  4. 使用AlexNet训练自己的数据集

    前言: 前两篇分别介绍两个图像识别的模型,第一个是mnist手写体图像识别,第二个是在第一个代码的基础上增加了一些优化技巧,例如正则化.dropout等,并且比较加上各种优化技巧之后图像识别的结果.  ...

  5. AlexNet网络的搭建以及训练花分类

    前言 本学习笔记参考自B站up主霹雳吧啦Wz 代码均来自导师github开源项目WZMIAOMIAO/deep-learning-for-image-processing: deep learning ...

  6. Pytorch搭建网络训练葡萄酒分类数据集(三分类)

    代码如下: import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F ...

  7. detectron2训练自己的数据集_YOLO(v3)PyTorch版 训练自己的数据集

    Yolo v3比Frcnn好调试多了--就是数据集准备比较麻烦-- 但是好Debug,linux和win10差别不大-- 代码链接(cpu版本): https://github.com/eriklin ...

  8. (Pytorch) YOLOV4 : 训练自己的数据集【左侧有码】

    项目地址:https://github.com/argusswift/YOLOv4-pytorch 这份代码实现的逻辑非常清楚,主要一些数据集处理的代码需要相应的改动: 这里的数据集label格式: ...

  9. Pytorch复现RepVGG模型,实现花分类

    写在开头,本篇博客是博主跟着一位b站up主学习过程中,发现up主只对网络讲解,但没有复现的视频. 在此放上up主的链接(点我) 为了加深印象,复现是不可少的,所有博主就去论文源码扒代码去了,但源码其实 ...

最新文章

  1. 超速电眼:全时成像芯片重塑机器视觉
  2. 02 | 日志系统:一条 SQL 更新语句是如何执行的
  3. boost::noinit_adaptor用法实例
  4. C# 三层级架构问题之 能加载文件或程序集或它的某一个依赖项。系统找不到指定的文件
  5. Hibernate查询方式
  6. c#之task与thread区别及其使用
  7. linux环境变量命名规范,Linux就该这么学 -- 重要的环境变量
  8. 关于Mytatis动态拼接in语句并且按照指定顺序排序的问题
  9. 【nexus】nexus : mac 安装 nexus
  10. 一手云端,一手终端:比特大陆发布两款AI芯片,大步迈进AI领域
  11. 7620a路由mysql_MT7620A路由刷DDWRT 及2.4G无线设置经验
  12. [Linux] Ubuntu Server 12.04 LTS 平台上搭建WordPress(Nginx+MySQL+PHP) Part IV
  13. 四元素与欧拉角之间的转换
  14. QPSK和16QAM调制
  15. Linux常见查看日志命令
  16. 【MOOC】华中科技大学操作系统慕课答案-单元作业+第1~2章开放性思考题
  17. 传感器自学笔记第五章——旋转编码器
  18. html - 移动标签 marquee 属性
  19. JIT准时生产制造管理
  20. 小米笔记本电池只充电到95%的设置

热门文章

  1. 使用Nexus搭建Maven私服(1)
  2. 新一代信息技术-大数据
  3. 【Python】一个房贷计算器功能的小案例
  4. vue2中component和components在组件注册和路由中的区别
  5. Android实现一键获取课程成绩dome
  6. 重新审视 Bancor 算法,为什么 cw 是失效的设计
  7. 生成一个cesium火焰特效
  8. 针对借“冠状病毒硬币”非法ICO的区块链疫情捐款诈骗分析 -星辰安全实验室
  9. LeetCode 银联-4. 设计自动售货机
  10. 基于 qemu 的 riscv32架构的 非官方rt-thread 体验 教程