部分内容来自 博主史丹利复合田的Keras 入门课6 – 使用Inception V3模型进行迁移学习
地址:https://blog.csdn.net/tsyccnh/article/details/78889838

迁移学习主要分为两种

  • 第一种即所谓的transfer learning,迁移训练时,移掉最顶层,比如ImageNet训练任务的顶层就是一个1000输出的全连接层,换上新的顶层,比如输出为10的全连接层,然后训练的时候,只训练最后两层,即原网络的倒数第二层和新换的全连接输出层。可以说transfer learning将底层的网络当做了一个特征提取器来使用。
  • 第二种叫做fine tune,和transfer learning一样,换一个新的顶层,但是这一次在训练的过程中,所有的(或大部分)其它层都会经过训练。也就是底层的权重也会随着训练进行调整。

下载Inception V3相关数据

import osfrom tensorflow.keras import layers
from tensorflow.keras import Model
!wget --no-check-certificate \https://storage.googleapis.com/mledu-datasets/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5 \-O /tmp/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5from tensorflow.keras.applications.inception_v3 import InceptionV3local_weights_file = '/tmp/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5'

设置pre_model架构(也叫base_model)

InceptionV3模型,其两个参数比较重要,一个是weights,如果是’imagenet’,Keras就会自动下载已经在ImageNet上训练好的参数,如果是None,系统会通过随机的方式初始化参数,目前该参数只有这两个选择。另一个参数是include_top,如果是True会保留全连接层。如果是False,会去掉顶层的全连接网络
这里的input_shape = (150, 150, 3)是我们输入网络的猫狗图片结构

pre_trained_model = InceptionV3(input_shape = (150, 150, 3), include_top = False, weights = None)pre_trained_model.load_weights(local_weights_file)

冻结pre_trained_model所有层,让骨架模型不再被训练

for layer in pre_trained_model.layers:layer.trainable = False

由于网络结构太长这里只展示部分

pre_trained_model.summary()

这里我们不采用InceptionV3的全部层次,让网络取到mixed7层作为输出连接到我们新添加的网络

last_layer = pre_trained_model.get_layer('mixed7')
print('last layer output shape: ', last_layer.output_shape)
last_output = last_layer.output

last layer output shape: (None, 7, 7, 768)

给网络添加我们自己的几个层次,采用dropout减少过拟合

from tensorflow.keras.optimizers import RMSprop# Flatten the output layer to 1 dimension
x = layers.Flatten()(last_output)
# Add a fully connected layer with 1,024 hidden units and ReLU activation
x = layers.Dense(1024, activation='relu')(x)
# Add a dropout rate of 0.2
x = layers.Dropout(0.2)(x)
# Add a final sigmoid layer for classification
x = layers.Dense  (1, activation='sigmoid')(x)           model = Model( inputs=pre_trained_model.input,  outputs=x) model.compile(optimizer = RMSprop(lr=0.0001), loss = 'binary_crossentropy', metrics = ['acc'])

定义目录,采用数据增强

base_dir = '/tmp/cats_and_dogs_filtered'train_dir = os.path.join( base_dir, 'train')
validation_dir = os.path.join( base_dir, 'validation')train_cats_dir = os.path.join(train_dir, 'cats') # Directory with our training cat pictures
train_dogs_dir = os.path.join(train_dir, 'dogs') # Directory with our training dog pictures
validation_cats_dir = os.path.join(validation_dir, 'cats') # Directory with our validation cat pictures
validation_dogs_dir = os.path.join(validation_dir, 'dogs')# Directory with our validation dog picturestrain_cat_fnames = os.listdir(train_cats_dir)
train_dog_fnames = os.listdir(train_dogs_dir)# Add our data-augmentation parameters to ImageDataGenerator
train_datagen = ImageDataGenerator(rescale = 1./255.,rotation_range = 40,width_shift_range = 0.2,height_shift_range = 0.2,shear_range = 0.2,zoom_range = 0.2,horizontal_flip = True)# Note that the validation data should not be augmented!
test_datagen = ImageDataGenerator( rescale = 1.0/255. )# Flow training images in batches of 20 using train_datagen generator
train_generator = train_datagen.flow_from_directory(train_dir,batch_size = 20,class_mode = 'binary', target_size = (150, 150))     # Flow validation images in batches of 20 using test_datagen generator
validation_generator =  test_datagen.flow_from_directory( validation_dir,batch_size  = 20,class_mode  = 'binary', target_size = (150, 150))

Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.

