数据准备

猫狗大战数据集下载链接

微软的数据集已经分好类,直接使用就行,

数据划分

我们将猫和狗的图片分别移动到训练集和验证集中,其中90%的数据作为训练集,10%的图片作为验证集,使用shutil.move()来移动图片。

新建文件夹train,test,将数据集放入train中,利用代码将10%的数据移动到test中

文件移动代码

import os
import shutil
source_path = r"E:\猫狗大战数据集\PetImages"
train_dir = os.path.join(source_path, "train")
test_dir = os.path.join(source_path,"test")
train_dir_list = os.listdir(train_dir)
for dir in train_dir_list:category_dir_path = os.path.join(train_dir, dir)image_file_list = os.listdir(category_dir_path)num = int(0.1*len(image_file_list))#移动10%文件到对应目录for i in range(num):shutil.move(os.path.join(category_dir_path,image_file_list[i]),os.path.join(test_dir,dir,image_file_list[i]))

移动后

数据可视化

import matplotlib.pyplot as plt
import numpy
import os
from PIL import Image #读取图片模块
from matplotlib.image import imread
source_path = r"E:\猫狗大战数据集\PetImages"
#分别从Dog,Cat文件夹中选取10张图片显示
train_Dog_dir = os.path.join(source_path, "train","Dog")
train_Cat_dir = os.path.join(source_path, "train","Cat")
Dog_image_list = os.listdir(train_Dog_dir)Cat_image_list = os.listdir(train_Cat_dir)
show_image = [os.path.join(train_Dog_dir,Dog_image_list[i]) for i in range(10)]
show_image.extend([os.path.join(train_Cat_dir,Cat_image_list[i]) for i in range(10)])
for i in show_image:print(i)
plt.figure()for i in range(1,20):plt.subplot(4,5,i)img = Image.open(show_image[i-1])plt.imshow(img)plt.show()

效果图:

可以看出图片的尺寸不同,在数据预处理时需要将图片resize,

使用预训练模型(resnet)进行训练

from    torchvision import datasets, transforms
import torch.utils.data
import torch.nn as nn
import torchvision.models as models
import torch.optim as optimfrom visdom import Visdomif __name__ == '__main__':#数据处理data_transform = transforms.Compose([transforms.Resize(128),transforms.CenterCrop(128),transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])train_dataset = datasets.ImageFolder(root=r'E:/猫狗大战数据集/PetImages/train/', transform=data_transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)test_dataset = datasets.ImageFolder(root=r'E:/猫狗大战数据集/PetImages/test/', transform=data_transform)test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=True, num_workers=4)#损失函数criteon = nn.CrossEntropyLoss()#加载预训练模型transfer_model = models.resnet18(pretrained=True)dim_in = transfer_model.fc.in_featurestransfer_model.fc = nn.Linear(dim_in, 2)#优化器adamoptimizer = optim.Adam(transfer_model.parameters(), lr=0.01)#加载模型到GPUtransfer_model = transfer_model.cuda()viz = Visdom()viz.line([[0.0,0.0]],[0.],win='train',opts=dict(title="train_loss&&acc", legend=['loss','acc']))viz.line([[0.0,0.0]], [0.], win='test', opts=dict(title="test loss&&acc.",legend=['loss', 'acc']))global_step =0#模型训练transfer_model.train()for epoch in range(10):train_acc_num =0test_acc_num =0for batch_idx,(data,target) in enumerate(train_loader):data, target = data.cuda(), target.cuda()#投入数据,得到预测值logits = transfer_model(data)_,pred = torch.max(logits.data,1)#print(pred, target)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()optimizer.step()#准确度计算train_acc_num += pred.eq(target).float().sum().item()#print("准确数:",train_acc_num," ",batch_idx, " ",len(data))train_acc = train_acc_num/((batch_idx+1)*len(data))#print(train_acc)#print(train_acc.item())global_step +=1viz.line([[loss.item(), train_acc]],[global_step],win='train',update='append')if batch_idx %200 ==0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f},acc:{}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item(),train_acc))test_loss =0for data, target in test_loader:data, target = data.cuda(), target.cuda()logits = transfer_model(data)test_loss += criteon(logits,target).item()_, pred = torch.max(logits.data, 1)# 准确度计算test_acc_num += pred.eq(target).float().sum().item()viz.line([[test_loss / len(test_loader.dataset), test_acc_num / len(test_loader.dataset)]],[global_step], win='test', update='append')test_acc = train_acc_num / len(test_loader.dataset)viz.images(data.view(-1, 3, 128, 128), win='x')viz.text(str(pred.detach().cpu().numpy()), win='pred',opts=dict(title='pred'))test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, test_acc, len(test_loader.dataset),100. * test_acc / len(test_loader.dataset)))

Train Epoch: 0 [0/22498 (0%)]    Loss: 1.061759,acc:0.25
Train Epoch: 0 [800/22498 (4%)]    Loss: 0.708053,acc:0.5174129353233831
Train Epoch: 0 [1600/22498 (7%)]    Loss: 0.403057,acc:0.5155860349127181
Train Epoch: 0 [2400/22498 (11%)]    Loss: 0.721054,acc:0.5033277870216306
Train Epoch: 0 [3200/22498 (14%)]    Loss: 0.629318,acc:0.5037453183520599

