文章目录

  • class_indices.json
  • model.py
  • predict.py
  • train.py
  • 创建自己的数据集

#详解

class_indices.json

{"0": "daisy","1": "dandelion","2": "roses","3": "sunflowers","4": "tulips"
}

model.py

import torch.nn as nn
import torchclass AlexNet(nn.Module):# 继承model类def __init__(self, num_classes=1000, init_weights=False):#初始化参数,定义参数与层结构super(AlexNet, self).__init__()self.features = nn.Sequential(#Sequential能把一系列的层结构打包成一个新的层结构,当前层结构被定义为提取特征的层结构nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), #第一层 # input[3, 224, 224]  output[48, 55, 55],他只用了一半的卷积核(padding=(1,2),计算后是小数,就又一样了)nn.ReLU(inplace=True),#inplace是pytorch通过一种操作增加计算量减少内存占用nn.MaxPool2d(kernel_size=3, stride=2),#卷积核大小是3,步距是2                  # output[48, 27, 27]nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6])self.classifier = nn.Sequential(#分类器nn.Dropout(p=0.5),#dropout的方法上全连接层随机失活(一般放在全裂阶层之间)p值随即失火的比例nn.Linear(128 * 6 * 6, 2048),#linear是全连接层nn.ReLU(inplace=True),#激活函数nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),#输出是数据集的类别个数)if init_weights:#初始化权重,定义在下面self._initialize_weights()def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1)#展平x = self.classifier(x)return xdef _initialize_weights(self):#其实不用,目前pytorch自动就是这个for m in self.modules():#会返回一个迭代器,遍历模型中所有的模块(遍历每一个层结构)if isinstance(m, nn.Conv2d):#是否是卷积nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')#是就去kaiming_normal初始化if m.bias is not None:#偏置不是0就置0nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):#如果是全连接层nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)

predict.py

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import jsondata_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load image
img = Image.open("../tulip.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)# read class_indict
try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = AlexNet(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():# predict classoutput = torch.squeeze(model(img))#压缩掉batch的维度predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

train.py

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import timedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#指定训练过程中使用的设备
print(device)data_transform = {#数据预处理函数"train": transforms.Compose([transforms.RandomResizedCrop(224),#随机裁剪到224*224像素大小transforms.RandomHorizontalFlip(),#水平随即反转transforms.ToTensor(),#转化成tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path#获取数据集的根目录(os.getcwd():获取当前稳健所在目录;os.path.join合并到一起)
image_path = data_root + "/data_set/flower_data/"  # flower data set path
train_dataset = datasets.ImageFolder(root=image_path + "/train",#加载数据集,train下面每一类是一个文件夹transform=data_transform["train"])#transform是数据预处理(之前定义的),map
train_num = len(train_dataset)#数据集有多少张图片# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx####################!!!!!获取类的名称对应的索引
cla_dict = dict((val, key) for key, val in flower_list.items())#江键值反转
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)#编码成json的格式
with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=4, shuffle=True,num_workers=0)#查看数据集的代码
# test_data_iter = iter(validate_loader)
# test_image, test_label = test_data_iter.next()
#
# def imshow(img):
#     img = img / 2 + 0.5  # unnormalize
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1, 2, 0)))
#     plt.show()
#
# print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
# imshow(utils.make_grid(test_image))net = AlexNet(num_classes=5, init_weights=True)#实例化,分类集有五类,初始化权重是truenet.to(device)#分设备
#损失函数与优化器
loss_function = nn.CrossEntropyLoss()
# pata = list(net.parameters())
optimizer = optim.Adam(net.parameters(), lr=0.0002)save_path = './AlexNet.pth'
best_acc = 0.0#用来保存最佳平均准确率,为了保存效果最好的一次模型
for epoch in range(10):#10轮# trainnet.train()#用net.train()与net.eavl() 因为用了dropout,希望只在训练时失活,所以用这个来管理dropoutrunning_loss = 0.0t1 = time.perf_counter()#统计训练一个epoch所使用的时间for step, data in enumerate(train_loader, start=0):#遍历数据集images, labels = data#将数据分成图像与对应的标签optimizer.zero_grad()#清空梯度信息outputs = net(images.to(device))#正向传播,将图像也指认到设备上loss = loss_function(outputs, labels.to(device))#计算损失loss.backward()#反向传播optimizer.step()#更新参数# print statisticsrunning_loss += loss.item()#将loss的值累加到runningloss中(loss。item才是loss的值)# print train process,打印训练进度rate = (step + 1) / len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")print()print(time.perf_counter()-t1)# validate验证net.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():#禁止损失跟踪for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == val_labels.to(device)).sum().item()#计算准确个数val_accurate = acc / val_numif val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)#保存权重print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')

