Kaggle竞赛中的犬种识别挑战,比赛的网址是https://www.kaggle.com/c/dog-breed-identification 在这项比赛中,尝试确定120种不同的狗。该比赛中使用的数据集实际上是著名的ImageNet数据集的子集。

基本思路

加载自定义数据集

微调ResNet18模型

训练模型

基于pytorch的代码

日常导入需要用到的python库

import torch

import torch.nn as nn

import torch.optim as optim

import torchvision

from torchvision import transforms, datasets, models

import numpy as np

import pandas as pd

import matplotlib.pyplot as plt

np.random.seed(0)

torch.manual_seed(0)1

2

3

4

5

6

7

8

9

10

11

12

加载数据集

使用的是比赛网址上下载数据集, 格式如下

| Dog Breed Identification

| train

| | 000bec180eb18c7604dcecc8fe0dba07.jpg

| | 00a338a92e4e7bf543340dc849230e75.jpg

| | …

| test

| | 00a3edd22dc7859c487a64777fc8d093.jpg

| | 00a6892e5c7f92c1f465e213fd904582.jpg

| | …

| labels.csv

| sample_submission.csv

我们要将他转换成pytorch能识别的格式, 如下

| train_valid_test

| train

| | affenpinscher

| | | 00ca18751837cd6a22813f8e221f7819.jpg

| | | …

| | afghan_hound

| | | 0a4f1e17d720cdff35814651402b7cf4.jpg

| | | …

| | …

| valid

| | affenpinscher

| | | 56af8255b46eb1fa5722f37729525405.jpg

| | | …

| | afghan_hound

| | | 0df400016a7e7ab4abff824bf2743f02.jpg

| | | …

| | …

| train_valid

| | affenpinscher

| | | 00ca18751837cd6a22813f8e221f7819.jpg

| | | …

| | afghan_hound

| | | 0a4f1e17d720cdff35814651402b7cf4.jpg

| | | …

| | …

| test

| | unknown

| | | 00a3edd22dc7859c487a64777fc8d093.jpg

| | | …

先设置文件路径

all_path = "/home/kesci/input/Kaggle_Dog6357/dog-breed-identification"

test_path = "test"

train_path = "train"

train_label_path = "labels.csv"

valid_path = "valid"1

2

3

4

5

更据上面的路径去调整文件路径,

加载完后方便我们加载数据

# 操作文件

import os

# 拷贝文件

import shutil

def make_dir(path):

"""

判断路径是否存在:

False:创建该路径

"""

if not os.path.exists(os.path.join(*path)):

os.makedirs(os.path.join(*path))

def get_dog_data(root_path, train_path, label_path, test_path, valid_path, valid_alpha=.3):

new_dir = "new_dir"

# 加载训练集图片文件名

train_names = os.listdir(os.path.join(root_path, train_path))

np.random.shuffle(train_names)

# 加载训练集标签

labels_csv = pd.read_csv(os.path.join(root_path, label_path))

labels = {i: c for i, c in labels_csv.values}

# 验证集大小

valid_size = int(len(train_names) * valid_alpha)

for i, name in enumerate(train_names):

# 原name是name.jpg,只需要.jpg前面的部分

split_name = name.split(".")[0]

# labels -> {name: label} 将label提取出来

l = labels[split_name]

# 将数据集拷贝到valid所属文件夹中

if i < valid_size:

make_dir([root_path, new_dir, "valid", l])

shutil.copy(

# 源文件路径

os.path.join(root_path, train_path, name),

# 拷贝文件路径

os.path.join(root_path, new_dir, "valid", l)

)

else:

make_dir([root_path, new_dir, "train", l])

shutil.copy(

os.path.join(root_path, train_path, name),

os.path.join(root_path, new_dir, "train", l)

)

# 加入完整的训练集中(训练集 + 验证集)

make_dir([root_path, new_dir, "train_and_valid", l])

shutil.copy(

os.path.join(root_path, train_path, name),

os.path.join(root_path, new_dir, "train_and_valid", l)

)

make_dir([root_path, new_dir, "test", "unclass"])

for i in os.listdir(os.path.join(root_path, test_path)):

shutil.copy(

os.path.join(root_path, test_path, i),

os.path.join(root_path, new_dir, "test", "unclass")

)1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

get_dog_data(all_path, train_path, train_label_path, test_path, valid_path)

运行, 然后调整文件

