1 赛题简介

对近5万张按“物种-病害-程度”分成61类的植物叶片照片进行分类

比赛地址:AI challenger比赛—农作物病害检测

2 框架

我使用的是Keras,以TensorFlow为后端,手动实现了DenseNet用于图片分类
由于Kaggle现在可以免费使用GPU,所以采用将数据上传至Kaggle的私人Dataset上,在其上创建Kernel进行模型训练
(上传需要翻墙,有梯子最好)

3 DenseNet模型实现

def dense_block(x, blocks, name):for i in range(blocks):x = conv_block(x, 32, name=name + '_block' + str(i + 1))return x
def transition_block(x, reduction, name):bn_axis = 3x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,name=name + '_bn')(x)x = layers.Activation('relu', name=name + '_relu')(x)x = layers.Conv2D(int(backend.int_shape(x)[bn_axis] * reduction), 1,use_bias=False,name=name + '_conv')(x)x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x)return x
def conv_block(x, growth_rate, name):bn_axis = 3x1 = layers.BatchNormalization(axis=bn_axis,epsilon=1.001e-5,name=name + '_0_bn')(x)x1 = layers.Activation('relu', name=name + '_0_relu')(x1)x1 = layers.Conv2D(4 * growth_rate, 1,use_bias=False,name=name + '_1_conv')(x1)x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,name=name + '_1_bn')(x1)x1 = layers.Activation('relu', name=name + '_1_relu')(x1)x1 = layers.Conv2D(growth_rate, 3,padding='same',use_bias=False,name=name + '_2_conv')(x1)x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])return x
def DenseNet(blocks, input_shape=(150,150,3), classes=61):img_input = Input(shape=input_shape)bn_axis = 3x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input)x = layers.Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x)x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(x)x = layers.Activation('relu', name='conv1/relu')(x)x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)))(x)x = layers.MaxPooling2D(3, strides=2, name='pool1')(x)x = dense_block(x, blocks[0], name='conv2')x = transition_block(x, 0.5, name='pool2')x = dense_block(x, blocks[1], name='conv3')x = transition_block(x, 0.5, name='pool3')x = dense_block(x, blocks[2], name='conv4')x = transition_block(x, 0.5, name='pool4')x = dense_block(x, blocks[3], name='conv5')x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='bn')(x)x = layers.GlobalAveragePooling2D(name='avg_pool')(x)x = Dense(512)(x)x = BatchNormalization()(x)x = PReLU()(x)x = Dropout(0.5)(x)x = Dense(classes, activation='softmax', name='fc61')(x)inputs = img_inputmodel = Model(inputs, x, name='densenet')return model

调用DenseNet函数即可创建

model = DenseNet(blocks=[6, 12, 48, 32], input_shape=(150,150,3),classes=61)
model.summary()

4 数据准备

1、训练集、验证集生产器
这里对图片进行图像预处理,增加图片归一化、适度旋转、随机缩放、上下翻转

train_datagen = ImageDataGenerator(rescale=1. / 255,shear_range=0.2,rotation_range=20,zoom_range=0.2,horizontal_flip=True)
val_datagen = ImageDataGenerator(rescale=1. / 255)

2、读取数据
从目录中读取数据

img_width, img_height = 150, 150
train_data_dir = '../input/train/train'
validation_data_dir = '../input/val/val'
batch_size = 64
classes = 61train_generator = train_datagen.flow_from_directory(train_data_dir,target_size=(img_width, img_height),batch_size=batch_size,class_mode='categorical') #多分类validation_generator = val_datagen.flow_from_directory(validation_data_dir,target_size=(img_width, img_height),batch_size=batch_size,class_mode='categorical') #多分类

5 模型训练

1、先对模型进行预编译

model.compile(loss='categorical_crossentropy',optimizer=Adam(lr=0.0001),metrics=['accuracy'])

2、训练模型
增加自动更新学习率和保存在验证集最后的模型参数

