tensorflow2.4使用GooleNet实现识别植物花朵图像项目

1. 数据集下载

链接:https://pan.baidu.com/s/1zs9U76OmGAIwbYr91KQxgg
提取码:bhjx

2. GooleNet网络介绍

GoogleNet是google推出的基于Inception模块的深度神经网络模型,在2014年的ImageNet竞赛中夺得了冠军。
GoogleNet在当时的创新点有两个:

  • 使用了模型融合
    在GoogleNet中,运用了许多的Inception模块。

    上图中,左边是原始的Inception结构,右边是优化后的Inception结构。
    Inception结构特点:使用不同卷积核,提取不同特征,最后融合起来。
  • 使用1×1卷积
    • 1×1卷积作用:

      1. 增加网络非线性——网络层数更多。
      2. 减少计算量和需要训练的权值。

网络结构:

3. 代码演示

3.1 导入依赖库

from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, LearningRateScheduler, TensorBoard
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
import tensorflow as tf
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout, Conv2D, MaxPool2D, Flatten, Input, concatenate, AvgPool2D
from tensorflow.keras.models import Model

3.2 定义超参数

classes = 17 # 类别
batch_size = 32 # 批次大小
epochs = 100 # 训练的轮次
img_size = 224 # resize图片的大小
lr = 1e-3  # 学习率
datasets = './dataset/data_flower' # 数据集的路径

3.3 定义数据集及数据增强

def data_process_func():# ---------------------------------- ##   训练集进行的数据增强操作#   1. rotation_range -> 随机旋转角度#   2. width_shift_range -> 随机水平平移#   3. width_shift_range -> 随机数值平移#   4. rescale -> 数据归一化#   5. shear_range -> 随机错切变换#   6. zoom_range -> 随机放大#   7. horizontal_flip -> 水平翻转#   8. brightness_range -> 亮度变化#   9. fill_mode -> 填充方式# ---------------------------------- #train_data = ImageDataGenerator(rotation_range=50, width_shift_range=0.1, height_shift_range=0.1,rescale=1/255.0,shear_range=10,zoom_range=0.1,horizontal_flip=True,brightness_range=(0.7, 1.3),fill_mode='nearest')# ---------------------------------- ##   测试集数据增加操作#   归一化即可# ---------------------------------- #test_data = ImageDataGenerator(rescale=1/255)# ---------------------------------- ##   训练器生成器#   测试集生成器# ---------------------------------- #train_generator = train_data.flow_from_directory(f'{datasets}/train',target_size=(img_size, img_size),batch_size=batch_size)test_generator = test_data.flow_from_directory(f'{datasets}/test',target_size=(img_size, img_size),batch_size=batch_size)return train_generator, test_generator

3.4 定义网络结构

def Inception(x, filters, name):t1 = Conv2D(filters=filters[0], kernel_size=(1, 1), strides=(1, 1), padding='same', activation='relu', name=f'{name}Inception_1_Conv1')(x)t2 = Conv2D(filters=filters[1], kernel_size=(1, 1), strides=(1, 1), padding='same', activation='relu', name=f'{name}Inception_2_Conv1')(x)t2 = Conv2D(filters=filters[2], kernel_size=(3, 3), strides=(1, 1), padding='same', activation='relu', name=f'{name}Inception_2_Conv2')(t2)t3 = Conv2D(filters=filters[3], kernel_size=(1, 1), strides=(1, 1), padding='same', activation='relu', name=f'{name}Inception_3_Conv1')(x)t3 = Conv2D(filters=filters[4], kernel_size=(5, 5), strides=(1, 1), padding='same', activation='relu', name=f'{name}Inception_3_Conv2')(t3)pooling = MaxPool2D(pool_size=(3, 3), strides=(1, 1), padding='same', name=f'{name}Inception_4_Pool1')(x)pooling = Conv2D(filters=filters[5],kernel_size=(1, 1), strides=(1, 1), padding='same', activation='relu', name=f'{name}Inception_4_Pool2')(pooling)x = concatenate([t1, t2, t3, pooling], axis=3)return xdef Goolenet(inputs, classes):x = Conv2D(filters=64, kernel_size=(7, 7), strides=(2, 2), padding='same', activation='relu', name='Conv1')(inputs)x = MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same', name='Pool1')(x)x = Conv2D(filters=64, kernel_size=(1, 1), strides=(1, 1), padding='same', activation='relu', name='Conv2')(x)x = Conv2D(filters=192, kernel_size=(3, 3), strides=(1, 1), padding='same', activation='relu', name='Conv3')(x)x = MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same', name='Pool2')(x)x = Inception(x, filters=[64, 96, 128, 16, 32, 32], name='Inception_block1')x = Inception(x, filters=[128, 128, 192, 32, 96, 64], name='Inception_block2')x = MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same', name='Pool3')(x)x = Inception(x, filters=[192,96,208,16,48,64], name='Inception_block3')x = Inception(x, filters=[160,112,224,24,64,64], name='Inception_block4')x = Inception(x, filters=[128,128,256,24,64,64], name='Inception_block5')x = Inception(x, filters=[112,144,288,32,64,64], name='Inception_block6')x = Inception(x, filters=[256,160,320,32,128,128], name='Inception_block7')x = MaxPool2D(pool_size=3,strides=2,padding='same', name='Pool4')(x)x = Inception(x, [256,160,320,32,128,128], name='Inception_block8')x = Inception(x, [384,192,384,48,128,128], name='Inception_block9')x = AvgPool2D(pool_size=7,strides=7,padding='same', name='AvgPool1')(x)x = Flatten()(x)x = Dropout(0.4)(x)x = Dense(classes, activation='softmax')(x)return x