创建自己的数据集

偷懒的办法
在flowerdata下直接全删了改自己的

3.2 使用pytorch搭建AlexNet并训练花分类数据集相关推荐

  1. 使用pytorch搭建AlexNet并训练花分类数据集

    深度学习学习笔记 导师博客:https://blog.csdn.net/qq_37541097/article/details/103482003 导师github:https://github.co ...

  2. 深度学习网络模型——RepVGG网络详解、RepVGG网络训练花分类数据集整体项目实现

    深度学习网络模型--RepVGG网络详解.RepVGG网络训练花分类数据集整体项目实现 0 前言 1 RepVGG Block详解 2 结构重参数化 2.1 融合Conv2d和BN 2.2 Conv2 ...

  3. pytorch——AlexNet——训练花分类数据集

    宝藏博主:霹雳吧啦Wz_太阳花的小绿豆_CSDN博客-深度学习,Tensorflow,软件安装领域博主 目录 数据集下载 训练集与测试集划分 "split_data.py" Ale ...

  4. 使用pytorch搭建AlexNet网络模型

    使用pytorch搭建AlexNet网络模型 AlexNet详解 AlexNet是2012年ISLVRC 2012(ImageNet Large Scale Visual Recognition Ch ...

  5. 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记

    使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...

  6. 基于PyTorch搭建CNN实现视频动作分类任务代码详解

    数据及具体讲解来源: 基于PyTorch搭建CNN实现视频动作分类任务 import torch import torch.nn as nn import torchvision.transforms ...

  7. 使用pytorch搭建MLP多层感知器分类网络判断LOL比赛胜负

    使用pytorch搭建MLP多层感知器分类网络判断LOL比赛胜负 1. 数据集 百度网盘链接,提取码:q79p 数据集文件格式为CSV.数据集包含了大约5万场英雄联盟钻石排位赛前15分钟的数据集合,总 ...

  8. AlexNet网络的搭建以及训练花分类

    前言 本学习笔记参考自B站up主霹雳吧啦Wz 代码均来自导师github开源项目WZMIAOMIAO/deep-learning-for-image-processing: deep learning ...

  9. pytorch学习2:pytorch搭建Alexnet网络

    推荐神仙up主 霹雳吧啦Wz 我的代码基本就是按照他的代码自己写了一遍加深印象,有兴趣的可以去看看,强烈推荐.我写这个博客只是记录一下学习的过程,防止忘记.添加了一些注释,帮助理解. 1.模型 imp ...

最新文章

  1. Java查找数组重复元素,并打印重复元素、重复次数、重复元素位置
  2. MySQL LEFT/RIGHT JOIN:外连接查询
  3. 数据查询语言(DQL)
  4. 阿里云MySQL按流量计费吗_阿里云服务器按使用流量计费带宽峰值1M和100M费用方面有区别吗?...
  5. XML和HTML有什么区别?两者之间有什么关联?
  6. bat set命令详解
  7. 【NOIP考前模拟赛】纯数学方法推导——旅行者问题
  8. 送给程序员:IT大神们的编程名言
  9. Gamma阶段第八次scrum meeting
  10. java设计连连看心得_基于Java的连连看游戏的设计与实现
  11. 360创始人周鸿祎曾这样告诫年轻人
  12. JavaScript匿名函数和回调函数
  13. 【视频】超级账本HyperLedger:Fabric源码走读(一):项目构建与代码结构
  14. OpenCV环境搭建(Windows+Visual studio)及Hello World
  15. 信号处理-基于希尔伯特解调(包络谱)的轴承故障诊断实战,通过python代码实现超详细讲解
  16. 操作系统概念第八章部分作业题答案
  17. androidsettitle方法_在Android应用程序中,Toolbar.setTitle方法无效 - 应用程序名称显示为ti...
  18. 100000+人体验过后都说:这TM绝对是最变态的英语学习方法……
  19. 板块分析:筑底阶段 智能家居开启蓝海
  20. 如何分发大文件、大文件断点续传解决方案

热门文章

  1. 从头编写 asp.net core 2.0 web api 基础框架 (4) EF配置
  2. AppWidget应用(一)---创建一个appWidget
  3. Angular JS 中的内置方法之表单验证
  4. EF延迟加载LazyLoading
  5. centos6.5 scala环境变量
  6. 计算机听不到音乐怎么回事,Win10电脑设置麦克风提示“计算机听不到任何声音”如何解决...
  7. python的结构_Python结构的选择,python,之
  8. 配置一个Servlet可以被一个(指定的开头链接,后自定义)访问
  9. java 弱引用定位_手把手教你定位常见Java性能问题
  10. 华为手表用鸿蒙了吗,华为鸿蒙都2.0了,手机还不能用吗?