类比于猫狗大战,利用自己搭建的CNN网络和已经搭建好的VGG16实现花朵识别。

1.导入库

注:导入的库可能有的用不到,之前打acm时留下的毛病,别管用不用得到,先写上再说

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Conv2D,Flatten,Dropout,MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os,PIL
import numpy as np
import matplotlib.pyplot as plt
import pathlib
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

2.数据下载

#数据下载
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
dataset_dir = tf.keras.utils.get_file(fname = 'flower_photos',origin=dataset_url,untar=True,cache_dir= 'E:/Deep-Learning/flowers')
dataset_dir = pathlib.Path(dataset_dir)

下载之后的文件是这个样子的。

为了方便处理,博主将数据集按照8:2的比例手动(就是不会用代码)划分成了训练集和测试集。
处理之后的文件夹如图所示:

3.计算数据总数

#将train和test下面的数据的文件路径加载到变量中
train_daisy = os.path.join(dataset_dir,"train","daisy")
train_dandelion = os.path.join(dataset_dir,"train","dandelion")
train_roses = os.path.join(dataset_dir,"train","roses")
train_sunflowers = os.path.join(dataset_dir,"train","sunflowers")
train_tulips = os.path.join(dataset_dir,"train","tulips")test_daisy = os.path.join(dataset_dir,"test","daisy")
test_dandelion = os.path.join(dataset_dir,"test","dandelion")
test_roses = os.path.join(dataset_dir,"test","roses")
test_sunflowers = os.path.join(dataset_dir,"test","sunflowers")
test_tulips = os.path.join(dataset_dir,"test","tulips")
#将训练集和测试集加载到变量中
train_dir = os.path.join(dataset_dir,"train")
test_dir = os.path.join(dataset_dir,"test")
#统计训练集和测试集的数据数目
train_daisy_num = len(os.listdir(train_daisy))
train_dandelion_num = len(os.listdir(train_dandelion))
train_roses_num = len(os.listdir(train_roses))
train_sunflowers_num = len(os.listdir(train_sunflowers))
train_tulips_num = len(os.listdir(train_tulips))
train_all = train_tulips_num+train_daisy_num+train_dandelion_num+train_roses_num+train_sunflowers_numtest_daisy_num = len(os.listdir(test_daisy))
test_dandelion_num = len(os.listdir(test_dandelion))
test_roses_num = len(os.listdir(test_roses))
test_sunflowers_num = len(os.listdir(test_sunflowers))
test_tulips_num = len(os.listdir(test_tulips))
test_all = test_tulips_num+test_daisy_num+test_dandelion_num+test_roses_num+test_sunflowers_num

4.超参数的设置

batch_size = 32
epochs = 10
height = 180
width = 180

5.数据预处理

#归一化处理
train_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255)
test_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255)
#规定batch_size的大小,文件路径,打乱图片顺序,规定图片的大小
train_data_gen = train_generator.flow_from_directory(batch_size=batch_size,directory=train_dir,shuffle=True,target_size=(height,width),class_mode="categorical")
test_data_gen = test_generator.flow_from_directory(batch_size=batch_size,directory=test_dir,shuffle=True,target_size=(height,width),class_mode="categorical")

6.模型搭建&&模型训练

模型采用的是三层卷积池化层+Dropout+Flatten+两层全连接层

model = tf.keras.Sequential([tf.keras.layers.Conv2D(16,3,padding="same",activation="relu",input_shape=(height,width,3)),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Conv2D(32,3,padding="same",activation="relu"),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Conv2D(64,3,padding="same",activation="relu"),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Dropout(0.5),tf.keras.layers.Flatten(),tf.keras.layers.Dense(128,activation="relu"),tf.keras.layers.Dense(5,activation='softmax')
])
model.compile(optimizer="adam",loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=["acc"])
history = model.fit_generator(train_data_gen,steps_per_epoch=train_all//batch_size,epochs=epochs,validation_data=test_data_gen,validation_steps=test_all//batch_size)

实验效果如下所示:

出现了过拟合的情况,epochs为50的情况下,运行结果如下所示:

测试集的准确率相比于epochs为10的情况高了10%左右,但是准确率仍然很低,而且过拟合的情况还是很严重,运行时间特别长~

7.数据增强

train_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255,rotation_range=45,#倾斜45°width_shift_range=.15,height_shift_range=.15,horizontal_flip=True,#水平翻转zoom_range=0.5)#随机放大

实验结果如下所示:
epochs为20,过拟合情况得到了改善。

loss: 1.2052 - acc: 0.6962 - val_loss: 1.2502 - val_acc: 0.6534
#epochs为20的时候准确率是65,而epochs为50的时候准确率为66,并没有很大的改善。

8.迁移学习

利用别人训练好的VGG16网络对同样的数据进行训练。

#引用VGG16模型
conv_base = tf.keras.applications.VGG16(weights='imagenet',include_top = False)
#设置为不可训练
conv_base.trainable = False
#搭建模型
model = tf.keras.Sequential()
model.add(conv_base)
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dense(128,activation='relu'))
model.add(tf.keras.layers.Dense(5,activation='sigmoid'))