3.5 定义学习率调整函数

# 学习率调整
def adjust_lr(epoch, lr=lr):print("Seting to %s" % (lr))if epoch < 3:return lrelse:return lr * 0.93

3.6 开始训练模型

# 使用GPU
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:tf.config.experimental.set_memory_growth(gpu, True)
inputs = Input(shape=(img_size,img_size,3))
# 构造器
train_generator, test_generator = data_process_func()
model = Model(inputs=inputs, outputs=Goolenet(inputs=inputs, classes=classes))
callbackss = [EarlyStopping(monitor='val_loss', patience=10, verbose=1), # val_loss若10个轮次还不下降,就停止训练ModelCheckpoint('logs/ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',monitor='val_loss',save_weights_only=True, save_best_only=False, period=1), # 每个轮次都保存权重文件,在logs文件夹下LearningRateScheduler(adjust_lr)# TensorBoard(log_dir='./logs1')]
print('---------->epoch0 starting--------->')
model.compile(optimizer=Adam(lr=lr), loss='categorical_crossentropy', metrics=['accuracy'])
history = model.fit(x                    = train_generator,validation_data      = test_generator,epochs               = epochs,workers              = 1,callbacks            = callbackss,
)

3.7 预测图片

import tensorflow as tf
from PIL import Image
from tensorflow.keras.models import load_model
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.preprocessing.image import img_to_array, load_img
import numpy as np
import os
import matplotlib.pyplot as plt# -------------------------------------------------#
#   指定所测试集的路径
#   权重路径
#   网络
#   类别
#   图片大小
# -------------------------------------------------#
if True:datasets = './dataset/data_flower/test'names = os.listdir(datasets)weight = './model_data/test_acc-0.754-Goolenet_flowers_val_loss0.944-1.h5'classes = 17img_size = 224# -------------------------------------------------## 归一化
def preprocess_input(x):x /= 255return xinputs = Input(shape=(img_size,img_size,3))
model = Model(inputs=inputs, outputs=Goolenet(inputs=inputs, classes=classes))# -------------------------------------------------#
#   载入模型
# -------------------------------------------------#
model.load_weights(weight)
while True:img_path = input('input img_path:')try:img = Image.open(img_path)img = img.resize((224, 224))image_data = np.expand_dims(preprocess_input(np.array(img, np.float32)), 0)except:print('The path is error!')continueelse:plt.imshow(img)plt.axis('off')p =model.predict(image_data)[0]print(model.predict(image_data))pred_name = names[np.argmax(p)]plt.title('%s:%.3f'%(pred_name, np.max(p)))plt.show()

输入:./dataset/data_flower/test/flower1/image_0082.jpg

结果如下:

tensorflow2.4使用GooleNet实现识别植物花朵图像项目相关推荐

  1. 利用CNN和迁移学习方法识别植物叶片疾病

    利用CNN和迁移学习方法识别植物叶片疾病 Abstract 及时发现和早期预防作物病害对提高产量至关重要.由于深度卷积神经网络(CNN)在机器视觉领域取得了令人瞩目的成果,本文采用深度卷积神经网络(C ...

  2. 好用的识别植物的软件app合集分享,快码住了

    在我们的日常生活中,我们经常会看到各种各样的植物.但是,我们并不总是知道这些植物的名称和特征.好在现在有很多植物识别软件,可以轻松地识别植物并获取相关信息.在本文中,我们将介绍四个植物识别软件,让你知 ...

  3. 两个妙招教你怎么拍照识别植物,增长见识

    每个地方都有着独特的气候,而在这些气候的影响下,山川湖海都有着它们特有的形态,自然,植被的生长也会受气候.环境的影响.当我们去到其他城市游玩,总会被一些新奇的风景.动植物所吸引,从而想要了解这些事物, ...

  4. 这三款软件让你轻松实现在线扫花识别植物

    如今,鲜花是我们日常生活中最常见的植物,但是随着鲜花种类的不断增多,它的许多的种类信息,想必大多数的朋友都难以认识清楚,因此,有的人就会使用一些识别鲜花的APP来帮助我们通过拍照而轻松获知鲜花的信息, ...

  5. 有什么拍照识别植物的软件?建议收藏这几个软件

    在我们日常生活中,经常会接触到各种植物,比如去公园游玩的时候,或者是去花鸟市场的时候.有时看到一些美丽的植物,但是我们却不认识植物叫什么,这种感觉就非常难受.所以为了解决这个问题,今天我就给大家带来教 ...

  6. 如何通过拍照识别植物?试试这几个软件

    小孩子有着旺盛的好奇心,我有时候带他们出去玩,他们看到什么就问什么,问得比较多的就是路边的植物,我常常被问得哑口无言,毕竟我是真的不知道它们是什么植物.这个问题要解决很容易,只要我们拥有一个植物拍照识 ...

  7. 有了这几款软件,就不用纠结拍照识别植物的软件哪个好了

    相信很多小伙伴都曾外出旅行,去陌生的城市感受其他地方的风土人情,当然,旅途中我们也会看到许多未曾见到过的风景,感受不同的文化习俗,甚至会见到许多不同气候影响下特有的植被等. 植物是我们生活中再常见不过 ...

  8. 拍照识别植物app哪个好?来看看这几个工具

    不知道小伙伴们出门看到好看的植物会不会想要在家里自己种植呢?可是我们该怎么知道它是什么品种呢?其实我们可以借助一些植物识别软件来帮助我们识别,但识别软件那么多,哪个好呢?大家不用为这个问题发愁,下面我 ...

  9. 我们如何一键识别?拍照识别植物的软件有哪些?

    在生活中我们常常会遇到这样的情况,比如在路边看见一株植物或者花卉,觉得自己非常喜欢,但是观看许久却不知道它的名字叫什么?其实这时候我们可以借助一些识别软件来识别植物,那你们知道识别植物的软件叫什么吗? ...

最新文章

  1. 田渊栋的2021年终总结:多读历史!历史就是一个大规模强化学习训练集
  2. shell中read用法
  3. android 实现表格横向混动_Flutter混合开发和Android动态更新实践
  4. etc/ld.so.conf的使用说明
  5. (转)函数指针,指针函数,指向函数的指针,返回指针的函数
  6. Linux EXT3文件系统下成功恢复误删的文件
  7. 底层原理_Spring框架底层原理IoC
  8. Rstudio修改背景颜色和源
  9. 动态规划求解限时采药问题(洛谷P1048题题解,Java语言描述)
  10. 微mysql命令行_mysql命令大全
  11. oracle rac 启动失败has,oracle11.2.0.4 rac asm启动故障
  12. python中lambda的另类使用
  13. (转) hash 函数及其重要性
  14. Linux内核分析:跟踪分析Linux内核的启动过程
  15. JS入门到精通完整版
  16. 利用selenium 实现对百度图片搜索中的图片的抓取
  17. linux-2.6.32在mini2440开发板上移植(15)之移植看门狗驱动
  18. 【iOS开发】——weak底层原理
  19. Hibernate手动控制事物
  20. 苹果手机怎么在照片上添加文字_手机照片如何添加文字?原来方法这么简单,手把手教你学会。...

热门文章

  1. 前端(HTML5基础学习笔记)
  2. 调焦后焦实现不同距离成像_分辨率、调焦和景深
  3. wp:涅普冬令营(2021) 监听消息
  4. 【思维导图怎么画】万彩脑图大师教程 | 嵌入企业Logo到思维导图
  5. android动态壁纸2.2.1,动态壁纸选择器
  6. 数据结构翻转课堂答疑实录——概述
  7. 工程实践 | 在 Flutter 中实现一个精准的滑动埋点
  8. NAACL 2019 | ​注意力模仿:通过关注上下文来更好地嵌入单词
  9. matlab中设x=zsin3x,三阶偏导数设e的sin(2x+3y)次方,求Z的三阶偏导数是多少?
  10. solidworks2014方程式添加全局变量存在句法错误的解决方案