一:数据集准备:     

数据集来源于研习社:猫狗大战--经典图像分类题 - AI算法竞赛-AI研习社 (yanxishe.com)

二:读取数据:

class Cat_and_Dog_Dataset(Dataset):def __init__(self, filepath):self.images = []self.labels = []self.transform = transformfor filename in tqdm(os.listdir(filepath)):image = Image.open(filepath + filename)image = image.resize((224,224))         #裁剪图片image = self.transform(image)           #转化为Tensor格式self.images.append(image)if filename.split('_')[0] == 'cat':     #添加标签self.labels.append(0)elif filename.split('_')[0] == 'dog':self.labels.append(1)self.labels = torch.LongTensor(self.labels) #标签转化位Tensor格式print(self.labels)def __getitem__(self, index):           #构造迭代器return self.images[index], self.labels[index]def __len__(self):                      #迭代器长度images = np.array(self.images)len = images.shape[0]return lentrain_data = Cat_and_Dog_Dataset('cat_dog/train/')              #加载训练集
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)  val_data = Cat_and_Dog_Dataset('cat_dog/val/')                  #加载验证集
val_loader = DataLoader(dataset=val_data, batch_size=64, shuffle=True)

二:搭建网络:

class InceptionA(torch.nn.Module):                              #构造Inception层def __init__(self,in_channels):super(InceptionA,self).__init__()self.branch1x1 = torch.nn.Conv2d(in_channels,16,kernel_size = 1)self.branch5x5_1 = torch.nn.Conv2d(in_channels,16,kernel_size = 1)self.branch5x5_2 = torch.nn.Conv2d(16,24,kernel_size = 5,padding = 2)self.branch3x3_1 = torch.nn.Conv2d(in_channels,16,kernel_size = 1)self.branch3x3_2 = torch.nn.Conv2d(16,24,kernel_size = 3,padding = 1)self.branch3x3_3 = torch.nn.Conv2d(24,24,kernel_size = 3,padding = 1)self.branch_pool = torch.nn.Conv2d(in_channels,24,kernel_size = 1)def forward(self,x):branch1x1 = self.branch1x1(x)branch5x5 = self.branch5x5_1(x)branch5x5 = self.branch5x5_2(branch5x5)branch3x3 = self.branch3x3_1(x)branch3x3 = self.branch3x3_2(branch3x3)branch3x3 = self.branch3x3_3(branch3x3)branch_pool = F.avg_pool2d(x,kernel_size = 3,stride = 1,padding = 1)branch_pool = self.branch_pool(branch_pool)outputs = [branch1x1,branch3x3,branch5x5,branch_pool]return torch.cat(outputs,dim = 1)class Net(torch.nn.Module):                             #构造网络def __init__(self):super(Net,self).__init__()self.conv1 = torch.nn.Conv2d(3,10,kernel_size = 5)self.conv2 = torch.nn.Conv2d(88,20,kernel_size = 5)self.incep1 = InceptionA(in_channels = 10)self.incep2 = InceptionA(in_channels = 20)self.mp = torch.nn.MaxPool2d(2)self.fc = torch.nn.Linear(247192,2)def forward(self,x):in_size = x.size(0)x = self.mp(F.relu((self.conv1(x))))x = self.incep1(x)x = self.mp(F.relu((self.conv2(x))))x = self.incep2(x)x = x.view(in_size,-1)x = self.fc(x)return x

三:构建损失函数和优化器:

criterion = torch.nn.CrossEntropyLoss()         #构造损失函数
optimizer = optim.SGD(model.parameters(),lr = 0.001)        #构造优化器

四:训练模型:

def train(epoch):running_loss = 0.0                              #训练模型for batch_idx,data in enumerate(train_loader,0):inputs,targets = datainputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()                       #梯度清零outputs = model(inputs)                     #前向传播loss = criterion(outputs,targets)           #计算损失loss.backward()                             #逆向传播optimizer.step()                            #梯度递进running_loss += loss.item()print('train loss: %.3f' % (running_loss/batch_idx))

五:验证模型:

def val():                                          #验证模型精度correct = 0total = 0with torch.no_grad():                           #不需要梯度,减少计算量for data in val_loader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()print('accuracy on test set: %d %% ' % (100*correct/total))return correct/total

源码:

import torch
import numpy as np
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
import torchvision
import torch.nn.functional as F
import torch.optim as optim
import os
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as pltbatch_size = 64transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])#设置transformclass Cat_and_Dog_Dataset(Dataset):def __init__(self, filepath):self.images = []self.labels = []self.transform = transformfor filename in tqdm(os.listdir(filepath)):image = Image.open(filepath + filename)image = image.resize((224,224))         #裁剪图片image = self.transform(image)           #转化为Tensor格式self.images.append(image)if filename.split('_')[0] == 'cat':     #添加标签self.labels.append(0)elif filename.split('_')[0] == 'dog':self.labels.append(1)self.labels = torch.LongTensor(self.labels) #标签转化位Tensor格式print(self.labels)def __getitem__(self, index):           #构造迭代器return self.images[index], self.labels[index]def __len__(self):                      #迭代器长度images = np.array(self.images)len = images.shape[0]return lentrain_data = Cat_and_Dog_Dataset('cat_dog/train/')              #加载训练集
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)  val_data = Cat_and_Dog_Dataset('cat_dog/val/')                  #加载验证集
val_loader = DataLoader(dataset=val_data, batch_size=64, shuffle=True)class InceptionA(torch.nn.Module):                              #构造Inception层def __init__(self,in_channels):super(InceptionA,self).__init__()self.branch1x1 = torch.nn.Conv2d(in_channels,16,kernel_size = 1)self.branch5x5_1 = torch.nn.Conv2d(in_channels,16,kernel_size = 1)self.branch5x5_2 = torch.nn.Conv2d(16,24,kernel_size = 5,padding = 2)self.branch3x3_1 = torch.nn.Conv2d(in_channels,16,kernel_size = 1)self.branch3x3_2 = torch.nn.Conv2d(16,24,kernel_size = 3,padding = 1)self.branch3x3_3 = torch.nn.Conv2d(24,24,kernel_size = 3,padding = 1)self.branch_pool = torch.nn.Conv2d(in_channels,24,kernel_size = 1)def forward(self,x):branch1x1 = self.branch1x1(x)branch5x5 = self.branch5x5_1(x)branch5x5 = self.branch5x5_2(branch5x5)branch3x3 = self.branch3x3_1(x)branch3x3 = self.branch3x3_2(branch3x3)branch3x3 = self.branch3x3_3(branch3x3)branch_pool = F.avg_pool2d(x,kernel_size = 3,stride = 1,padding = 1)branch_pool = self.branch_pool(branch_pool)outputs = [branch1x1,branch3x3,branch5x5,branch_pool]return torch.cat(outputs,dim = 1)class Net(torch.nn.Module):                             #构造网络def __init__(self):super(Net,self).__init__()self.conv1 = torch.nn.Conv2d(3,10,kernel_size = 5)self.conv2 = torch.nn.Conv2d(88,20,kernel_size = 5)self.incep1 = InceptionA(in_channels = 10)self.incep2 = InceptionA(in_channels = 20)self.mp = torch.nn.MaxPool2d(2)self.fc = torch.nn.Linear(247192,2)def forward(self,x):in_size = x.size(0)x = self.mp(F.relu((self.conv1(x))))x = self.incep1(x)x = self.mp(F.relu((self.conv2(x))))x = self.incep2(x)x = x.view(in_size,-1)x = self.fc(x)return xmodel = Net()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)            #转换为cuda格式criterion = torch.nn.CrossEntropyLoss()         #构造损失函数
optimizer = optim.SGD(model.parameters(),lr = 0.001)        #构造优化器def train(epoch):running_loss = 0.0                              #训练模型for batch_idx,data in enumerate(train_loader,0):inputs,targets = datainputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()                       #梯度清零outputs = model(inputs)                     #前向传播loss = criterion(outputs,targets)           #计算损失loss.backward()                             #逆向传播optimizer.step()                            #梯度递进running_loss += loss.item()print('train loss: %.3f' % (running_loss/batch_idx))def val():                                          #验证模型精度correct = 0total = 0with torch.no_grad():                           #不需要梯度,减少计算量for data in val_loader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()print('accuracy on test set: %d %% ' % (100*correct/total))return correct/totalif __name__ == '__main__':accuracy_list = []epoch_list = []for epoch in range(10):train(epoch)acc = val()accuracy_list.append(acc)epoch_list.append(epoch)plt.plot(epoch_list,accuracy_list)plt.xlabel(epoch)plt.ylabel(accuracy_list)plt.show()