根据以前所学知识, 对数据进行一些数据增强, 批量加载等

# 数据增强

train_transform = transforms.Compose([

# 图像随机裁剪大小和纵横比

transforms.RandomResizedCrop(224, scale=(0.08, 1.0),

ratio=(3.0/4.0, 4.0/3.0)),

# 图像水平翻转

transforms.RandomHorizontalFlip(),

# 更改图像亮度, 对比度, 饱和度 (色阶)

transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),

transforms.ToTensor(),

transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

])

test_transform = transforms.Compose([

# 将图片缩放到256

transforms.Resize(256),

# 根据图片中心点裁剪224

transforms.CenterCrop(224),

transforms.ToTensor(),

transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

])

# 加载数据集

train_data = datasets.ImageFolder(os.path.join(all_path, "new_dir", train_path),

transform=train_transform)

valid_data = datasets.ImageFolder(os.path.join(all_path, "new_dir", valid_path),

transform=test_transform)

train_and_valid_data = datasets.ImageFolder(os.path.join(all_path, "new_dir",

"train_and_valid"), transform=train_transform)

test_data = datasets.ImageFolder(os.path.join(all_path, "new_dir", test_path),

transform=test_transform)

# 批量数据集

train_iter = torch.utils.data.DataLoader(train_data, batch_size=128,

shuffle=True)

valid_iter = torch.utils.data.DataLoader(valid_data, batch_size=128,

shuffle=True)

train_and_valid_iter = torch.utils.data.DataLoader(train_and_valid_data,

batch_size=128, shuffle=True)

test_iter = torch.utils.data.DataLoader(test_data, batch_size=128,

shuffle=False)1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

微调ResNet18模型

加载已经下载好的权重, 将参数冻结, 训练全连接层即可

def resnet34():

model = models.resnet34()

model.load_state_dict(torch.load(

"/home/kesci/input/resnet347742/resnet34-333f7ec4.pth"))

# 冻结参数

for para in model.parameters():

para.requires_grad = False

model.fc = nn.Sequential(

nn.Linear(512, 256),

nn.ReLU(),

nn.Linear(256, 120)

)

return model1

2

3

4

5

6

7

8

9

10

11

12

13

训练模型

和以前一样训练模型

def train(net, epochs=20, lr=0.01):

opt = optim.Adam(net.parameters(), lr=lr)

criterion = nn.CrossEntropyLoss()

import time

for epoch in range(1, epochs):

net.train()

train_loss = 0.0

start_time = time.time()

for x, y in train_iter:

out = net(x)

loss = criterion(out, y)

train_loss += loss.float().item()

net.zero_grad()

loss.backward()

opt.step()

print(

f"Epoch -> {epoch}\t"

f"Time Out: {time.time() - start_time :.4f}sec\t"

f"Loss: {train_loss / len(train_iter) :.3f}"

)

net.eval()

valid_loss = 0

acc = 0

state_time = time.time()

for x, y in valid_iter:

out = net(x)

loss = criterion(out, y)

valid_loss += loss.float().item()

acc += (out.argmax(dim=1) == y).float().mean().item()

print(

f"Valid Time Out: {time.time() - state_time :.4f}sec\t"

f"Valid Loss: {valid_loss / len(valid_iter) :.4f}\t"

f"Accuracy: {acc / len(valid_iter) * 100 :.2f}%\nOver!"

)1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

也可以直接尝试训练整个数据集(train_and_valid), 训练时间较长, 我就不尝试了

然后用训练好的模型去分类测试集的的图片

