猫狗大战是kaggle平台上的一个比赛,用于实现猫和狗的二分类问题。最近在学卷积神经网络,所以自己动手搭建了几层网络进行训练,然后利用迁移学习把别人训练好的模型直接应用于猫狗分类这个数据集,比较一下实验效果。

自己搭建网络

需要用到的库

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
import numpy as np
import matplotlib.pyplot as plt

数据集加载
数据是通过这个网站下载的,也可以自己先下载好。

dataset_url = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"dataset_path = tf.keras.utils.get_file("cats_and_dogs_filtered.zip", origin=dataset_url, extract=True)
dataset_dir = os.path.join(os.path.dirname(dataset_path), "cats_and_dogs_filtered")train_cats = os.path.join(dataset_dir,"train","cats")
train_dogs = os.path.join(dataset_dir,"train","dogs")test_cats = os.path.join(dataset_dir,"validation","cats")
test_dogs = os.path.join(dataset_dir,"validation","dogs")train_dir = os.path.join(dataset_dir,"train")
test_dir = os.path.join(dataset_dir,"validation")

统计训练集和测试集的大小

train_dogs_num = len(os.listdir(train_dogs))
train_cats_num = len(os.listdir(train_cats))test_dogs_num = len(os.listdir(test_dogs))
test_cats_num = len(os.listdir(test_cats))train_all = train_cats_num+train_dogs_num
test_all = test_cats_num+test_dogs_num

设置超参数

batch_size = 128
epochs = 50
height = 150
width = 150

数据预处理
我们所作的预处理包含以下几步:
①读取图像数据。
②对图像内容进行解码并转换成合适的格式。
③对图像进行打散、规定图片大小。
④将数值归一化。

train_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255)
test_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255)
train_data_gen = train_generator.flow_from_directory(batch_size=batch_size,directory=train_dir,shuffle=True,target_size=(height,width),class_mode="binary")
test_data_gen = train_generator.flow_from_directory(batch_size=batch_size,directory=test_dir,shuffle=True,target_size=(height,width),class_mode="binary")

构建网络
网络模型为:3层卷积池化层+Dropout+Flatten+两层全连接层

model = tf.keras.Sequential([tf.keras.layers.Conv2D(16,3,padding="same",activation="relu",input_shape=(height,width,3)),tf.keras.layers.MaxPool2D(),tf.keras.layers.Conv2D(32,3,padding="same",activation="relu"),tf.keras.layers.MaxPool2D(),tf.keras.layers.Conv2D(64,3,padding="same",activation="relu"),tf.keras.layers.MaxPool2D(),tf.keras.layers.Dropout(0.5),tf.keras.layers.Flatten(),tf.keras.layers.Dense(512,activation="relu"),tf.keras.layers.Dense(1)
])
model.compile(optimizer="adam",loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),metrics=["acc"])

训练模型

