文章目录

  • 1 数据
    • 1.1 准备工作
    • 1.2 数据下载
    • 1.3 数据分类
  • 2 模型
  • 3 训练
  • 4 测试
  • 5 爬取图片并测试
    • 5.1爬取测试图片
  • 5.2 测试爬取的花朵
  • 7 十二生肖(爬取、划分、训练验证、测试)

1 数据

1.1 准备工作

新建一个文件夹AlexNet,在文件夹AlexNet新建一个文件夹flower_data,将下载后的数据解压并放到文件夹flower_data

1.2 数据下载

下载 Tensorflow 的花朵图片
http://download.tensorflow.org/example_images/flower_photos.tgz

1.3 数据分类

文件夹AlexNet右键打开终端

gedit spile_data.py # 将 spile_data.py 拷入保存关闭
python spile_data.py # 运行 spile_data.py

spile_data.py

import os
from shutil import copy
import randomdef mkfile(file):if not os.path.exists(file):os.makedirs(file)file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
mkfile('flower_data/train')
for cla in flower_class:mkfile('flower_data/train/'+cla)mkfile('flower_data/val')
for cla in flower_class:mkfile('flower_data/val/'+cla)split_rate = 0.1
for cla in flower_class:cla_path = file + '/' + cla + '/'images = os.listdir(cla_path)num = len(images)eval_index = random.sample(images, k=int(num*split_rate))for index, image in enumerate(images):# 划分为验证集if image in eval_index:image_path = cla_path + imagenew_path = 'flower_data/val/' + clacopy(image_path, new_path)# 划分为训练集else:image_path = cla_path + imagenew_path = 'flower_data/train/' + clacopy(image_path, new_path)print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing barprint()print("processing done!")

2 模型

文件夹AlexNet右键打开终端

gedit model.py # 将 model.py 拷入保存关闭

model.py

import torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=5, init_weights=False):   super(AlexNet, self).__init__()self.features = nn.Sequential(  #打包nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55] 自动舍去小数点后nn.ReLU(inplace=True), #inplace 可以载入更大模型nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27] kernel_num为原论文一半nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6])self.classifier = nn.Sequential(nn.Dropout(p=0.5),#全链接nn.Linear(128 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1) #展平   或者view()x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') #何教授方法if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)  #正态分布赋值nn.init.constant_(m.bias, 0)

3 训练

文件夹AlexNet右键打开终端

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time#device : GPU 或 CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)#数据预处理
data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪为224x224transforms.RandomHorizontalFlip(), # 水平翻转transforms.ToTensor(), # 转为张量transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),# 均值和方差为0.5"val": transforms.Compose([transforms.Resize((224, 224)), # 重置大小transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}batch_size = 32 # 批次大小
data_root = os.getcwd() # 获取当前路径
image_path = data_root + "/flower_data/"  # 数据路径train_dataset = datasets.ImageFolder(root=image_path + "/train",transform=data_transform["train"]) # 加载训练数据集并预处理
train_num = len(train_dataset) # 训练数据集大小train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0) # 训练加载器validate_dataset = datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"]) # 验证数据集
val_num = len(validate_dataset) # 验证数据集大小
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=True,num_workers=0) # 验证加载器print("训练数据集大小: ",train_num,"\n") # 3306
print("验证数据集大小: ",val_num,"\n") # 364net = AlexNet(num_classes=5, init_weights=True) # 调用模型net.to(device)loss_function = nn.CrossEntropyLoss() # 损失函数:交叉熵
optimizer = optim.Adam(net.parameters(), lr=0.0002) #优化器 Adam
save_path = './AlexNet.pth' # 训练参数保存路径
best_acc = 0.0 # 训练过程中最高准确率#开始进行训练和测试,训练一轮,测试一轮
for epoch in range(10):# 训练部分print(">>开始训练: ",epoch+1)net.train()    #训练dropoutrunning_loss = 0.0t1 = time.perf_counter()for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad() # 梯度置0outputs = net(images.to(device)) loss = loss_function(outputs, labels.to(device))loss.backward() # 反向传播optimizer.step()running_loss += loss.item() # 累加损失rate = (step + 1) / len(train_loader) # 训练进度a = "*" * int(rate * 50) # *数b = "." * int((1 - rate) * 50) # .数print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")print()print(time.perf_counter()-t1) # 一个epoch花费的时间# 验证部分print(">>开始验证: ",epoch+1)net.eval()    #验证不需要dropoutacc = 0.0  # 一个批次中分类正确个数with torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoutputs = net(val_images.to(device))#print("outputs: \n",outputs,"\n")predict_y = torch.max(outputs, dim=1)[1]#print("predict_y: \n",predict_y,"\n")acc += (predict_y == val_labels.to(device)).sum().item() # 预测和标签一致,累加val_accurate = acc / val_num # 一个批次的准确率if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path) # 更新准确率最高的网络参数print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx #  {0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
cla_dict = dict((val, key) for key, val in flower_list.items())# 将字典写入 json 文件
json_str = json.dumps(cla_dict, indent=4) # 字典转json
with open('class_indices.json', 'w') as json_file: # 对class_indices.json写入操作json_file.write(json_str) # 写入class_indices.json