其中visdom是模型可视化模块

深度学习实战---猫狗大战(pytorch实现)相关推荐

  1. python猫狗大战pytorch_深度学习实战---猫狗大战(pytorch实现)

    数据准备 猫狗大战数据集下载链接 微软的数据集已经分好类,直接使用就行, 数据划分 我们将猫和狗的图片分别移动到训练集和验证集中,其中90%的数据作为训练集,10%的图片作为验证集,使用shutil. ...

  2. 深度学习实战_五天入门深度学习,这里有一份PyTorch实战课程

    这是一门五天入门深度学习的实战课程. 想入门深度学习的小伙伴有福了!dataflowr 最近推出了一门五天初步掌握深度学习的实战教程(实战使用 PyTorch 框架),有知识点有实例有代码,值得一看. ...

  3. 实战例子_Pytorch官方力荐新书《Pytorch深度学习实战指南》pdf及代码分享

    PyTorch是目前非常流行的机器学习.深度学习算法运算框架.它可以充分利用GPU进行加速,可以快速的处理复杂的深度学习模型,并且具有很好的扩展性,可以轻松扩展到分布式系统.PyTorch与Pytho ...

  4. 深度学习实战(十):使用 PyTorch 进行 3D 医学图像分割

    深度学习实战(十):使用 PyTorch 进行 3D 医学图像分割 1. 项目简介 2. 3D医学图像分割的需求 3. 医学图像和MRI 4. 三维医学图像表示 5. 3D-Unet模型 5.1损失函 ...

  5. 人工智能AI:TensorFlow Keras PyTorch MXNet PaddlePaddle 深度学习实战 part1

    日萌社 人工智能AI:TensorFlow Keras PyTorch MXNet PaddlePaddle 深度学习实战 part1 人工智能AI:TensorFlow Keras PyTorch ...

  6. pytorch深度学习实战——预训练网络

    来源:<Pytorch深度学习实战>,2.1,一个识别图像主体的预训练网络 from torchvision import models from torchvision import t ...

  7. 深度学习实战6-卷积神经网络(Pytorch)+聚类分析实现空气质量与天气预测

    文章目录 一.前期工作 导入库包 导入数据 主成分分析(PCA) 聚类分析(K-means) 二.神经网络模型建立 三.检验模型 大家好,我是微学AI,今天给大家带来一个利用卷积神经网络(pytorc ...

  8. PyTorch深度学习图像分类--猫狗大战

    PyTorch深度学习图像分类--猫狗大战 1.背景介绍 2.环境配置 2.1软硬件清单 2.1.1配置PyPorch 2.1.2开发软件 2.1.3 显卡 2.2 数据准备 3 基础理论 3.1Py ...

  9. Pytorch 深度学习实战教程(二):UNet语义分割网络

    本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善. 一 ...

最新文章

  1. 肠道微生物的研究不复杂,不信看这篇Science
  2. boost::hana::extend用法的测试程序
  3. Django中管理并发操作
  4. 无法启动baiMicrosoft Office Outlook。无法打开duOutlook窗口
  5. Docker Nacos Mysql集群
  6. Ubuntu系统全盘备份与恢复,亲自总结,实测可靠
  7. 华为手机设置页面黑色_华为手机里最危险的设置,学会这一招,手机还能再战三年...
  8. python3.6 asyncio paramiko_网工的Python之路:Concurrent.Futures
  9. 0基础入门学PLC,只需掌握好这5个步骤让您从0变精通
  10. 2019深圳杯数学建模比赛--初步思路
  11. 极光短信推送-java使用
  12. iPhoneX开了个好头,苹果今年将推廉价版iPhoneX柏颖
  13. 什么是瑞士加密谷Crypto Valley、CV Labs
  14. linux audacity,Audacity使用教程 Audacity怎么用
  15. 京东2020校招笔试题-算法工程师
  16. 张朝阳的心境,搜狐的武器
  17. android 编译ninja,Ninja编译过程分析
  18. 【Android Activity】Activity的生命周期
  19. mpu6050判断自由落体状态的方法
  20. 如何搭建自己的Teamspeak服务器?(Windows)

热门文章

  1. 区块链论语:价值投资及区块链应用
  2. 一个准大四狗的内心独白
  3. 阅读记录|《远山淡影》
  4. 对deap数据集进行脑电情绪识别并进行频谱分析(频域特征)
  5. FDA认证咨询,委托实验室或者生产厂家必须进行抑菌/抑真菌测试。
  6. 1. Jewels and Stones (宝石与石头)
  7. 信用社考试计算机知识,信用社考试(计算机基础知识 六)
  8. 我把所有的精华文章都整理出来了
  9. GA/T 1400协议 - 注册注销流程
  10. c语言结构体工人评优题,C语言复习习题-结构体