训练网络

history = model.fit_generator(train_generator,validation_data = validation_generator,steps_per_epoch = 100,epochs = 20,validation_steps = 50,verbose = 2)

Epoch 1/20
100/100 - 29s - loss: 0.3360 - acc: 0.8655 - val_loss: 0.1211 - val_acc: 0.9470
Epoch 2/20
100/100 - 23s - loss: 0.2193 - acc: 0.9145 - val_loss: 0.1096 - val_acc: 0.9640
Epoch 3/20
100/100 - 23s - loss: 0.2038 - acc: 0.9290 - val_loss: 0.0888 - val_acc: 0.9660
Epoch 4/20
100/100 - 22s - loss: 0.1879 - acc: 0.9315 - val_loss: 0.1198 - val_acc: 0.9590
Epoch 5/20
100/100 - 23s - loss: 0.1760 - acc: 0.9415 - val_loss: 0.1155 - val_acc: 0.9660
Epoch 6/20
100/100 - 22s - loss: 0.1771 - acc: 0.9375 - val_loss: 0.1540 - val_acc: 0.9450
Epoch 7/20
100/100 - 23s - loss: 0.1916 - acc: 0.9370 - val_loss: 0.1616 - val_acc: 0.9550
Epoch 8/20
100/100 - 22s - loss: 0.1594 - acc: 0.9440 - val_loss: 0.1422 - val_acc: 0.9630
Epoch 9/20
100/100 - 23s - loss: 0.1669 - acc: 0.9465 - val_loss: 0.1099 - val_acc: 0.9650
Epoch 10/20
100/100 - 23s - loss: 0.1677 - acc: 0.9445 - val_loss: 0.1245 - val_acc: 0.9600
Epoch 11/20
100/100 - 22s - loss: 0.1653 - acc: 0.9470 - val_loss: 0.0918 - val_acc: 0.9730
Epoch 12/20
100/100 - 22s - loss: 0.1542 - acc: 0.9455 - val_loss: 0.1623 - val_acc: 0.9570
Epoch 13/20
100/100 - 22s - loss: 0.1525 - acc: 0.9520 - val_loss: 0.1087 - val_acc: 0.9670
Epoch 14/20
100/100 - 23s - loss: 0.1454 - acc: 0.9565 - val_loss: 0.1314 - val_acc: 0.9640
Epoch 15/20
100/100 - 22s - loss: 0.1279 - acc: 0.9525 - val_loss: 0.1515 - val_acc: 0.9630
Epoch 16/20
100/100 - 23s - loss: 0.1255 - acc: 0.9530 - val_loss: 0.1306 - val_acc: 0.9650
Epoch 17/20
100/100 - 22s - loss: 0.1430 - acc: 0.9575 - val_loss: 0.1226 - val_acc: 0.9660
Epoch 18/20
100/100 - 23s - loss: 0.1350 - acc: 0.9510 - val_loss: 0.1583 - val_acc: 0.9520
Epoch 19/20
100/100 - 22s - loss: 0.1288 - acc: 0.9580 - val_loss: 0.1170 - val_acc: 0.9710
Epoch 20/20
100/100 - 22s - loss: 0.1363 - acc: 0.9550 - val_loss: 0.1260 - val_acc: 0.9660

绘制损失和准确率图

import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(acc))plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend(loc=0)
plt.figure()plt.show()