4 测试

在网上下载一张jpg格式的图片,改名为sunflower.jpg,并放在文件夹AlexNet

文件夹AlexNet右键打开终端

gedit predict.py # 将 predict.py 拷入保存关闭
python predict.py # 运行 predict.py

predict.py

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json# 加载图片
img = Image.open("./sunflower.jpg")  #验证太阳花data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#img = Image.open("./roses.jpg")     #验证玫瑰花
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)# 读取 class_indict
try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# 调用模型
model = AlexNet(num_classes=5)
# 加载模型参数
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval() # 不用Dropout
with torch.no_grad():output = torch.squeeze(model(img)) # 压缩predict = torch.softmax(output, dim=0) # 生成概率predict_cla = torch.argmax(predict).numpy() # 最大值的索引
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

5 爬取图片并测试

5.1爬取测试图片

因为使用英文关键字爬取有敏感图片,所以用中文。存放路径’flower_data/predict’。

get_data.py

import requests
import urllib.parse as up
import json
import time
import osmajor_url = 'https://image.baidu.com/search/index?'
headers = {'User-Agent' : 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.135 Safari/537.36'}def pic_spider(kw, path, page = 10):path = os.path.join(path, class_indict[str(j)]) # 中文搜索,英文保存if not os.path.exists(path):os.mkdir(path)if kw != '':for num in range(page):data = {"tn": "resultjson_com","logid": "11587207680030063767","ipn": "rj","ct": "201326592","is": "","fp": "result","queryWord": kw,"cl": "2","lm": "-1","ie": "utf-8","oe": "utf-8","adpicid": "","st": "-1","z": "","ic": "0","hd": "","latest": "","copyright": "","word": kw,"s": "","se": "","tab": "","width": "","height": "","face": "0","istype": "2","qc": "","nc": "1","fr": "","expermode": "","force": "","pn": num*30,"rn": "30","gsm": oct(num*30),"1602481599433": ""}url = major_url + up.urlencode(data)i = 0pic_list = []while i < 5:try:pic_list = requests.get(url=url, headers=headers).json().get('data')breakexcept:print('网络不好,正在重试...')i += 1time.sleep(1.3)for pic in pic_list:url = pic.get('thumbURL', '') # 有的没有图片链接,就设置成空if url == '':continuename = pic.get('fromPageTitleEnc')for char in ['?', '\\', '/', '*', '"', '|', ':', '<', '>']:name = name.replace(char, '') # 将所有不能出现在文件名中的字符去除掉type = pic.get('type', 'jpg') # 找到图片的类型,若没有找到,默认为 jpgpic_path = (os.path.join(path, '%s.%s') % (name, type))print(name, '已完成下载')if not os.path.exists(pic_path):with open(pic_path, 'wb') as f:f.write(requests.get(url = url, headers = headers).content)
cwd = os.getcwd() # 当前路径
file1 = 'flower_data/flower_photos'
file2 = 'flower_data/predict'
predict_data = os.path.join(cwd,file2)
flower_class = [cla for cla in os.listdir(file1) if ".txt" not in cla]try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)
#print(class_indict['daisy'])
print(type(class_indict))
print(class_indict)for j,cla in enumerate(['雏菊','蒲公英','玫瑰花','太阳花','郁金香'],0):pic_spider(cla,predict_data, page = 10)