learning_rate_reduction = ReduceLROnPlateau(monitor='val_acc', patience=3, verbose=1,factor=0.5, min_lr=0.000001)
checkpoint = ModelCheckpoint(model_name, monitor='val_acc', save_best_only=True)
history = model.fit_generator(train_generator,steps_per_epoch=nb_train_samples // batch_size,epochs=30,validation_data=validation_generator,validation_steps=nb_validation_samples // batch_size,callbacks=[checkpoint, learning_rate_reduction])

训练次数由于受Kaggle中Kernel的使用时间受限,只能训练6小时,所以只能暂时训练30,不过可以多次迭代训练。

6 模型预测

由于文件夹存放顺序跟window上不一样,所以实际上文件夹在Kaggle上Dataset上的存放顺序如下

rr = [0,1,10,11,12,13,14,15,16,17,18,19,2,20,21,22,23,24,25,26,27,28,29,3,30,31,32,33,34,35,36,37,38,39,4,40,41,42,43,44,45,46,47,48,49,5,50,51,52,53,54,55,56,57,58,59,6,60,7,8,9]images = os.listdir('../input/ai-challenger-pdr2018/testa/testA')result = []
for img1 in images:image_path = '../input/ai-challenger-pdr2018/testa/testA/' + img1img = image.load_img(image_path, target_size=(150, 150))x = image.img_to_array(img)/255.0x = np.expand_dims(x, axis=0)preds = model.predict(x)tmp = dict()tmp['image_id'] = img1tmp['disease_class']=rr[int(np.argmax(preds))]result.append(tmp)

最后保存为json

import json
json2 = json.dumps(result)
f = open('result.json','w',encoding='utf-8')
f.write(json2)
f.close()

7 提交结果

最终的结果是0.87395的成绩

8 完整代码参考

DenseNet模型训练 plants_disease_detection

如果你觉得我写的不错,请给我一下Star(^_^),谢谢!

AI challenger 2018图片分类比赛—农作物病害检测相关推荐

  1. 世界首个!AI农作物病害检测竞赛火热进行中 | AI Challenger 全球AI挑战赛

    乾明 发自 凹非寺 量子位 出品 | 公众号 QbitAI 如果你用谷歌搜索"AI+农业"或者"人工智能+农业",就会发现与AI在其他领域的应用相比,农业依旧是 ...

  2. 总奖金300万的AI Challenger 2018进入第二阶段,决赛在即!

    参加 2018 AI开发者大会,请点击 ↑↑↑ 此前,AI科技大本营曾报道过奖金池高达 300 万元的 AI Challenger 2018 比赛.与往届不同,今年的比赛共有 5 个主赛道,5 个实验 ...

  3. AI Challenger 2018决赛在即,12月18-19日极客峰会免费抢票!

    第二届"AI Challenger 全球AI挑战赛"各赛道竞赛经过两个多月的激烈角逐,报名将于北京时间2018年11月11日23:59:59正式截止,随即进入决赛阶段,最终每个竞赛 ...

  4. AI Challenger 2018:细粒度用户评论情感分析冠军思路总结

    2018年8月-12月,由美团点评.创新工场.搜狗.美图联合主办的"AI Challenger 2018全球AI挑战赛"历经三个多月的激烈角逐,冠军团队从来自全球81个国家.100 ...

  5. AI Challenger 2018 机器翻译参赛总结

    金山集团 AI Lab 组队参加了 AI Challenger 2018 全球挑战赛的英中机器翻译项目,并且获得冠军.  AI Challenger 2018 主题为"用 AI 挑战真实世界 ...

  6. 基于深度学习的农作物病害检测

    基于深度学习的农作物病害检测 1.研究思路 47 637 张图片总共 61 个分类标签.6 种模型对图像进行特征抽取. 采用交叉熵和正则化项组成损失函数进行反向传播调整,对数据集进行 4 种不同情况的 ...

  7. 深度学习(二十)基于Overfeat的图片分类、定位、检测

    基于Overfeat的图片分类.定位.检测 原文地址:http://blog.csdn.net/hjimce/article/details/50187881 作者:hjimce 一.相关理论 本篇博 ...

  8. AI Challenger 全球AI挑战赛[二]——场景分类比赛介绍(附数据集和基线模型百度云下载)

    AI Challenger 全球AI挑战赛       场景分类 [ 2017 ] 传送门 目的:寻找一个更鲁棒的场景分类模型,解决图片的角度.尺度.和光照的多样性问题 一.比赛介绍 赛题简介 移动互 ...

  9. 计算机视觉农作物检测,基于计算机视觉的农作物病害检测系统的研究

    摘要: 农作物病害是制约农业发展的主要因素之一,准确,高效地识别病害对于保证农作物的正常生长具有重要的意义.计算机视觉技术对加速农业现代化建设,提高生产效率影响深远. 本文以农作物病害类别的检测与识别 ...

最新文章

  1. 全面支持三大主流环境 |百度PaddlePaddle新增Windows环境支持
  2. C语言求幺元的函数,离散数学实验指导书及其答案.doc
  3. 《守望先锋》阵亡镜头、全场最佳和亮眼表现是如何设计
  4. 用twisted为未来安排任务(Scheduling tasks for the future
  5. java int64如何定义_java – 具有两个int属性的自定义类的hashCode是什么?
  6. Android当中layer-list使用来实现多个图层堆叠到一块儿
  7. storm spout mysql_storm+mysql集成
  8. ptyhon中文本挖掘精简版
  9. C#中用WMI实现对驱动的查询
  10. python基础笔记_python基础学习笔记
  11. overflow与BFC解说
  12. SHELL中如何对一个变量进行算术操作(加减)
  13. SocksCap64全局代理设置
  14. 使用mutt和msmtp发送邮件
  15. java 调用bat脚本 等待返回_java程序调用bat脚本
  16. 录播网站 服务器,录播服务器
  17. npm报错, Error: EPERM: operation not permitted, mkdir
  18. 业余无线电新手入门基础知识(全网最全)
  19. 如何判断一个数是不是整数
  20. 安装compiz-fusion

热门文章

  1. 机房收费管理系统之退卡
  2. 为什么自动驾驶遇瓶颈,但自动代客泊车却很热?
  3. 浏览器如何截图整个滚动屏 ?
  4. SQL Server 数据库之身份验证和访问控制
  5. snap 无法卸载_你手机里有哪些不想卸载的良心 App?
  6. systemd 介绍
  7. ERROR 1366 (HY000): Incorrect string value: ‘\xE8\xB5\xB5\xE9\x9B\xB7‘ for column ‘s_name‘ at row 1
  8. 计算机linux二级试题,计算机二级考试题及答案
  9. 社会实践分组(c++)
  10. Roofline-on-NVIDIA-GPUs代码分析