动手学深度学习——4. 猫狗大战
动手学深度学习——4. 猫狗大战
记录一下学习深度学习的一些。本篇简述如何在 Windows 上训练一个模型来识别猫狗。
所使用的环境:
- Windows 10
- 8700K
- GeForce RTX 2070
- CUDA 10.1
- Python 3.8
- Pytorch 1.7.1
数据准备
猫狗大战的数据集找不到官网,这里使用 Kaggle 的数据集,也提供百度网盘的下载地址。
- Kaggle
- 百度网盘:8nr6
数据查看
下载完数据后,解压,可以看到训练数据集的图片都以 cat.xxx.jpg
和 dog.xxx.jpg
命名,可以从名字中获取标签,而测试数据集的标签无法提取,emmm,不嫌累的可以直接一张张标注。
再查看一些其他信息。
可以看到有 12,500 只猫和 12,500 只狗,图像的高宽大约为500,有部分特别小,有两个特别大,查看一下
可以看到其中有部分脏数据,这对模型的训练会造成很大的困扰,但是 25,000 张图像要清洗一遍太费精力了,待解决。
定义数据集
数据集中的图像有大有小,训练模型要求输入是统一大小的,直接将图像缩放至统一大小对人类判别猫狗影响不大,因此将所有的图像全部缩放至 256,再随机裁剪至 224,进行数据增强。
transform = {'train': transforms.Compose([transforms.Resize((256,256)),transforms.RandomCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize((256,256)),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}datasets = {'train': DogsCatsSet(train_set, transform['train']),'val': DogsCatsSet(val_set, transform['val'])
}data_sizes = {'train': len(datasets['train']),'val': len(datasets['val'])
}dataloaders = {'train': DataLoader(datasets['train'], batch_size=batch_size, shuffle=True, pin_memory=True),'val': DataLoader(datasets['val'], batch_size=batch_size, shuffle=False, pin_memory=False)
}
训练
用 ResNet 来训练猫狗分类器,使用 ImageNet 的预训练权重,进行微调
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision.models import resnet50device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_ft = resnet50(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
model_tf = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_ft.parameters(), lr=1e-3)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
model_tf = train(model_ft, dataloaders, dataset_sizes, criterion, optimizer, exp_lr_scheduler, device, 20)
测试
本次测试由于测试集没有真实标签,所以就无办法直接验证测试集的准确率了,测试部分图像并将输入图像展示出来
test_transform = transforms.Compose([transforms.Resize((256,256)),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
test_set = DogsCatsSet(test_list, test_transform)
test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=True, pin_memory=True)
visualize_preds(model_ft, device, test_dataloader, 64)
可以看到测试的大部分结果都是正确的,但对于其他的测试样本或者真实的数据表现如何,就要实践才知道了。
保存模型
save_path = 'dogs_vs_cats.pt'
torch.save(model_ft.state_dict(), save_path)
Code
- notebook
- script
动手学深度学习——4. 猫狗大战相关推荐
- 动手学深度学习——5. 数据清洗
动手学深度学习--5. 数据清洗 记录一下学习深度学习的一些.本篇简述如何使用 cleanlab 清洗分类数据 所使用环境: Ubuntu 16.04 8700K GeForce RTX 1080Ti ...
- 「动手学深度学习」在B站火到没谁,加这个免费实操平台,妥妥天花板!
论 AI 圈活菩萨,非李沐老师莫属. 前有编写「动手学深度学习」,成就圈内入门经典,后又在B站免费讲斯坦福 AI 课,一则艰深硬核讲论文的视频播放量36万,不少课题组从导师到见习本科生都在追番. 如此 ...
- 《动手学深度学习》中文第二版预览版发布
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨李沐@知乎 来源丨https://zhuanlan.zhihu ...
- 收藏 |《动手学深度学习》中文版PDF
对于初学者来说,直接阅读英文资料,效率慢,估计读着读着都没有信心读下去了.对于初学者,中文资料是再好不过了.今天小编就来安利一本中文资料--中文版本的<动手学深度学习>. 资料领取: 扫码 ...
- 深度学习经典教程:深度学习+动手学深度学习
作者:[美] Ian,Goodfellow(伊恩·古德费洛),[加] Yoshua,Bengio(约书亚·本吉奥)等 出版社:人民邮电出版社 品牌:异步图书 出版时间:2019-06-01 深度学习经 ...
- 资源 | 李沐等人开源中文书《动手学深度学习》预览版上线
来源:机器之心 本文约2000字,建议阅读10分钟. 本文为大家介绍了一本交互式深度学习书籍. 近日,由 Aston Zhang.李沐等人所著图书<动手学深度学习>放出了在线预览版,以供读 ...
- 最新版 | 2020李沐《动手学深度学习》
点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 强烈推荐李沐等人的<动手学深度学习>最新版!完整中文版 PDF 终于 在 ...
- 动手学深度学习需要这些数学基础知识
https://www.toutiao.com/a6716993354439066124/ 本附录总结了本书中涉及的有关线性代数.微分和概率的基础知识.为避免赘述本书未涉及的数学背景知识,本节中的少数 ...
- 《动手学深度学习》PyTorch版GitHub资源
之前,偶然间看到过这个PyTorch版<动手学深度学习>,当时留意了一下,后来,着手学习pytorch,发现找不到这个资源了.今天又看到了,赶紧保存下来. <动手学深度学习>P ...
- 用PyTorch实现的李沐《动手学深度学习》,登上GitHub热榜,获得700+星
晓查 发自 凹非寺 量子位 报道 | 公众号 QbitAI 李沐老师的<动手学深度学习>是一本入门深度学习的优秀教材,也是各大在线书店的计算机类畅销书. 作为MXNet的作者之一,李沐老 ...
最新文章
- jQuery.delegate() 函数详解
- micrometer_具有InlfuxDB的Spring Boot和Micrometer第1部分:基础项目
- 在2008 server安装vm server时发生的错误error1718、error1335……
- Git Flow工作流图
- Linux运维:cobbler
- Vivado常见问题集锦
- C#.Net工作笔记015---C#中Decimal类型四舍五入_小数点截位
- opencv4版本和3版本_Spring Boot 太狠了,一口气发布了 3 个版本!
- [软件工程基础]结队项目——地铁
- rs232转usb线故障(ft232r usb uart驱动安装失败)
- StrokeIt-单手摸鱼的快乐你想象不到
- VMware Workstation虚拟机安装及虚拟机搭建(内有虚拟机安装包及序列号和系统镜像)...
- 手机淘宝APP关键词搜索采集方案
- python闯关训练营怎么样3.0_泡着枸杞写bug的三流程序员凭什么逆袭到一线大厂?...
- SQL分组选取时间最大的记录
- HashMap引发死链问题(HashMap、ConcurrentHashMap原理解析)
- 【coolshell酷壳】简明 Vim 练级攻略
- 群晖消息通知 推送服务器,群晖开启系统信息微信推送服务
- html5视频播放av,7月AHA急救课程报名中!掌握埃里克森心脏骤停的获救技能!!...
- java/php/net/python养花助手平台设计
热门文章
- [生存志] 第51节 子产相郑铸刑书
- 3399 android root,RK3288/3399 Android Root方法
- PISCES P4-vSwitch 安装以及一次失败的测试
- 高效办公之云端实时协作企业办公软件:石墨文档
- influence和effect的区别
- 蓝桥杯第七届省赛 模拟风扇控制系统 by YYC
- 【String-easy】551. Student Attendance Record I 学生迟到和旷课
- 西部世界分析:人民网点名IPFS 分布式存储打开千亿级市场
- java对象转excel_【转】JAVA实现EXCEL的导入和导出(一)
- 如何做好DevOps Secrets管理