目录

一.数据处理

二.构造网络

三.训练和测试

四.展示结果


一.数据处理

Dogs vs. Cats(猫狗大战),其中训练集有20000张,猫狗各占一半,验证集20000,测试集2000张,没有标定是猫还是狗。要求设计一种算法对测试集中的猫狗图片进行判别,是一个传统的二分类问题。

拿到数据,先查看数据集,可以看到图片的大小均不一致且没有y值。所以我们需要自己在先将数据处理好。

以下是代码。

class Datachushihua(Dataset):def __init__(self,mode,dir) :self.mode=modeself.img_list=[]# 存放路径self.label_list=[] #转化输出self.data_size=0self.transform=afterdataif self.mode == 'train'or self.mode =='val':dir = dir + self.mode+'/'for file in tqdm(os.listdir(dir)):img = Image.open(dir+file)   # 打开图片self.img_list.append(self.transform(img))self.data_size += 1file_try=filename=file_try.split(sep='_') #切割字符串返回if name[0]=='cat':self.label_list.append(0)else:self.label_list.append(1)self.label_list = torch.LongTensor(self.label_list) elif self.mode == 'test':           # 测试集模式下,只需要提取图片路径就行dir = dir + '/test/'            # 测试集路径为"dir"/test/for file in os.listdir(dir):self.img_list.append(dir + file)    # 添加图片路径至image listself.data_size += 1else:print('Undefined Dataset!')def __getitem__(self,item):#返回下标if self.mode == 'train'or self.mode =='val':                                      # 训练集模式下需要读取数据集的image和labelreturn self.img_list[item], self.label_list[item]elif self.mode == 'test':                                       # 测试集只需读取imageimg = Image.open(self.img_list[item])return self.transform(img)                                  # 只返回imageelse:print('None')def __len__(self):return self.data_size  

二.构造网络

构造model并且加入inception和残差网络,具体的网络结构如图所示。

#网络结构
class Model(torch.nn.Module):def __init__(self) :super(Model,self).__init__()self.Conv1=torch.nn.Conv2d(3, 16, 3, padding=1)   #3*200*200-->16*200*200self.Conv2=torch.nn.Conv2d(88, 16, 3, padding=1)self.incept1 = InceptionA(in_channels=16)self.rblock1=ResidualBlock(88)self.pooling=torch.nn.MaxPool2d(2)self.linear1=torch.nn.Linear(220000,2)# self.linear4=torch.nn.Linear(40000,128)## #self.linear3=torch.nn.Linear(64,2)def forward(self,x):x=F.relu(self.Conv1(x))x=self.pooling(x)#16*200*200-->16*100*100x = self.incept1(x)#16*100*100-->88*100*100x=self.rblock1(x)x=F.relu(self.Conv2(x))#88*100*100-->16*100*100x=self.pooling(x)#16*100*100-->16*50*50x = self.incept1(x)#16*50*50-->88*50*50x=self.rblock1(x)x=x.view(x.size()[0], -1)#x=F.relu(self.linear1(x))#x=F.relu(self.linear4(x))#x=F.relu(self.linear2(x))x=self.linear1(x)return x# 构造自己的Netclass ResidualBlock(torch.nn.Module):def __init__(self,channels):super(ResidualBlock,self).__init__()self.channels=channelsself.conv1=torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)#为了使大小不变self.conv2=torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)def forward(self,x):y=F.relu(self.conv1(x))y=self.conv2(y)return F.relu(x+y)   # 构造Inception block
class InceptionA(torch.nn.Module):def __init__(self, in_channels):super(InceptionA, self).__init__()self.averag_pool = torch.nn.AvgPool2d(kernel_size=3, stride=1, padding=1)self.conv1_1_24 = torch.nn.Conv2d(in_channels,24,kernel_size=1)self.conv1_1_16 = torch.nn.Conv2d(in_channels,16,kernel_size=1)self.conv5_5_24 = torch.nn.Conv2d(16,24,kernel_size=5,padding=2)self.conv3_3_24_1 = torch.nn.Conv2d(16,24,kernel_size=3, padding=1)self.conv3_3_24_2 = torch.nn.Conv2d(24,24,kernel_size=3,padding=1)def forward(self, x):x1 = self.averag_pool(x)#不改变大小x1 = self.conv1_1_24(x1)#16*100*100-->24*100*100(1)x2 = self.conv1_1_16(x)#16*100*100-->16*100*100(1)x3 = self.conv1_1_16(x)#16*100*100-->16*100*100(1)x3 = self.conv5_5_24(x3)#16*100*100-->24*100*100(1)x4 = self.conv1_1_16(x)#16*100*100-->16*100*100(1)x4 = self.conv3_3_24_1(x4)#16*100*100-->24*100*100(1)x4 = self.conv3_3_24_2(x4)#24*100*100-->24*100*100outputs = [x1,x2,x3,x4]return torch.cat(outputs, dim=1)#88*100*100