Pytorch实战:用经典网络实现猫狗大战相关推荐

  1. PyTorch实现:经典网络 NiN

    PyTorch实现:经典网络 NiN (Network in Network) 多输入多输出卷积: 一般输入的图像具有 3∗h∗w3 * h * w3∗h∗w 的形状,是一个通道数为3,尺寸h∗wh ...

  2. PyTorch实现:经典网络 ResNet

    经典网络 ResNet 1 简述 GoogleNet 和 VGG 等网络证明了,更深度的网络可以抽象出表达能力更强的特征,进而获得更强的分类能力.在深度网络中,随之网络深度的增加,每层输出的特征图分辨 ...

  3. 手撕 CNN 经典网络之 VGGNet(PyTorch实战篇)

    大家好,我是红色石头! 在上一篇文章: 手撕 CNN 经典网络之 VGGNet(理论篇) 详细介绍了 VGGNet 的网络结构,今天我们将使用 PyTorch 来复现VGGNet网络,并用VGGNet ...

  4. Pytorch实战第一步--用经典神经网络实现猫狗大战

    文章目录 前言 一.猫狗大战数据集 二.pytorch实战 1.程序整体结构 2.读入数据 3.网络结构 4.网络结构 5测试 总结 总结 前言 随着人工智能的不断发展,机器学习这门技术也越来越重要, ...

  5. 手撕 CNN 之 AlexNet(PyTorch 实战篇)

    大家好,我是红色石头! 在上一篇文章: 手撕 CNN 经典网络之 AlexNet(理论篇) 详细介绍了 AlexNet 的网络结构,今天我们将使用 PyTorch 来复现AlexNet网络,并用Ale ...

  6. 神经网络中的注意力机制总结及PyTorch实战

    技术交流 QQ 群:1027579432,欢迎你的加入! 欢迎关注我的微信公众号:CurryCoder的程序人生 0.概述 当神经网络来处理大量的输入信息时,也可以借助人脑的注意力机制,只选择一些关键 ...

  7. 手撕 CNN 经典网络之 VGGNet(理论篇)

    2014年,牛津大学计算机视觉组(Visual Geometry Group)和Google DeepMind公司一起研发了新的卷积神经网络,并命名为VGGNet.VGGNet是比AlexNet更深的 ...

  8. 【深度学习】手撕 CNN 之 AlexNet(PyTorch 实战篇)

    今天我们将使用 PyTorch 来复现AlexNet网络,并用AlexNet模型来解决一个经典的Kaggle图像识别比赛问题. 正文开始! 1. 数据集制作 在论文中AlexNet作者使用的是ILSV ...

  9. PyTorch实战使用Resnet迁移学习

    PyTorch实战使用Resnet迁移学习 项目结构 项目任务 项目代码 网络模型测试 项目结构 数据集存放在flower_data文件夹 cat_to_name.json是makejson文件运行生 ...

最新文章

  1. oracle 分区使用情况,Oracle Hash分区的使用总结
  2. OpenGL硬件加速指南
  3. linux下进程的tty,Linux下TTY驱动程序分析
  4. Java架构师成长之道之浅谈计算机系统架构
  5. struts2官方 中文教程 系列一:创建一个struts2 web Application
  6. 介绍 WebLogic 的一些结构和特点
  7. Django1.7开发博客
  8. OpenCV2.4.4中调用SIFT特征检测器进行图像匹配
  9. VLAN TRUNK 链路聚合 网络层路由器
  10. Free SQLSever 2008的书
  11. python解放二次开发_[转载]Python二次开发程序详解
  12. 使用.NET Core进行Linux编程3:简介和第2章
  13. mysql主从搭建_手把手教你搭建MySQL主从架构
  14. oracle的sql的语法解析,oracle SQL解析步骤小结
  15. 计算机专业可以当警校吗,警校开设的计算机类专业,毕业生就业方向偏向于信息安全,请注意...
  16. 【转】为什么360这种软件存活至今?程序员:打不死的小强,春风吹又生
  17. 根据经纬度查询地名,
  18. 文档在线查看功能的实现
  19. 只用div+CSS做淘宝手机端首页
  20. 高通骁龙820A凭什么能赢得众多车厂的芳心

热门文章

  1. 骆驼(Camel)命名法、帕斯卡(Pascal)命名法、匈牙利命名法
  2. 寻找真正的入口(OEP)--广义ESP定律
  3. IGBT器件选型参考
  4. 如何用Qt绘制一颗好看的二叉树
  5. iastora怎么改成ahci_怎么把硬盘更改成ahci模式 AHCI功能开启方法
  6. 葡,西两国发展史(大航海时代)启示
  7. 《Swift4打造今日头条视频实战项目实战》最新
  8. 会声会影X10中文版序列号32位/64位下载教程
  9. Win10:鼠标右键如何添加快捷关机、注销等功能
  10. 2019 年 (A 题) 电动小车动态无线充电系统