5.2 测试爬取的花朵

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
import osdata_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])cwd = os.getcwd() # 当前路径
predict = 'flower_data/predict'
predict_path = os.path.join(cwd,predict)
#flowers = ['雏菊','蒲公英','玫瑰花','太阳花','郁金香']
#flowers = [flower for flower in os.listdir(predict_path)]try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)for j,flower in class_indict.items():print(">>测试: ",flower)#print("花\t","概率") path = os.path.join(predict_path,flower)images = [f1 for f1 in os.listdir(path) if ".gif" not in f1] # 过滤gif动图acc_ = [0,0,0,0,0]for image in images:# 加载图片img = Image.open(path+'/'+image).convert('RGB')# RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0# .convert('RGB')plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indicttry:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)except Exception as e:print(e)exit(-1)# create modelmodel = AlexNet(num_classes=5)# load model weightsmodel_weight_path = "./AlexNet.pth"model.load_state_dict(torch.load(model_weight_path))model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img))predict = torch.softmax(output, dim=0)predict_flower = torch.argmax(predict).numpy()print(class_indict[str(predict_flower)],'\t', predict[predict_flower].item())#print(str(predict_flower))acc_[predict_flower]+=1print("acc_: ",acc_)print("{}总共有{}张图片 \n其中,{}:{},{}:{},{}:{},{}:{},{}:{}".format(flower,len(images),class_indict[str(0)],acc_[0],class_indict[str(1)],acc_[1],class_indict[str(2)],acc_[2],class_indict[str(3)],acc_[3],class_indict[str(4)],acc_[4]))print("{}准确率为:{}%".format(flower,100*acc_[int(j)]/len(images)))print("\n\n")
print(">>测试完毕!")

太阳花的准确率不到30%,其他的花朵有60%的正确率

7 十二生肖(爬取、划分、训练验证、测试)

在AlexNet文件夹新增数据文件夹

爬取
get_data.py

import requests
import urllib.parse as up
import json
import time
import osmajor_url = 'https://image.baidu.com/search/index?'
headers = {'User-Agent' : 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.135 Safari/537.36'}
def pic_spider(kw, path, page = 10):path = os.path.join(path, kw)if not os.path.exists(path):os.mkdir(path)if kw != '':for num in range(page):data = {"tn": "resultjson_com","logid": "11587207680030063767","ipn": "rj","ct": "201326592","is": "","fp": "result","queryWord": kw,"cl": "2","lm": "-1","ie": "utf-8","oe": "utf-8","adpicid": "","st": "-1","z": "","ic": "0","hd": "","latest": "","copyright": "","word": kw,"s": "","se": "","tab": "","width": "","height": "","face": "0","istype": "2","qc": "","nc": "1","fr": "","expermode": "","force": "","pn": num*30,"rn": "30","gsm": oct(num*30),"1602481599433": ""}url = major_url + up.urlencode(data)i = 0pic_list = []while i < 5:try:pic_list = requests.get(url=url, headers=headers).json().get('data')breakexcept:print('网络不好,正在重试...')i += 1time.sleep(1.3)for pic in pic_list:url = pic.get('thumbURL', '') # 有的没有图片链接,就设置成空if url == '':continuename = pic.get('fromPageTitleEnc')for char in ['?', '\\', '/', '*', '"', '|', ':', '<', '>']:name = name.replace(char, '') # 将所有不能出现在文件名中的字符去除掉type = pic.get('type', 'jpg') # 找到图片的类型,若没有找到,默认为 jpgpic_path = (os.path.join(path, '%s.%s') % (name, type))print(name, '已完成下载')if not os.path.exists(pic_path):with open(pic_path, 'wb') as f:f.write(requests.get(url = url, headers = headers).content)
cwd = os.getcwd() # 当前路径
file1 = 'flower_data/flower_photos'
file2 = '数据/十二生肖'
save_path = os.path.join(cwd,file2)
#flower_class = [cla for cla in os.listdir(file1) if ".txt" not in cla]
lists = ['猫','牛','虎','兔','龙','蛇','马','羊','猴','鸡','狗','猪']
for list in lists:if not os.path.exists(save_path):os.mkdir(save_path)pic_spider('动物'+list,save_path, page = 10)

划分数据
spile_data.py

import os
from shutil import copy
import randomdef mkfile(file):if not os.path.exists(file):os.makedirs(file)#file = 'flower_data/flower_photos'
file = '数据/十二生肖'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
#mkfile('flower_data/train')
mkfile('数据/train')
for cla in flower_class:#mkfile('flower_data/train/'+cla)mkfile('数据/train/'+cla)#mkfile('flower_data/val')
mkfile('数据/val')
for cla in flower_class:#mkfile('flower_data/val/'+cla)mkfile('数据/val/'+cla)split_rate = 0.1
for cla in flower_class:cla_path = file + '/' + cla + '/'images = os.listdir(cla_path)num = len(images)eval_index = random.sample(images, k=int(num*split_rate))for index, image in enumerate(images):if image in eval_index:image_path = cla_path + imagenew_path = 'flower_data/val/' + clacopy(image_path, new_path)else:image_path = cla_path + imagenew_path = 'flower_data/train/' + clacopy(image_path, new_path)print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing barprint()print("processing done!")

使用 Pytorch 训练 AlexNet 识别5种花朵相关推荐

  1. Pytorch训练表情识别之笑脸识别

    Pytorch训练表情识别之笑脸识别 一.数据下载 首先是数据下载,话不多说,直接上百度云链接,这是一个只有笑脸和无表情的数据集: 链接:https://pan.baidu.com/s/11K1C6n ...

  2. Pytorch训练速度更快的十七种方法

    来源: 不久前,Reddit 上一个帖子热度爆了.最主题的内容是关于如何加速 PyTorch 训练.作者是来自苏黎世联邦理工学院的计算机科学硕士生 LORENZ KUHN,文章向我们介绍了在使用 Py ...

  3. Resnet152对102种花朵图像分类(PyTorch,迁移学习)

    目录 1.介绍 1.1.项目数据及源码 1.2.数据集介绍 1.3.任务介绍 1.4.ResNet网络介绍 2.数据预处理 3.展示数据 4.进行迁移学习 4.1.训练全连接层 4.2.训练所有层 5 ...

  4. PyTorch训练加速17种技巧

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 文自 机器之心 作者:LORENZ KUHN 编辑:陈萍 掌握这 ...

  5. 使用PYTORCH复现ALEXNET实现MNIST手写数字识别

    网络介绍: Alexnet网络是CV领域最经典的网络结构之一了,在2012年横空出世,并在当年夺下了不少比赛的冠军,下面是Alexnet的网络结构: 网络结构较为简单,共有五个卷积层和三个全连接层,原 ...

  6. Pytorch采用AlexNet实现猫狗数据集分类(训练与预测)

    Pytorch采用AlexNet实现猫狗数据集分类(训练与预测) 介绍 AlexNet网络模型 猫狗数据集 AlexNet网络训练 训练全代码 预测 预测图片 介绍 AlexNet模型是CNN网络中经 ...

  7. 让PyTorch训练速度更快,你需要掌握这17种方法

    选自efficientdl.com 作者:LORENZ KUHN 机器之心编译 编辑:陈萍 掌握这 17 种方法,用最省力的方式,加速你的 Pytorch 深度学习训练. 近日,Reddit 上一个帖 ...

  8. 这17 种方法让 PyTorch 训练速度更快!

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:选自 | efficientdl.com   作者 | LO ...

  9. 面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码)

    面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码) 目录 面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码) 1.面部表情识别方法 2.面部表情识别数据集 ...