history = model.fit_generator(train_data_gen,steps_per_epoch=train_all//batch_size,epochs=epochs,validation_data=test_data_gen,validation_steps=test_all//batch_size)

训练结果可视化

#训练结果可视化
accuracy = history.history["acc"]
test_accuracy = history.history["val_acc"]
loss = history.history["loss"]
test_loss = history.history["val_loss"]
epochs_range = range(epochs)
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.plot(epochs_range,accuracy,label = "Training Acc")
plt.plot(epochs_range,test_accuracy,label = "Test Acc")
plt.legend()
plt.title("Training and Test Acc")plt.subplot(1,2,2)
plt.plot(epochs_range,loss,label = "Training loss")
plt.plot(epochs_range,test_loss,label = "Test loss")
plt.legend()
plt.title("Training and Test loss")
plt.show()

自己搭建的网络对猫狗大战的数据进行训练,经过50次的epoch,最终的训练结果是70%左右的正确率,不是很高。
训练结果
其中训练集的模型准确率接近100%,但是测试集的正确率比较低。

Epoch 50/501/15 [=>............................] - ETA: 5s - loss: 0.0089 - acc: 1.00002/15 [===>..........................] - ETA: 3s - loss: 0.0071 - acc: 1.00003/15 [=====>........................] - ETA: 4s - loss: 0.0086 - acc: 1.00004/15 [=======>......................] - ETA: 3s - loss: 0.0113 - acc: 0.99785/15 [=========>....................] - ETA: 3s - loss: 0.0163 - acc: 0.99666/15 [===========>..................] - ETA: 3s - loss: 0.0163 - acc: 0.99587/15 [=============>................] - ETA: 2s - loss: 0.0149 - acc: 0.99658/15 [===============>..............] - ETA: 2s - loss: 0.0135 - acc: 0.99699/15 [=================>............] - ETA: 2s - loss: 0.0139 - acc: 0.9964
10/15 [===================>..........] - ETA: 1s - loss: 0.0145 - acc: 0.9959
11/15 [=====================>........] - ETA: 1s - loss: 0.0139 - acc: 0.9963
12/15 [=======================>......] - ETA: 1s - loss: 0.0155 - acc: 0.9953
13/15 [=========================>....] - ETA: 0s - loss: 0.0155 - acc: 0.9944
14/15 [===========================>..] - ETA: 0s - loss: 0.0170 - acc: 0.9943
15/15 [==============================] - 9s 595ms/step - loss: 0.0174 - acc: 0.9941 - val_loss: 1.2763 - val_acc: 0.7522


经过分析可得,出现了Overfitting的情况,我们对数据集做些调整。
对原来的数据集,做随机翻转,水平翻转,随机放大操作

train_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255,rotation_range=45,#随机翻转width_shift_range=.15,height_shift_range=.15,horizontal_flip=True,#水平翻转zoom_range=0.5#放大操作)

最终的训练结果为:

Epoch 50/501/15 [=>............................] - ETA: 5s - loss: 0.5440 - acc: 0.70312/15 [===>..........................] - ETA: 9s - loss: 0.5151 - acc: 0.74223/15 [=====>........................] - ETA: 10s - loss: 0.5249 - acc: 0.70834/15 [=======>......................] - ETA: 10s - loss: 0.5082 - acc: 0.72275/15 [=========>....................] - ETA: 8s - loss: 0.4787 - acc: 0.7382 6/15 [===========>..................] - ETA: 8s - loss: 0.4764 - acc: 0.74177/15 [=============>................] - ETA: 7s - loss: 0.4801 - acc: 0.74068/15 [===============>..............] - ETA: 6s - loss: 0.4759 - acc: 0.75319/15 [=================>............] - ETA: 5s - loss: 0.4805 - acc: 0.7554
10/15 [===================>..........] - ETA: 4s - loss: 0.4877 - acc: 0.7459
11/15 [=====================>........] - ETA: 3s - loss: 0.4944 - acc: 0.7390
12/15 [=======================>......] - ETA: 2s - loss: 0.4969 - acc: 0.7325
13/15 [=========================>....] - ETA: 1s - loss: 0.4939 - acc: 0.7345
14/15 [===========================>..] - ETA: 0s - loss: 0.4933 - acc: 0.7368
15/15 [==============================] - 23s 2s/step - loss: 0.4957 - acc: 0.7377 - val_loss: 0.5394 - val_acc: 0.7277


Overfitting的情况得到了改善,但是准确率没有得到相应的提高。这是经过50个epoch之后的结果。

迁移学习

所谓的迁移学习就是通过别人已经训练好的网络直接对自己的数据进行处理。所采用的网络是别人已经训练好的VGG16网络。

模型加载

#引用VGG16模型
conv_base = tf.keras.applications.VGG16(weights='imagenet',include_top=False)
#设置为不可训练
conv_base.trainable =False
#模型搭建
model = tf.keras.Sequential()
model.add(conv_base)
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dense(512,activation='relu'))
model.add(tf.keras.layers.Dense(1,activation='sigmoid'))

模型训练

