AI challenger 2018图片分类比赛—农作物病害检测
1 赛题简介
对近5万张按“物种-病害-程度”分成61类的植物叶片照片进行分类
比赛地址:AI challenger比赛—农作物病害检测
2 框架
我使用的是Keras,以TensorFlow为后端,手动实现了DenseNet用于图片分类
由于Kaggle现在可以免费使用GPU,所以采用将数据上传至Kaggle的私人Dataset上,在其上创建Kernel进行模型训练
(上传需要翻墙,有梯子最好)
3 DenseNet模型实现
def dense_block(x, blocks, name):for i in range(blocks):x = conv_block(x, 32, name=name + '_block' + str(i + 1))return x
def transition_block(x, reduction, name):bn_axis = 3x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,name=name + '_bn')(x)x = layers.Activation('relu', name=name + '_relu')(x)x = layers.Conv2D(int(backend.int_shape(x)[bn_axis] * reduction), 1,use_bias=False,name=name + '_conv')(x)x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x)return x
def conv_block(x, growth_rate, name):bn_axis = 3x1 = layers.BatchNormalization(axis=bn_axis,epsilon=1.001e-5,name=name + '_0_bn')(x)x1 = layers.Activation('relu', name=name + '_0_relu')(x1)x1 = layers.Conv2D(4 * growth_rate, 1,use_bias=False,name=name + '_1_conv')(x1)x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,name=name + '_1_bn')(x1)x1 = layers.Activation('relu', name=name + '_1_relu')(x1)x1 = layers.Conv2D(growth_rate, 3,padding='same',use_bias=False,name=name + '_2_conv')(x1)x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])return x
def DenseNet(blocks, input_shape=(150,150,3), classes=61):img_input = Input(shape=input_shape)bn_axis = 3x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input)x = layers.Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x)x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(x)x = layers.Activation('relu', name='conv1/relu')(x)x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)))(x)x = layers.MaxPooling2D(3, strides=2, name='pool1')(x)x = dense_block(x, blocks[0], name='conv2')x = transition_block(x, 0.5, name='pool2')x = dense_block(x, blocks[1], name='conv3')x = transition_block(x, 0.5, name='pool3')x = dense_block(x, blocks[2], name='conv4')x = transition_block(x, 0.5, name='pool4')x = dense_block(x, blocks[3], name='conv5')x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='bn')(x)x = layers.GlobalAveragePooling2D(name='avg_pool')(x)x = Dense(512)(x)x = BatchNormalization()(x)x = PReLU()(x)x = Dropout(0.5)(x)x = Dense(classes, activation='softmax', name='fc61')(x)inputs = img_inputmodel = Model(inputs, x, name='densenet')return model
调用DenseNet函数即可创建
model = DenseNet(blocks=[6, 12, 48, 32], input_shape=(150,150,3),classes=61)
model.summary()
4 数据准备
1、训练集、验证集生产器
这里对图片进行图像预处理,增加图片归一化、适度旋转、随机缩放、上下翻转
train_datagen = ImageDataGenerator(rescale=1. / 255,shear_range=0.2,rotation_range=20,zoom_range=0.2,horizontal_flip=True)
val_datagen = ImageDataGenerator(rescale=1. / 255)
2、读取数据
从目录中读取数据
img_width, img_height = 150, 150
train_data_dir = '../input/train/train'
validation_data_dir = '../input/val/val'
batch_size = 64
classes = 61train_generator = train_datagen.flow_from_directory(train_data_dir,target_size=(img_width, img_height),batch_size=batch_size,class_mode='categorical') #多分类validation_generator = val_datagen.flow_from_directory(validation_data_dir,target_size=(img_width, img_height),batch_size=batch_size,class_mode='categorical') #多分类
5 模型训练
1、先对模型进行预编译
model.compile(loss='categorical_crossentropy',optimizer=Adam(lr=0.0001),metrics=['accuracy'])
2、训练模型
增加自动更新学习率和保存在验证集最后的模型参数
learning_rate_reduction = ReduceLROnPlateau(monitor='val_acc', patience=3, verbose=1,factor=0.5, min_lr=0.000001)
checkpoint = ModelCheckpoint(model_name, monitor='val_acc', save_best_only=True)
history = model.fit_generator(train_generator,steps_per_epoch=nb_train_samples // batch_size,epochs=30,validation_data=validation_generator,validation_steps=nb_validation_samples // batch_size,callbacks=[checkpoint, learning_rate_reduction])
训练次数由于受Kaggle中Kernel的使用时间受限,只能训练6小时,所以只能暂时训练30,不过可以多次迭代训练。
6 模型预测
由于文件夹存放顺序跟window上不一样,所以实际上文件夹在Kaggle上Dataset上的存放顺序如下
rr = [0,1,10,11,12,13,14,15,16,17,18,19,2,20,21,22,23,24,25,26,27,28,29,3,30,31,32,33,34,35,36,37,38,39,4,40,41,42,43,44,45,46,47,48,49,5,50,51,52,53,54,55,56,57,58,59,6,60,7,8,9]images = os.listdir('../input/ai-challenger-pdr2018/testa/testA')result = []
for img1 in images:image_path = '../input/ai-challenger-pdr2018/testa/testA/' + img1img = image.load_img(image_path, target_size=(150, 150))x = image.img_to_array(img)/255.0x = np.expand_dims(x, axis=0)preds = model.predict(x)tmp = dict()tmp['image_id'] = img1tmp['disease_class']=rr[int(np.argmax(preds))]result.append(tmp)
最后保存为json
import json
json2 = json.dumps(result)
f = open('result.json','w',encoding='utf-8')
f.write(json2)
f.close()
7 提交结果
最终的结果是0.87395的成绩
8 完整代码参考
DenseNet模型训练 plants_disease_detection
如果你觉得我写的不错,请给我一下Star(^_^
),谢谢!
AI challenger 2018图片分类比赛—农作物病害检测相关推荐
- 世界首个!AI农作物病害检测竞赛火热进行中 | AI Challenger 全球AI挑战赛
乾明 发自 凹非寺 量子位 出品 | 公众号 QbitAI 如果你用谷歌搜索"AI+农业"或者"人工智能+农业",就会发现与AI在其他领域的应用相比,农业依旧是 ...
- 总奖金300万的AI Challenger 2018进入第二阶段,决赛在即!
参加 2018 AI开发者大会,请点击 ↑↑↑ 此前,AI科技大本营曾报道过奖金池高达 300 万元的 AI Challenger 2018 比赛.与往届不同,今年的比赛共有 5 个主赛道,5 个实验 ...
- AI Challenger 2018决赛在即,12月18-19日极客峰会免费抢票!
第二届"AI Challenger 全球AI挑战赛"各赛道竞赛经过两个多月的激烈角逐,报名将于北京时间2018年11月11日23:59:59正式截止,随即进入决赛阶段,最终每个竞赛 ...
- AI Challenger 2018:细粒度用户评论情感分析冠军思路总结
2018年8月-12月,由美团点评.创新工场.搜狗.美图联合主办的"AI Challenger 2018全球AI挑战赛"历经三个多月的激烈角逐,冠军团队从来自全球81个国家.100 ...
- AI Challenger 2018 机器翻译参赛总结
金山集团 AI Lab 组队参加了 AI Challenger 2018 全球挑战赛的英中机器翻译项目,并且获得冠军. AI Challenger 2018 主题为"用 AI 挑战真实世界 ...
- 基于深度学习的农作物病害检测
基于深度学习的农作物病害检测 1.研究思路 47 637 张图片总共 61 个分类标签.6 种模型对图像进行特征抽取. 采用交叉熵和正则化项组成损失函数进行反向传播调整,对数据集进行 4 种不同情况的 ...
- 深度学习(二十)基于Overfeat的图片分类、定位、检测
基于Overfeat的图片分类.定位.检测 原文地址:http://blog.csdn.net/hjimce/article/details/50187881 作者:hjimce 一.相关理论 本篇博 ...
- AI Challenger 全球AI挑战赛[二]——场景分类比赛介绍(附数据集和基线模型百度云下载)
AI Challenger 全球AI挑战赛 场景分类 [ 2017 ] 传送门 目的:寻找一个更鲁棒的场景分类模型,解决图片的角度.尺度.和光照的多样性问题 一.比赛介绍 赛题简介 移动互 ...
- 计算机视觉农作物检测,基于计算机视觉的农作物病害检测系统的研究
摘要: 农作物病害是制约农业发展的主要因素之一,准确,高效地识别病害对于保证农作物的正常生长具有重要的意义.计算机视觉技术对加速农业现代化建设,提高生产效率影响深远. 本文以农作物病害类别的检测与识别 ...
最新文章
- 全面支持三大主流环境 |百度PaddlePaddle新增Windows环境支持
- C语言求幺元的函数,离散数学实验指导书及其答案.doc
- 《守望先锋》阵亡镜头、全场最佳和亮眼表现是如何设计
- 用twisted为未来安排任务(Scheduling tasks for the future
- java int64如何定义_java – 具有两个int属性的自定义类的hashCode是什么?
- Android当中layer-list使用来实现多个图层堆叠到一块儿
- storm spout mysql_storm+mysql集成
- ptyhon中文本挖掘精简版
- C#中用WMI实现对驱动的查询
- python基础笔记_python基础学习笔记
- overflow与BFC解说
- SHELL中如何对一个变量进行算术操作(加减)
- SocksCap64全局代理设置
- 使用mutt和msmtp发送邮件
- java 调用bat脚本 等待返回_java程序调用bat脚本
- 录播网站 服务器,录播服务器
- npm报错, Error: EPERM: operation not permitted, mkdir
- 业余无线电新手入门基础知识(全网最全)
- 如何判断一个数是不是整数
- 安装compiz-fusion
热门文章
- 机房收费管理系统之退卡
- 为什么自动驾驶遇瓶颈,但自动代客泊车却很热?
- 浏览器如何截图整个滚动屏 ?
- SQL Server 数据库之身份验证和访问控制
- snap 无法卸载_你手机里有哪些不想卸载的良心 App?
- systemd 介绍
- ERROR 1366 (HY000): Incorrect string value: ‘\xE8\xB5\xB5\xE9\x9B\xB7‘ for column ‘s_name‘ at row 1
- 计算机linux二级试题,计算机二级考试题及答案
- 社会实践分组(c++)
- Roofline-on-NVIDIA-GPUs代码分析