最新文章

  1. stk软件支持Linux,STK组件基础篇:开发入门
  2. php屏蔽审查元素,HTML网站右键禁用F12代码 屏蔽审查元素 防止修改页面代码
  3. 因为WPFe JavaScript到了不得不学的地步
  4. 30款顶级CSS工具及应用-CSDN.NET
  5. Python批量提取PowerPoint文件中所有幻灯片标题和备注文本
  6. 安装exchange
  7. Java 实现 RSA 非对称加密算法-加解密和签名验签
  8. mysql操作SQL语句
  9. spring连接mysql出现问题_spring+hibernate连接mysql问题啊
  10. 共轭梯度法的简单直观理解
  11. vue+ele 使用及demo
  12. LQ735kii针式打印机安装教程以及设置教程超级详细
  13. 使用python创建数组的方法
  14. Ubuntu下查看隐藏文件
  15. SQLServer之修改PRIMARY KEY
  16. 基于NAR神经网络的时间序列预测
  17. 【网络安全】威胁情报信息
  18. Thunar文件管理器新增一个使用root权限打开当前目录或者文件夹
  19. html5分镜头脚本范例,分镜头脚本教程图解
  20. STM32F411RE项目开发-3-定时器的使用

热门文章

  1. 三年磨一剑:蚂蚁金服的研发效能洞察实践
  2. 建行浙江分行总部【等。。。】
  3. PHP图片拼接util
  4. ubuntu 11.10使用fcitx双拼输入法
  5. pycharm环境下导入包
  6. 机器学习从入门到创业手记-sklearn基础设计
  7. 设置 IntelliJ IDEA 主题和字体的方法
  8. 《恐怖电脑》技术支持
  9. 如何保存window10锁屏壁纸
  10. iOS中延时执行(睡眠)的几种方法