比赛地址:https://www.kaggle.com/c/leaf-classification/rules
完整代码:https://github.com/SPECTRELWF/kaggle_competition
个人主页:liuweifeng.top:8090

比赛题目:对树叶的类别进行分类,树叶总共99个类别,树叶的图片如下:


我也不知道怎么分类,反正总共有99中类别的树叶。下载到的数据集解压后如下:

image里面存了所有的树叶图像,train.csv是训练文件的标号以及类别,后面有一堆的特征,我没用到,因为比赛已经结束了,我只是纯纯的拿了练习下CNN。test.csv文件是测试数据的标号,sample_submission.csv文件是提交样例,长这样:

第一列是id,后面的99列是对应的每个类别的概率,分类结果加上softmax就行。

思路:

直接使用的基于ImageNet预训练的resnet101,微调一下。

预处理

将训练集的id和label写到一个txt文件中,测试集的id写入另一个txt文件:

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/12/8 上午10:27import os
import pandas as pd
classes = ['Acer_Capillipes', 'Acer_Circinatum', 'Acer_Mono', 'Acer_Opalus', 'Acer_Palmatum', 'Acer_Pictum', 'Acer_Platanoids', 'Acer_Rubrum', 'Acer_Rufinerve', 'Acer_Saccharinum', 'Alnus_Cordata', 'Alnus_Maximowiczii', 'Alnus_Rubra', 'Alnus_Sieboldiana', 'Alnus_Viridis', 'Arundinaria_Simonii', 'Betula_Austrosinensis', 'Betula_Pendula', 'Callicarpa_Bodinieri', 'Castanea_Sativa', 'Celtis_Koraiensis', 'Cercis_Siliquastrum', 'Cornus_Chinensis', 'Cornus_Controversa', 'Cornus_Macrophylla', 'Cotinus_Coggygria', 'Crataegus_Monogyna', 'Cytisus_Battandieri', 'Eucalyptus_Glaucescens', 'Eucalyptus_Neglecta', 'Eucalyptus_Urnigera', 'Fagus_Sylvatica', 'Ginkgo_Biloba', 'Ilex_Aquifolium', 'Ilex_Cornuta', 'Liquidambar_Styraciflua', 'Liriodendron_Tulipifera', 'Lithocarpus_Cleistocarpus', 'Lithocarpus_Edulis', 'Magnolia_Heptapeta', 'Magnolia_Salicifolia', 'Morus_Nigra', 'Olea_Europaea', 'Phildelphus', 'Populus_Adenopoda', 'Populus_Grandidentata', 'Populus_Nigra', 'Prunus_Avium', 'Prunus_X_Shmittii', 'Pterocarya_Stenoptera', 'Quercus_Afares', 'Quercus_Agrifolia', 'Quercus_Alnifolia', 'Quercus_Brantii', 'Quercus_Canariensis', 'Quercus_Castaneifolia', 'Quercus_Cerris', 'Quercus_Chrysolepis', 'Quercus_Coccifera', 'Quercus_Coccinea', 'Quercus_Crassifolia', 'Quercus_Crassipes', 'Quercus_Dolicholepis', 'Quercus_Ellipsoidalis', 'Quercus_Greggii', 'Quercus_Hartwissiana', 'Quercus_Ilex', 'Quercus_Imbricaria', 'Quercus_Infectoria_sub', 'Quercus_Kewensis', 'Quercus_Nigra', 'Quercus_Palustris', 'Quercus_Phellos', 'Quercus_Phillyraeoides', 'Quercus_Pontica', 'Quercus_Pubescens', 'Quercus_Pyrenaica', 'Quercus_Rhysophylla', 'Quercus_Rubra', 'Quercus_Semecarpifolia', 'Quercus_Shumardii', 'Quercus_Suber', 'Quercus_Texana', 'Quercus_Trojana', 'Quercus_Variabilis', 'Quercus_Vulcanica', 'Quercus_x_Hispanica', 'Quercus_x_Turneri', 'Rhododendron_x_Russellianum', 'Salix_Fragilis', 'Salix_Intergra', 'Sorbus_Aria', 'Tilia_Oliveri', 'Tilia_Platyphyllos', 'Tilia_Tomentosa', 'Ulmus_Bergmanniana', 'Viburnum_Tinus', 'Viburnum_x_Rhytidophylloides', 'Zelkova_Serrata']train_txt = open('train.txt','w')
train_csv = pd.read_csv(r'leaf-classification/train.csv')
ids = train_csv['id']
species = train_csv['species']for i in range(len(ids)):train_txt.write(str(ids[i]))train_txt.write(' ')train_txt.write(str(classes.index(str(species[i]))))train_txt.write('\n')
train_txt.close()test_txt = open('test.txt','w')
test_csv = pd.read_csv(r'leaf-classification/test.csv')
ids = test_csv['id']
for i in range(len(ids)):test_txt.write(str(ids[i]))test_txt.write('\n')
test_txt.close()

