动手学深度学习——4. 猫狗大战

记录一下学习深度学习的一些。本篇简述如何在 Windows 上训练一个模型来识别猫狗。
所使用的环境:

  • Windows 10
  • 8700K
  • GeForce RTX 2070
  • CUDA 10.1
  • Python 3.8
  • Pytorch 1.7.1

数据准备

猫狗大战的数据集找不到官网,这里使用 Kaggle 的数据集,也提供百度网盘的下载地址。

  • Kaggle
  • 百度网盘:8nr6

数据查看

下载完数据后,解压,可以看到训练数据集的图片都以 cat.xxx.jpgdog.xxx.jpg 命名,可以从名字中获取标签,而测试数据集的标签无法提取,emmm,不嫌累的可以直接一张张标注。

再查看一些其他信息。

可以看到有 12,500 只猫和 12,500 只狗,图像的高宽大约为500,有部分特别小,有两个特别大,查看一下

可以看到其中有部分脏数据,这对模型的训练会造成很大的困扰,但是 25,000 张图像要清洗一遍太费精力了,待解决。

定义数据集

数据集中的图像有大有小,训练模型要求输入是统一大小的,直接将图像缩放至统一大小对人类判别猫狗影响不大,因此将所有的图像全部缩放至 256,再随机裁剪至 224,进行数据增强。