三.训练和测试

这里的test修改了以下,可以随机从测试集中选取图片。

model=Model()
model.cuda()
batch_size=16
lr=0.001
sunshi=torch.nn.CrossEntropyLoss()
youhua=optim.SGD(model.parameters(),lr=lr,momentum=0.5)#训练,验证,与测试
def train(epoch):print("epoch:",epoch+1)running_loss = 0for batch_index,data in enumerate(tqdm(train_loader)):input,target=datainput = input.cuda()target = target.cuda()output=model(input)loss=sunshi(output,target)running_loss += lossyouhua.zero_grad()loss.backward()youhua.step()print("train loss:", (running_loss).item()/batch_index)     torch.save(model.state_dict(), './model.pth')              # 训练所有数据后,保存网络的参数def val(val_loader, model):total = 0correct = 0with torch.no_grad():for batch_index, data in enumerate(val_loader,0):inputs, labels = datainputs = inputs.cuda()labels = labels.cuda()outputs = model(inputs)# 取维度最大_, predicts = torch.max(outputs,dim=1)total += labels.size(0)correct += (predicts==labels).sum().item()print("正确率:", correct/total)return correct/total
def test():font={  'color': 'red','size': 20,'family': 'Times New Roman','style':'italic'}model=Model()model = model.cuda()model.load_state_dict(torch.load("model.pth"))  index = np.random.randint(0, test_dataset.data_size, 1)[0]      # 获取一个随机数,即随机从数据集中获取一个测试图片img = test_dataset.__getitem__(index)                           # 获取一个图像img = img.unsqueeze(0)                                      # 因为网络的输入是一个4维Tensor,3维数据,1维样本大小,所以直接获取的图像数据需要增加1个维度img = img.cuda()                                  # 将数据放置在PyTorch的Variable节点中,并送入GPU中作为网络计算起点out = model(img)                                            # 网路前向计算,输出图片属于猫或狗的概率,第一列维猫的概率,第二列为狗的概率out = F.softmax(out, dim=1)                                        # 采用SoftMax方法将输出的2个输出值调整至[0.0, 1.0],两者和为1img = Image.open(test_dataset.img_list[index])      # 打开测试的图片plt.figure('image')plt.imshow(img)if out[0, 0] > out[0, 1]:                 # 猫的概率大于狗plt.text(0, -6.0, "prediction: cat", fontdict=font)else:                                       # 猫的概率小于狗plt.text(0, -6.0, "prediction: dog", fontdict=font)plt.show()

四.展示结果

dataset_dir='./cat_dog/'
train_dataset = Datachushihua('train', dataset_dir)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print('Dataset loaded! length of train set is {0}'.format(len(train_loader)))
val_dataset=Datachushihua('val', dataset_dir)
val_loader = DataLoader(val_dataset, batch_size=64)
test_dataset=Datachushihua('test', dataset_dir)
test_loader = DataLoader(test_dataset, batch_size=64)if __name__=='__main__':for epoch in range(20):train(epoch)val(val_loader, model)test()

可以看到正确率稳定在了 78%左右。

主要·时间还是浪费在了数据处理那一块,好在网络改进后正确率提高了。