模型resnet101

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/12/8 上午10:24
import torch
import torchvision.models
import torchvision.transforms as transforms
import torch.nn as nn
import torchvision.models as modelsclass resnet101(nn.Module):def __init__(self, num_classes=1000):super(resnet101, self).__init__()self.num_classes = num_classesself.feature_extract = torchvision.models.resnet101(pretrained=True)self.net = nn.Sequential(nn.Linear(1000, 512),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(512, 256),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(256, num_classes),)def forward(self, x):x = self.feature_extract(x)x = self.net(x)return x# x = torch.randn((1,3,224,224))
# net = resnet101(num_classes=99)
# print(net)
# print(net(x).shape)

dataloader

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/12/8 上午10:24
import numpy as np
import torch.utils.data as data
import torch
import torchvision.transforms as transforms
from PIL import Image
data_root = r'leaf-classification/images/'class leaf_Dataset(data.Dataset):def __init__(self,is_train=True,transform=None):self.is_train = is_trainself.transform = transformself.images = []self.labels = []if is_train:file = open('train.txt','r')lines = file.readlines()for line in lines:res = line[:-1]image = res.split(' ')[0]label = int(res.split(' ')[1])self.images.append(image)self.labels.append(label)print(self.images)print(self.labels)def __len__(self):return len(self.images)def __getitem__(self, index):image_name = self.images[index] + '.jpg'image_path = data_root + image_nameimg = Image.open(image_path).convert('RGB')# print(img)img = self.transform(img)label = self.labels[index]label = torch.from_numpy(np.array(label))return img, labeltransforms = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()
])
# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/12/8 上午10:25"""
使用imagenet预训练的rennet101来在树叶数据集上面进行微调
"""
import torch
import torchvision.transforms as transforms
from dataset import leaf_Dataset
import torch.utils.data as data
import torch.optim as optim
import torch.nn as nn
from resnet import resnet101
#使用Adam优化器来训练网络,不冻结参数# 设置hyperparameterepoch = 200
lr = 1e-3
b1 = 0.9
b2 = 0.999
device = torch.device('cuda:0')
train_loss = []
# 初始化网络模型
net = resnet101(num_classes=99)
net.to(device)# load data
transforms = transforms.Compose([transforms.Resize((224,224)),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.ToTensor(),
])
data = leaf_Dataset(is_train=True,transform=transforms)
dataloader = torch.utils.data.DataLoader(data,batch_size=64,shuffle=True)loss_func = nn.CrossEntropyLoss()
opt = torch.optim.Adam(net.parameters(),lr=lr,betas=(b1,b2))for epoch in range(1,epoch + 1):for i, (x,y) in enumerate(dataloader):x = x.to(device)y = y.to(device)pred = net(x)loss = loss_func(pred,y)opt.zero_grad()loss.backward()opt.step()train_loss.append(loss.item())print("epoch: %d   batch_idx:%d   loss:%.3f" %(epoch,i,loss.item()))torch.save(net.state_dict(),'model/epoch:%d'%epoch + '.pth')
from utils import plot_curve
plot_curve(train_loss)

loss

将预测结果写入要提交的文件

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/12/8 下午5:42
import torch
import torchvision.transforms as transforms
import numpy as np
import os
from PIL import Image
from resnet import resnet101
import torch.nn.functional as Fimage_path = r'leaf-classification/images'
f = open('test.txt','r')
tmp = f.readlines()
test_file = []
for i in tmp:i = i[:-1]test_file.append(i+'.jpg')
print(test_file)device = torch.device('cuda:0')
net = resnet101(num_classes=99)
print('load weight........')
net.load_state_dict(torch.load('model/epoch:200.pth'))
net.to(device)
net.eval()
transformss = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor()
])
res = []
with torch.no_grad():for image in test_file:img = Image.open(os.path.join(image_path,image)).convert('RGB')img = transformss(img)img = torch.unsqueeze(img,dim=0)img = img.to(device)# print(img.shape)pred = net(img)pred = F.softmax(pred).flatten()pred = pred.cpu().numpy()print(pred)res.append(pred)np.savetxt("result.csv",res,delimiter = ',')

实战Kaggle比赛(1):树叶分类相关推荐

  1. 沐神《动手学深度实战Kaggle比赛:狗的品种识别(ImageNet Dogs)

    沐神<动手学深度学习>飞桨版课程公开啦! hello各位飞桨的开发者,大家好!李沐老师的<动手学深度学习>飞桨版课程已经公开啦.本课程由PPSIG和飞桨工程师共同建设,将原书中 ...

  2. 动手学深度学习:3.16 实战Kaggle比赛:房价预测

    3.16 实战Kaggle比赛:房价预测 作为深度学习基础篇章的总结,我们将对本章内容学以致用.下面,让我们动手实战一个Kaggle比赛:房价预测.本节将提供未经调优的数据的预处理.模型的设计和超参数 ...

  3. 实战Kaggle比赛----预测房价(多层感知机)

    文章目录 实战Kaggle比赛----预测房价 下载和缓存数据集 Kaggle简介 访问和读取数据集 数据预处理 标准正态化.缺失值填充.离散值one-hot编码 小栗子帮助理解 训练 KKK折交叉验 ...

  4. 《动手深度学习》4.10. 实战Kaggle比赛:预测房价

    4.10. 实战Kaggle比赛:预测房价 本节内容预览 数据 下载和缓存数据集 访问和读取数据集 使用pandas读入并处理数据 数据预处理 处理缺失值&对数值类数据标准化 处理离散值-on ...

  5. 04.10. 实战Kaggle比赛:预测房价

    4.10. 实战Kaggle比赛:预测房价 详细介绍数据预处理.模型设计和超参数选择. 通过亲身实践,你将获得一手经验,这些经验将有益数据科学家的职业成长. import hashlib import ...

  6. 超详解pytorch实战Kaggle比赛:房价预测

    详解pytorch实战Kaggle比赛:房价预测 教程名称 教程地址 机器学习/深度学习 [李宏毅]机器学习/深度学习国语教程(双语字幕) 生成对抗网络 [李宏毅]生成对抗网络国语教程(双语字幕) 目 ...

  7. 实战Kaggle比赛:预测房价

    文章目录 实战Kaggle比赛:预测房价 1 - 下载和缓存数据集 2 - 访问和读取数据集 3 - 数据预处理 4 - 训练 5 - K折交叉验证 6 - 模型选择 7 - 提交你的Kaggle预测 ...

  8. 深度学习+pytorch实战Kaggle比赛(一)——房价预测

    参考书籍<动手学深度学习(pytorch版),参考网址为: https://zh-v2.d2l.ai/chapter_multilayer-perceptrons/kaggle-house-pr ...

  9. pytorch学习笔记(十四):实战Kaggle比赛——房价预测

    文章目录 1. Kaggle比赛 2. 获取和读取数据集 3. 预处理数据 4. 训练模型 5. KKK折交叉验证 6. 模型选择 7. 预测并在Kaggle提交结果 1. Kaggle比赛 Kagg ...

最新文章

  1. 【骚气的动效】外发光涟漪波纹动画、向外辐射动画效果,通常用于地图上面某一个扩散点效果
  2. Python 技术篇-通过进程名称、PID杀死windows进程的两种方法,获取当前运行程序的pid
  3. Spring Boot2.x-10 基于Spring Boot 2.1.2 + Mybatis 2.0.0实现多数据源,支持事务
  4. 奥运信息安全谁说了算?
  5. 几款优秀的点播、RTSP/RTMP直播播放器介绍
  6. UIwebView缩放
  7. 使用phpqrcode来生成二维码/thinkphp
  8. Windows 7 建立 ×××网络
  9. Java的clone方法
  10. 设计模式-(8)外观(swift版)
  11. DAY 11 | 自学前端第十一天
  12. Python AutoCAD 选择集
  13. 计算机的排版方法,计算机编辑排版系统及其方法
  14. 【LeetCode】75. Sort Colors(颜色排序)-C++实现的两种方法及超详细图解
  15. 云服务器内存占用多少,腾讯云云服务器CPU或内存占用过高怎么办?
  16. 如何禁止软件联网,防止软件自动更新
  17. 最新县及县以上行政区划代码(截止2014年10月31日)
  18. 【学术写作】优雅地翻译英文论文【保持格式】【无须排版】
  19. linux临时配置mac地址,Linux获取网卡型号、mac地址、修改IP地址的几种方法
  20. 解决google打开Github慢的问题,亲测有效

热门文章

  1. 想成为一名优秀的数据分析师,应该做些什么?
  2. Transformer升级之路:二维位置的旋转式位置编码
  3. OAG – WhoIsWho 同名消歧竞赛发布 | 10万元奖金双赛道
  4. 每周一起读 | ACL 2019 NAACL 2019:文本关系抽取专题沙龙
  5. 直播预告:基于动态词表的对话生成研究 | PaperWeekly x 微软亚洲研究院
  6. 在Android设备部署PyTorch模型
  7. vba fso读utf 文本_利用FSO对象操作文件
  8. 2021年零基础带你走进nacos的世界之云服务器下载安装nacos-小白教程,详细到爆了!
  9. springboot项目中一个实体类引用其它实体类的字段并显示到页面上
  10. AFei Loves Magic