tensorflow实现猫狗分类器(三)Inception V3迁移学习相关推荐

  1. 基于深度学习的图像分类:使用卷积神经网络实现猫狗分类器

    摘要: 深度学习在计算机视觉领域中具有广泛的应用.本文将介绍如何使用卷积神经网络(CNN)实现一个猫狗分类器.我们将使用Python和TensorFlow框架搭建一个简单的卷积神经网络模型,并利用猫狗 ...

  2. 猫狗大战——基于TensorFlow的猫狗识别(2)

    微信公众号:龙跃十二 我是小玉,一个平平无奇的小天才! 上篇文章我们说了关于猫狗大战这个项目的一些准备工作,接下来,我们看看具体的代码详解. 猫狗大战--基于TensorFlow的猫狗识别(1) 文件 ...

  3. 猫狗大战——基于TensorFlow的猫狗识别(1)

    微信公众号:龙跃十二 我是小玉,一个平平无奇的小天才! 简介: 关于猫狗识别是机器学习和深度学习的一个经典实例,下来小玉把自己做的基于CNN卷积神经网络利用Tensorflow框架进行猫狗的识别的程序 ...

  4. 基于tensorflow的猫狗分类

    基于tensorflow的猫狗分类 数据的准备 引入库 数据集来源 准备数据 显示一张图片的内容 搭建网络模型 构建网络 模型的编译 数据预处理 模型的拟合与评估 模型的拟合 预测一张图片 损失和精度 ...

  5. 体验AI乐趣:基于AI Gallery的二分类猫狗图片分类小数据集自动学习

    摘要:直接使用AI Gallery里面现有的数据集进行自动学习训练,很简单和方便,节约时间,不用自己去训练了,AI Gallery 里面有很多类似的有趣数据集,也非常好玩,大家一起试试吧. 本文分享自 ...

  6. Tensorflow(七)Retrain Google Inception V3

    1.下载Inception V3模型 Download-Link 在tensorflow官网中可以直接下载,下载完压缩包以后解压,注意不要删除这个压缩包,后面可能会用到,然后在同目录下创建一个log文 ...

  7. 深度学习-使用tensorflow实现猫狗识别

    最近一直在撸猫,为了猫主子的事情忧三愁四,皱纹多了不少,头发也掉了好几根,神态也多了几分忧郁,唯一不变的还是那份闲鱼的懒散和浪荡的心. 要说到深度学习图像分类的经典案例之一,那就是猫狗大战了.猫和狗在 ...

  8. tensorflow实现猫狗分类项目

    最近暑假有时间,因此想学一点东西,然后又因为限于自己电脑的显卡是A卡,不能GPU加速,也用不了pytorch框架,所以就选择tensorflow. 现在也在刚刚入坑tensorflow因此做的项目比较 ...

  9. 猫狗图像数据集上的深度学习模型性能对比

    LeNet模型简介 1. LeNet LeNet-5由七层组成(不包括输入层),每一层都包含可训练权重.通过卷积.池化等操作进行特征提取,最后利用全连接实现分类识别,下面是他的网络结构示意图: C:卷 ...

最新文章

  1. gbdt xgboost 贼难理解!
  2. python字符串能减吗_在python中减去两个字符串(Subtract two strings in python)
  3. python 皮尔森相关系数
  4. python中文分词jieba总结
  5. 【转】三五个人十来条枪 如何走出软件作坊成为开发正规军
  6. linux系统管理与服务器配置高志君_如何在 Linux 上安装、配置 NTP 服务器和客户端?...
  7. python常见安装
  8. unity_AR(一) 安卓手机无法显示模型和无法播放动画问题
  9. 虚拟化技术发展编年史
  10. 数据结构课设之航空订票系统(Java)下载链接在文末
  11. moodle php代码解读_Moodle学习笔记
  12. Python实现日程表
  13. android 签到自定义,Android日历签到,超级简单的实现方式
  14. 阿里云主机修改操作系统 详细步骤
  15. java计算长方体面积和周长
  16. java用户登录注册
  17. 怎么把html转成mp4,怎么把其他视频格式转成常用的mp4形式?
  18. 常规通知(Notification)模板
  19. 公民实验室:史上危险的手机间谍软件已感染45个国家/地区
  20. bigdata101:Permission denied (publickey .....) 这次的不一样

热门文章

  1. 前端笔记:边框和阴影
  2. 如果被执行人对法院判决的强制执行拖延时间拒不执行,怎么办?
  3. java mx150显卡够了吗_MX150显卡怎么样 MX150相当于什么显卡
  4. linux常用rootkit技术,unix下的 rootkit
  5. IOS应用开发之自动旋转与调整大小
  6. Python进度条实现
  7. Gorm 相关问题记录
  8. Android Cocos2dx引擎 prv.ccz/plist/so等优化缓存文件,手把手ida教你逆向project反编译apk库等文件...
  9. 干扰网络信号的app_解决Wi-Fi无线信号干扰的方法
  10. python数组的定义和基本使用