训练结果如下图所示:

在epochs只有10的情况下,准确率就达到了80,有了明显的提升。
努力加油a啊

深度学习之基于卷积神经网络实现花朵识别相关推荐

  1. Python深度学习实例--基于卷积神经网络的小型数据处理(猫狗分类)

    Python深度学习实例--基于卷积神经网络的小型数据处理(猫狗分类) 1.卷积神经网络 1.1卷积神经网络简介 1.2卷积运算 1.3 深度学习与小数据问题的相关性 2.下载数据 2.1下载原始数据 ...

  2. 深度学习之基于卷积神经网络(VGG16)实现性别判别

    无意间在kaggle上发现的一个数据集,旨在提高网络模型判别男女的准确率,博主利用迁移学习试验了多个卷积神经网络,最终的模型准确率在95%左右.并划分了训练集.测试集.验证集三类,最终在验证集上的准确 ...

  3. 深度学习之基于卷积神经网络(VGG16CNN)实现海贼王人物识别

    硬件问题真的是搞机器学习的一个痛处,更何况这只是入门级别的. 基于CNN和VGG16,实现对海贼王人物的分类识别.本次自己动手搭建了VGG16 网络,并且和迁移学习的VGG16的网络的实验效果做了一个 ...

  4. 深度学习之基于卷积神经网络实现服装图像识别

    本博客与手写数字识别大同小异. 1.导入所需库 import tensorflow as tf from tensorflow.keras import datasets, layers, model ...

  5. 深度学习之基于卷积神经网络实现超大Mnist数据集识别

    在以往的手写数字识别中,数据集一共是70000张图片,模型准确率可以达到99%以上的准确率.而本次实验的手写数字数据集中有120000张图片,而且数据集的预处理方式也是之前没有遇到过的.最终在验证集上 ...

  6. 深度学习(DL)与卷积神经网络(CNN)学习笔记随笔-03-基于Python的LeNet之LR

    原地址可以查看更多信息 本文主要参考于:Classifying MNIST digits using Logistic Regression  python源代码(GitHub下载 CSDN免费下载) ...

  7. 【深度学习系列】卷积神经网络CNN原理详解(一)——基本原理(1)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  8. 深度学习21天——卷积神经网络(CNN):天气识别(第5天)

    目录 一.前期准备 1.1 设置GPU 1.2 导入数据 1.2.1 np.random.seed( i ) 1.2.2 tf.random.set_seed() 1.3 查看数据 二.数据预处理 2 ...

  9. 深度学习笔记:卷积神经网络的可视化--卷积核本征模式

    目录 1. 前言 2. 代码实验 2.1 加载模型 2.2 构造返回中间层激活输出的模型 2.3 目标函数 2.4 通过随机梯度上升最大化损失 2.5 生成滤波器模式可视化图像 2.6 将多维数组变换 ...

最新文章

  1. 外部样式表声明的样式并不会进入style对象
  2. mfc只有doc才能序列化吗_MFC序列化-IMPLEMENT_SERIAL(...)
  3. 一文详尽支付宝系统架构(附内部架构图)
  4. Ubuntu20.04 服务器版安装
  5. Ubuntu 环境初始化
  6. 用大白话彻底搞懂 HBase RowKey 详细设计!
  7. pve万兆网卡驱动_无线环境下打游戏,还能不能更稳?附各类AX网卡换装思路
  8. 软件安全测试之应用安全测试
  9. Spring注解原理详解
  10. linux环境安装的odac,net不安装Oracle11g客户端直接使用ODAC
  11. 淘宝店铺的装修是店铺的门面,如何进行淘宝店铺装修?需要注意的点有哪些?
  12. 实现isPrime()函数,参数是整数,如果整数是质数, 返回True,否则返回False
  13. python实现简单的神经网络,python的神经网络编程
  14. adb shell打开开发者选项
  15. 单目标跟踪SiamMask:特定目标车辆追踪 part2
  16. 课时31:永久储存:腌制一缸美味的泡菜
  17. 《卸甲笔记》-基础语法对比
  18. 中心矩和原点矩_原点矩与中心矩.ppt
  19. 小黑重装WIFI之解 - 硬件无线电已关闭 802.11无线通信 禁用状态无法启用
  20. opencv 表识别 工业表智能识别 数字式表盘识别,指针式表盘刻度识别,分为表检测,表盘纠正,刻度分割,刻度拉直识别

热门文章

  1. C# XML 添加,修改,删除Xml节点
  2. 微信小程序:背景图片在电脑可以显示,真机测试时无法显示
  3. IOS开发基础之微博项目第1天-OC版
  4. IOS开发基础之汽车品牌项目-14
  5. 怎么把python程序发给别人_想把你写的Python程序发给别人用?打包成exe啊!
  6. java exec mvn_maven---常用插件之EXEC
  7. Visual Studio 智能提示消失解决办法
  8. 猪行天下之Python基础——1.1 Python开发环境搭建
  9. 手把手教你撸一个简易的 webpack
  10. 〖前端开发〗HTML/CSS基础知识学习笔记