transform = {'train': transforms.Compose([transforms.Resize((256,256)),transforms.RandomCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize((256,256)),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}datasets = {'train': DogsCatsSet(train_set, transform['train']),'val': DogsCatsSet(val_set, transform['val'])
}data_sizes = {'train': len(datasets['train']),'val': len(datasets['val'])
}dataloaders = {'train': DataLoader(datasets['train'], batch_size=batch_size, shuffle=True, pin_memory=True),'val': DataLoader(datasets['val'], batch_size=batch_size, shuffle=False, pin_memory=False)
}

训练

用 ResNet 来训练猫狗分类器,使用 ImageNet 的预训练权重,进行微调

import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision.models import resnet50device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_ft = resnet50(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
model_tf = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_ft.parameters(), lr=1e-3)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
model_tf = train(model_ft, dataloaders, dataset_sizes, criterion, optimizer, exp_lr_scheduler, device, 20)

测试

本次测试由于测试集没有真实标签,所以就无办法直接验证测试集的准确率了,测试部分图像并将输入图像展示出来

test_transform = transforms.Compose([transforms.Resize((256,256)),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
test_set = DogsCatsSet(test_list, test_transform)
test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=True, pin_memory=True)
visualize_preds(model_ft, device, test_dataloader, 64)

可以看到测试的大部分结果都是正确的,但对于其他的测试样本或者真实的数据表现如何,就要实践才知道了。

保存模型

save_path = 'dogs_vs_cats.pt'
torch.save(model_ft.state_dict(), save_path)

Code

  • notebook
  • script

动手学深度学习——4. 猫狗大战相关推荐

  1. 动手学深度学习——5. 数据清洗

    动手学深度学习--5. 数据清洗 记录一下学习深度学习的一些.本篇简述如何使用 cleanlab 清洗分类数据 所使用环境: Ubuntu 16.04 8700K GeForce RTX 1080Ti ...

  2. 「动手学深度学习」在B站火到没谁,加这个免费实操平台,妥妥天花板!

    论 AI 圈活菩萨,非李沐老师莫属. 前有编写「动手学深度学习」,成就圈内入门经典,后又在B站免费讲斯坦福 AI 课,一则艰深硬核讲论文的视频播放量36万,不少课题组从导师到见习本科生都在追番. 如此 ...

  3. 《动手学深度学习》中文第二版预览版发布

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨李沐@知乎 来源丨https://zhuanlan.zhihu ...

  4. 收藏 |《动手学深度学习》中文版PDF

    对于初学者来说,直接阅读英文资料,效率慢,估计读着读着都没有信心读下去了.对于初学者,中文资料是再好不过了.今天小编就来安利一本中文资料--中文版本的<动手学深度学习>. 资料领取: 扫码 ...

  5. 深度学习经典教程:深度学习+动手学深度学习

    作者:[美] Ian,Goodfellow(伊恩·古德费洛),[加] Yoshua,Bengio(约书亚·本吉奥)等 出版社:人民邮电出版社 品牌:异步图书 出版时间:2019-06-01 深度学习经 ...

  6. 资源 | 李沐等人开源中文书《动手学深度学习》预览版上线

    来源:机器之心 本文约2000字,建议阅读10分钟. 本文为大家介绍了一本交互式深度学习书籍. 近日,由 Aston Zhang.李沐等人所著图书<动手学深度学习>放出了在线预览版,以供读 ...

  7. 最新版 | 2020李沐《动手学深度学习》

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 强烈推荐李沐等人的<动手学深度学习>最新版!完整中文版 PDF 终于 在 ...

  8. 动手学深度学习需要这些数学基础知识

    https://www.toutiao.com/a6716993354439066124/ 本附录总结了本书中涉及的有关线性代数.微分和概率的基础知识.为避免赘述本书未涉及的数学背景知识,本节中的少数 ...

  9. 《动手学深度学习》PyTorch版GitHub资源

    之前,偶然间看到过这个PyTorch版<动手学深度学习>,当时留意了一下,后来,着手学习pytorch,发现找不到这个资源了.今天又看到了,赶紧保存下来. <动手学深度学习>P ...

  10. 用PyTorch实现的李沐《动手学深度学习》,登上GitHub热榜,获得700+星

    晓查 发自 凹非寺  量子位 报道 | 公众号 QbitAI 李沐老师的<动手学深度学习>是一本入门深度学习的优秀教材,也是各大在线书店的计算机类畅销书. 作为MXNet的作者之一,李沐老 ...

最新文章

  1. jQuery.delegate() 函数详解
  2. micrometer_具有InlfuxDB的Spring Boot和Micrometer第1部分:基础项目
  3. 在2008 server安装vm server时发生的错误error1718、error1335……
  4. Git Flow工作流图
  5. Linux运维:cobbler
  6. Vivado常见问题集锦
  7. C#.Net工作笔记015---C#中Decimal类型四舍五入_小数点截位
  8. opencv4版本和3版本_Spring Boot 太狠了,一口气发布了 3 个版本!
  9. [软件工程基础]结队项目——地铁
  10. rs232转usb线故障(ft232r usb uart驱动安装失败)
  11. StrokeIt-单手摸鱼的快乐你想象不到
  12. VMware Workstation虚拟机安装及虚拟机搭建(内有虚拟机安装包及序列号和系统镜像)...
  13. 手机淘宝APP关键词搜索采集方案
  14. python闯关训练营怎么样3.0_泡着枸杞写bug的三流程序员凭什么逆袭到一线大厂?...
  15. SQL分组选取时间最大的记录
  16. HashMap引发死链问题(HashMap、ConcurrentHashMap原理解析)
  17. 【coolshell酷壳】简明 Vim 练级攻略
  18. 群晖消息通知 推送服务器,群晖开启系统信息微信推送服务
  19. html5视频播放av,7月AHA急救课程报名中!掌握埃里克森心脏骤停的获救技能!!...
  20. java/php/net/python养花助手平台设计

热门文章

  1. [生存志] 第51节 子产相郑铸刑书
  2. 3399 android root,RK3288/3399 Android Root方法
  3. PISCES P4-vSwitch 安装以及一次失败的测试
  4. 高效办公之云端实时协作企业办公软件:石墨文档
  5. influence和effect的区别
  6. 蓝桥杯第七届省赛 模拟风扇控制系统 by YYC
  7. 【String-easy】551. Student Attendance Record I 学生迟到和旷课
  8. 西部世界分析:人民网点名IPFS 分布式存储打开千亿级市场
  9. java对象转excel_【转】JAVA实现EXCEL的导入和导出(一)
  10. 如何做好DevOps Secrets管理