图像分类——猫狗大战问题相关推荐

  1. PyTorch深度学习图像分类--猫狗大战

    PyTorch深度学习图像分类--猫狗大战 1.背景介绍 2.环境配置 2.1软硬件清单 2.1.1配置PyPorch 2.1.2开发软件 2.1.3 显卡 2.2 数据准备 3 基础理论 3.1Py ...

  2. 基于PyTorch的卷积神经网络图像分类——猫狗大战(二):使用Pytorch定义网络模型

    文章目录 1. 需要用到的库 2. 模型定义 3. 测试 基于上一篇文章 https://blog.csdn.net/linghu8812/article/details/100044971,这次介绍 ...

  3. 猫狗大战(kaggle竞赛-猫狗图像分类)

    本实验使用kaggle中猫狗大战中的部分数据集(2000张训练数据+500张测试数据) 本次实验中使用了DNN.CNN.RNN分别进行了图像识别,具体代码如下: DNN模型: 全连接层 神经元个数 F ...

  4. 赠书 | 图像分类问题建模方案探索实践

    作者 | 中国农业银行 陆春晖 责编 | 晋兆雨 出品 | AI科技大本营 头图 | 付费下载于视觉中国 *文末有赠书福利 背景 图像分类,是计算机视觉领域的一个核心问题,顾名思义就是输入一张图像,根 ...

  5. 【经验分享】TinyMind 多标签图像分类竞赛小试牛刀——by:for the dream

    多标签图像分类竞赛地址:https://www.tinymind.cn/competitions/42?from=blog 队伍:for the dream,其实是大酒神死忠粉~ 初次拿到这个题目,想 ...

  6. python +keras实现图像分类(入门级例子讲解)

    一.项目描述 数据集来源于kaggle猫狗大战数据集.训练集有25000张,猫狗各占一半.测试集12500张.希望计算机可以从这些训练集图片中学习到猫狗的特征,从而使得计算机可以正确的对未曾见过的猫狗 ...

  7. 【图像分类】实战——使用ResNet实现猫狗分类(pytorch)

    目录 摘要 导入项目使用的库 设置全局参数 图像预处理 读取数据 设置模型 设置训练和验证 验证 完整代码: 摘要 ResNet(Residual Neural Network)由微软研究院的Kaim ...

  8. 基于Keras2《面向小数据集构建图像分类模型》——Kaggle猫狗数据集

    概述 在本文中,将使用VGG-16模型提供一种面向小数据集(几百张到几千张图片)构造高效.实用的图像分类器的方法并给出试验结果. 本文将探讨如下几种方法: 从图片中直接训练一个小网络(作为基准方法) ...

  9. tensorflow+k-means聚类 简单实现猫狗图像分类

    文章目录 一.前言 二.k-means聚类 三.图像分类 一.前言 本文使用的是 kaggle 猫狗大战的数据集:https://www.kaggle.com/c/dogs-vs-cats/data ...

最新文章

  1. 在Ubuntu 14.04 64bit上安装百度云Linux客户端BCloud
  2. tf.reverse_sequence
  3. [痛并快乐着 国外开发者总结欧美游戏坑钱指南] 讀後感想
  4. 几个重要库函数的实现
  5. R语言学习笔记(四)参数估计
  6. vue ---- vue 的入门程序
  7. Python的开源人脸识别库:离线识别率高达99.38%
  8. HDMI接口是什么?HDMI接口的基础知识讲解
  9. 计算机学科a类排名,哈工大17个学科排名位列A类
  10. 期货量化交易matlab,【策略分享】Matlab量化交易策略源码分享
  11. C语言中有关字符串的库函数(3)
  12. Resource is out of sync with the file system的解决办法
  13. java流重定向如何分类,Java 文件流与标准流之间的重定向
  14. python判断手机号运营商_python手机号码运营商归属测试
  15. 终于给自己买了台电脑
  16. 黑马程序员机器学习Day2学习笔记
  17. Apache的Order Allow,Deny 规则
  18. 无线广播相关信号(收音机)的发射与接收
  19. 开源数据库MySQL DBA运维实战 第2章 SQL1
  20. 尚硅谷在线教育七:尚硅谷在线教育项目课程管理相关的开发

热门文章

  1. 各省金融机构存贷款余额、GDP、金融化程度、城镇化率、大专以上人口比重等(2012-2019年)
  2. 管理系统中计算机应用实践大纲,管理系统中计算机应用实践技能考核大纲及操作指导...
  3. 人物-商界-杨惠妍:杨惠妍
  4. 图案设计灵感怎么写_服装设计灵感来源怎么写_服装设计理念怎么写
  5. 一梦江湖(楚留香)自用日常手机脚本
  6. 网店美工之你不知道的图片设计技巧
  7. 第33篇 Android Studio实现五子棋游戏(四)棋子类和主类
  8. python变量名必须以字母或下划线开头不区分字母大小写_Python变量名必须以字符或下划线开头,并且区分字母大小写。...
  9. 如何用c++发出音乐
  10. java必备的开发知识和技能