想法来源:李宏毅老师的机器学习课
通过改写Alexnet网络最后一层全连接层,使其能分辨宝可梦和数码宝贝。

数据集准备以及预处理

数据来源:
宝可梦
数码宝贝

  • 下载宝可梦和数码宝贝,设置训练集和测试集
  • 将图片搜集好后,通过transforms将图像转换为张量(tensor)格式,并对其进行归一化,其中,我使用了PyTorch的torchvision模块中的ImageFolder类来加载我们的数据集。
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Datasettransform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
# 因为我使用GPU训练模型,因此通过参数pin_memory=True来告诉DataLoader将数据加载到固定的内存中,这样可以更快地将数据传输到GPU
trainset = torchvision.datasets.ImageFolder(root='Dataset/train', transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)testset = torchvision.datasets.ImageFolder(root='Dataset/test', transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

构建模型

从torchvision中获取预训练的AlexNet模型,将最后一层输出1000个类别的全连接层改为输出2个类别

import torchvision.models as models
from torch import nn, optim
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = models.alexnet(weights=True)
num_features = model.classifier[-1].in_features
model.classifier[-1] = torch.nn.Linear(num_features, 2)

训练模型

超参epoch设为20,学习率为0.001
使用随机梯度下降(SGD)优化器和交叉熵损失函数对其进行训练。

model = model.to(device)criterion = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)n_epochs = 20for epoch in range(n_epochs):running_loss = 0.0model.train()for i, data in enumerate(trainloader, 0):inputs, labels = datainputs = inputs.type(torch.FloatTensor).to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i == 54:print('[%d, %5d] loss: %.5f' % (epoch , i, running_loss/55))# torch.save(model.state_dict(),"saveModel.pt")

测试模型

correct = 0
total = 0
model.eval()
with torch.no_grad():for data in testloader:images, labels = dataimages = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)print("predicte_mon",predicted)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy : %d %%' % (100 * correct / total))

输出准确率:(因为我的测试集只设了30个,所以准确度很高)

使用AlexNet网络区分宝可梦和数码宝贝相关推荐

  1. 宝可梦维护服务器,宝可梦大师卡在登录界面进不去,宝可梦大师为啥玩不了

    游戏介绍:<宝可梦大师>是一款由The Pokemon Company发行的手机游戏. 游戏中玩家将扮演训练家,带领宝可梦和队员们以"世界宝可梦大师赛"的冠军为目标开始 ...

  2. 用python画皮卡丘源代码-实现童年宝可梦,教你用Python画一只属于自己的皮卡丘...

    原标题:实现童年宝可梦,教你用Python画一只属于自己的皮卡丘 大数据文摘出品 作者:李雷.蒋宝尚 还记得小时候疯狂收集和交换神奇宝贝卡片的经历吗? 还记得和小伙伴拿着精灵球,一起召唤小精灵的中二模 ...

  3. python简单代码画皮卡丘-实现童年宝可梦,教你用Python画一只属于自己的皮卡丘...

    原标题:实现童年宝可梦,教你用Python画一只属于自己的皮卡丘 大数据文摘出品 作者:李雷.蒋宝尚 还记得小时候疯狂收集和交换神奇宝贝卡片的经历吗? 还记得和小伙伴拿着精灵球,一起召唤小精灵的中二模 ...

  4. 用python画皮卡丘代码-实现童年宝可梦,教你用Python画一只属于自己的皮卡丘

    大数据文摘出品 作者:李雷.蒋宝尚 还记得小时候疯狂收集和交换神奇宝贝卡片的经历吗? 还记得和小伙伴拿着精灵球,一起召唤小精灵的中二模样吗? 最近上映的<大侦探皮卡丘>,是否会让你秒回童年 ...

  5. 深度学习tensorflow实现宝可梦图像分类

    目录 一.数据集简介 二.数据预处理 三.构建卷积神经网络 四.模型训练 五.预测 六.分析与优化 一.数据集简介 宝可梦数据集(共1168张图像):bulbasaur(妙蛙种子,234).charm ...

  6. 用python画皮卡丘-实现童年宝可梦,教你用Python画一只属于自己的皮卡丘

    大数据文摘出品 作者:李雷.蒋宝尚 还记得小时候疯狂收集和交换神奇宝贝卡片的经历吗? 还记得和小伙伴拿着精灵球,一起召唤小精灵的中二模样吗? 最近上映的<大侦探皮卡丘>,是否会让你秒回童年 ...

  7. 基于ResNetRS的宝可梦图像识别

    基于ResNetRS的宝可梦图像识别 1.ResNet-D架构 2.ResNetRS架构 3.手动搭建模型(Tensorflow) 3.1 模型配置项 3.2 get_survival_probabi ...

  8. 怎么用python画皮卡丘_实现童年宝可梦,教你用Python画一只属于自己的皮卡丘

    原标题:实现童年宝可梦,教你用Python画一只属于自己的皮卡丘 大数据文摘出品 作者:李雷.蒋宝尚 还记得小时候疯狂收集和交换神奇宝贝卡片的经历吗? 还记得和小伙伴拿着精灵球,一起召唤小精灵的中二模 ...

  9. 个体值0和31差多少攻击_口袋妖怪:攻略篇!个体值有多重要?6V宝可梦才是完美的精灵!...

    爱生活,爱游戏,大家好,我是你们的好朋友汤圆.关注汤圆,收获更多快乐哦! 口袋妖怪到目前为止有近900只精灵,18种属性,每种属性都代表着克制关系,但我们今天不讲这个,我们讲精灵的能力值,看看你心仪的 ...

最新文章

  1. java的线程管理器,QuickThread - Java线程池管理器
  2. mysql用sql语句怎么做个脚本备份_mysql备份脚本
  3. HTML5超链接和多媒体,IT兄弟连 HTML5教程 多媒体应用 创建图像和链接
  4. Android 自定义 View
  5. oracle 分组统计行数,求助分组之后进行统计行数
  6. 你看得懂的CSMA介质访问控制原理
  7. ios9版本的iphone,不执行网页js
  8. 基本初等函数导数公式表
  9. 恒生电子面试(面试介绍,面试流程,面试建议,面试题库(软测方向))
  10. chm格式电子书另类反编译法:使用压缩软件7Z简单实现CHM电子书反编译 | 志文工作室
  11. photoshop修色圣典 第5版pdf
  12. Jvav-C++/真正的Jvav
  13. 【JavaBigDecimal练习】利用BigDecimal精确计算欧拉数
  14. 有运气摇号来不及挑选?网易有数帮你科学选房
  15. 可能是因为该宏在此工作簿中不可用,或者所有的宏都被禁用
  16. 从零开始搭建SpringBoot项目(一)——开发环境搭建(图文详细)
  17. Linux执行某些命令缺少libtinfo.so.5
  18. 你可知道,让你发胖的食物不是高脂肪食物,而是高碳水化合物
  19. 一份完整的单机版slurm部署
  20. 解决Manifest merger failed : Attribute application@appComponentFactory

热门文章

  1. Windows Phone 8.1 新特性 - 控件之列表选择控件
  2. SQL学习笔记(1)
  3. batch,draw call
  4. Hystrix(豪猪)的原理探索(一)
  5. 2011年最新出炉的爆笑签名
  6. 佳能相机的拍照应用开发canon EDSDK C#
  7. 2017幼儿园计算机培训,2017幼儿园
  8. javaEE面试重点
  9. Warning: can't write resource [META-INF/MANIFEST.MF] (Duplicate zip entry [yyy.jar:META-INF/MANIFEST
  10. iOS 应用提交 App Store 上架被拒的原因收集