model.compile(optimizer='Adam',loss='binary_crossentropy',metrics=['acc'])
history = model.fit(train_data_gen,epochs=epochs,steps_per_epoch=train_all//batch_size,validation_data=test_data_gen,validation_steps=test_all//batch_size)

训练结果

Epoch 10/101/15 [=>............................] - ETA: 4:36 - loss: 0.3333 - acc: 0.84382/15 [===>..........................] - ETA: 2:14 - loss: 0.3917 - acc: 0.81253/15 [=====>........................] - ETA: 1:25 - loss: 0.3870 - acc: 0.81554/15 [=======>......................] - ETA: 1:01 - loss: 0.3967 - acc: 0.80825/15 [=========>....................] - ETA: 46s - loss: 0.3994 - acc: 0.8125 6/15 [===========>..................] - ETA: 36s - loss: 0.3935 - acc: 0.81397/15 [=============>................] - ETA: 28s - loss: 0.3976 - acc: 0.81608/15 [===============>..............] - ETA: 22s - loss: 0.3948 - acc: 0.81669/15 [=================>............] - ETA: 17s - loss: 0.3958 - acc: 0.8207
10/15 [===================>..........] - ETA: 13s - loss: 0.3970 - acc: 0.8198
11/15 [=====================>........] - ETA: 10s - loss: 0.3934 - acc: 0.8243
12/15 [=======================>......] - ETA: 7s - loss: 0.3946 - acc: 0.8246
13/15 [=========================>....] - ETA: 4s - loss: 0.3880 - acc: 0.8274
14/15 [===========================>..] - ETA: 2s - loss: 0.3900 - acc: 0.8251
15/15 [==============================] - 48s 3s/step - loss: 0.3910 - acc: 0.8237 - val_loss: 0.4328 - val_acc: 0.8025

因为VGG16模型训练起来比较耗时,所以我只设置了10个epoch,但是最终的结果已经比自己搭建的网络好很多了。

没有出现过拟合的情况(数据集经过预处理了),而且在epoch只有10的情况下,正确率已经达到了80%。

总结

通过对比我们可以发现, 自己搭建的网络在模型准确率上面,不如迁移学习所使用的网络模型,有可能是我的网络泛化能力比较差的问题。路过的大佬如果有更好的网络模型,可以讨论一下。

参考博客1

参考博客2

基于tensorflow2.0实现猫狗大战(搭建网络迁移学习)相关推荐

  1. 深度学习之基于Tensorflow2.0实现AlexNet网络

    在之前的实验中,一直是自己搭建或者是迁移学习进行物体识别,但是没有对某一个网络进行详细的研究,正好人工智能课需要按组上去展示成果,借此机会实现一下比较经典的网络,为以后的研究学习打下基础.本次基于Te ...

  2. 深度学习之基于Tensorflow2.0实现ResNet50网络

    理论上讲,当网络层数加深时,网络的性能会变强,而实际上,在不断的加深网络层数后,分类性能不会提高,而是会导致网络收敛更缓慢,准确率也随着降低.利用数据增强等方法抑制过拟合后,准确率也不会得到提高,出现 ...

  3. 【Python深度学习】基于Tensorflow2.0构建CNN模型尝试分类音乐类型(二)

    前情提要 基于上文所说 基于Tensorflow2.0构建CNN模型尝试分类音乐类型(一) 我用tf2.0和Python3.7复现了一个基于CNN做音乐分类器.用余弦相似度评估距离的一个音乐推荐模型. ...

  4. 基于tensorflow2.0+CNN实现手势识别(全)

    基于tensorflow2.0+CNN实现手势识别 环境:windows10.pycharm2017.python3.64.tensorflow2.0.opencv3 我在github上分享了代码以及 ...

  5. 基于TensorFlow2.0的摄像头数字识别

    import numpy as np import cv2 from skimage import data, segmentation, measure, morphology, color imp ...

  6. 基于tensorflow2.0+opencv的花卉识别系统源码(含数据集)

    花卉识别-基于tensorflow2.3实现 完整代码下载地址:基于tensorflow2.0+opencv的花卉识别系统源码( 文件目录 # 数据下载地址 https://storage.googl ...

  7. 用 X 光检测新冠肺炎?也许孪生网络+迁移学习是更好的选择!

    始于2019年的新冠肺炎仍然肆虐全球,快速低成本检测该疾病成为了医学技术领域最热门的话题,早已有专家发现,核酸+胸部医学影像检测相结合是更可信的检测手段. 胸部X光影像是低成本的检测技术,但深度学习往 ...

  8. 基于Mindspore2.0的GPT2预训练模型迁移教程

    摘要: 这篇文章主要目的是为了让大家能够清楚如何用MindSpore2.0来进行模型的迁移. 本文分享自华为云社区<MindNLP-基于Mindspore2.0的GPT2预训练模型迁移教程> ...

  9. ResNet网络结构详解,网络搭建,迁移学习

    前言: 参考内容来自up:6.1 ResNet网络结构,BN以及迁移学习详解_哔哩哔哩_bilibili up的代码和ppt:https://github.com/WZMIAOMIAO/deep-le ...

最新文章

  1. 机器学习算法推导的较好例子
  2. 使用移动设备 连接到Exchange Server 2007
  3. 《剑指offer》c++版本 9. 用两个栈来实现一个队列
  4. 【学习笔记】JS进阶语法一window对象
  5. FIR IIR滤波器的设计
  6. mysql的tcp链接过程_tcp建立连接和断开连接过程
  7. 1305. GT考试
  8. Spring Security 认证执行流程
  9. Android学习之碎片的生命周期
  10. C语言编译php环境,vscode中C语言编译环境的配置方法(分享)
  11. RxJava flatMap,switchMap,concatMap
  12. 笔记3:数字和数学计算
  13. scp传输文件的命令
  14. 软件生命周期、面向对象基本概要
  15. 新手引导 自定义遮罩 点击穿透
  16. SAP ABAP 从入门至精通书籍推荐
  17. CICD概念 k8s DevOps
  18. ADDA数模转换(PCF8591)
  19. 这5个是不是元宇宙游戏遗珠?
  20. html背景图片自适应窗口大小

热门文章

  1. 浅谈C/C++中的指针和数组(一)
  2. html博客页面实验报告,×××实验报告
  3. 报头中的偏移量作用_C语言中函数的实现
  4. python sendto函数返回值_有返回值的函数amp;闭包(python)
  5. 关于Git下载上传项目的操作指令
  6. c语言程序输入n个数字排序,输入n个数字然后进行排序,用C语言编写。注意是n个数啊,不是确定的个数。...
  7. Qt的QStyle类的标准图标汇总
  8. java小票_Java编程打印购物小票实现代码
  9. JVM知识点总览:高级Java工程师面试必备
  10. MySql详解(六)