python狗品种识别_狗品种识别相关推荐

  1. python计算狗的年龄_狗一岁相当于人几岁?怎样确定狗狗的年龄

    解答:狗是生活中比较常见的一种动物,有人提出狗的一岁相当于人的7岁实际上这种说法是不准确的.假如按照人的年龄换算的话,0-1岁的狗是16-17岁左右年纪的人,而2岁的狗则相当于23岁左右的人. 狗一岁 ...

  2. python 人脸轮廓提取_实现人脸识别、人脸68个特征点提取,或许这个 Python 库能帮到你!...

    以前写过一篇关于实现人脸识别的文章,里面用到的技术是经过调用百度 API 实现的,本次将借助于 dlib  程序包实现人脸区域检测.特征点提取等功能,html dlib 封装了许多优秀的机器学习算法, ...

  3. python 命名实体识别_命名实体识别的两种方法

    作者:Walker 目录 一.什么是命名实体识别 二.基于NLTK的命名实体识别 三.基于Stanford的NER 四.总结 一 .什么是命名实体识别? 命名实体识别(Named Entity Rec ...

  4. python机器视觉车牌识别_机器视觉车牌识别

    机器视觉车牌识别 --车牌号识别系统研究课题 2018年7月10日,许昌学院信息工程(软件职业技术)学院"创出彩"机器视觉智能检测实践队第10天研究正式开展,由于老师有别的事情要忙 ...

  5. 人脸识别_云端人脸识别-人脸识别SDK+API-人脸识别闸机解决方案

    云端人脸识别-人脸识别SDK+API-人脸识别闸机解决方案 人脸识别闸机-人脸识别闸机解决方案 软硬一体的人脸识别闸机解决方案,提升人员系统化管理的安全性与便捷性 方案构成 针对人员出入的闸机及门禁场 ...

  6. 人脸反光识别和读数识别_云端人脸识别-人脸识别SDK+API-人脸识别闸机解决方案...

    云端人脸识别-人脸识别SDK+API-人脸识别闸机解决方案 人脸识别闸机-人脸识别闸机解决方案 软硬一体的人脸识别闸机解决方案,提升人员系统化管理的安全性与便捷性 方案构成 针对人员出入的闸机及门禁场 ...

  7. 夜间环境人脸识别_动态人脸识别系统的优势

    TH-894是一款天煌电子全新的三防动态人脸识别xt终端采用嵌入式系统.功耗低,运行更稳定.数据更安全.使用高性能智能处理器,基于深度学习的人脸识别与抓拍信息提取,极大的提高了人脸抓拍率.采用夜间红外 ...

  8. python画小动物_三分钟识别所有小动物!

    大家是不是有过这样的经历:看到一只可爱的小动物却不知道这是什么品种?或者看到一个美丽的妹子牵着一只小动物却不知道如何搭讪?现在机会来了,免费领取你的人工智能AI自动识别小动物!当然猪猪也是可以的!!! ...

  9. python爬取换页_一个可识别翻页的简易Python爬虫程序

    同学拿出一个需求:从某课程教学网站上爬取所有课程的主页面,以及课程简介栏目内容. 于是在之前做的那个练手级的Python爬虫程序中进行修改,最终实现了该功能.与之前那个爬虫不同,这里每一个大类的课程下 ...

最新文章

  1. 朱哥研究出来的分页控件
  2. jsp中获取不到后台请求域中的值
  3. webflux切面拦截权限,webflux整合aop,webflux获取request
  4. List的Clear方法与RemoveAll方法用法小结
  5. java 类似datatable_在java中实现类似于.net中的DataTable,请各位看看,这种方法可行吗?...
  6. 关于Linux环境变量
  7. HTML图形映射技术
  8. php java session共享_PHP实现多服务器session共享之NFS共享
  9. xmpp协议抓包_抓包工具有哪些?大佬们常用的18款抓包工具就是这些
  10. jQuery基础集锦——插件开发
  11. 一款简洁大气的个人主页源码
  12. 基于php034医院电子病历住院病人
  13. springboot+vue公众号页面授权获得微信openId
  14. 计算机网络技术广告,屏蔽QQ广告和迷你首页广告
  15. 欧洲杯赛场“中国元素”引观众热议;万达两家酒店在延安红街开业窑洞房最具特色 | 美通社头条...
  16. Java加上Xtend,满足你对C#语法的所有想象
  17. 深入浅出TVS瞬态抑态二极管
  18. WPF学习之深入浅出话模板
  19. 利用CANoe Vector LDF Explorer Pro创建LDF文件
  20. FC协议功能子模块,实现FC-1553协议,ASM协议,AV协议的应用,多种接口可定制

热门文章

  1. 问题1484:小鱼的刷剧时光
  2. 【月伴流星】GhostW7_SP1_U_x86_V2013.06_OEM通用纯净、装机、美化版(三版齐发)
  3. android RecyclerView禁止多点触控
  4. 2018Python元年
  5. wincc提示计算机丢失ccctrl,WinCC(变量记录和组态报警)
  6. Tune:一个分布式模型选择与训练研究平台
  7. VirtualBox安装Ubuntu18
  8. linux常用命令与实例小全
  9. 什么是神经网络技术,三种常见的神经网络
  10. MIPI DBI介绍