导入工具包

import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

首先得准备好两类分类的图片,并且准备好分开两个文件夹

这里是有2000张训练集 和 1000张验证集
然后指定好数据路径(训练集和验证集)

# 数据所在文件夹
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')# 训练集
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')# 验证集
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')

构建卷积神经网络模型

几层都可以,大家可以随意玩
如果用CPU训练,可以把输入设置的更小一些,一般输入大小更主要的决定了训练速度

#构造卷积神经网络模型
#对于GPU可以输入224*224
#对于CPU输入64*64,速度可以快10倍#以下第一种方法科学 直观展示summary
model = tf.keras.models.Sequential([# 如果训练慢,可以把数据设置的更小一些tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(64, 64, 3)),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),tf.keras.layers.MaxPooling2D(2, 2),# 为全连接层准备tf.keras.layers.Flatten(),tf.keras.layers.Dense(512, activation='relu'),# 二分类sigmoid就够了tf.keras.layers.Dense(1, activation='sigmoid')
])
"""第二种构建方法
model = tf.keras.models.Sequential([# 如果训练慢,可以把数据设置的更小一些tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),tf.keras.layers.MaxPooling2D(2, 2),# 为全连接层准备tf.keras.layers.Flatten(),tf.keras.layers.Dense(512, activation='relu'),# 二分类sigmoid就够了tf.keras.layers.Dense(1, activation='sigmoid')
])
model.build(input_shape=(None,64, 64, 3))
"""
model.summary()

配置训练器

model.compile(loss='binary_crossentropy',optimizer=Adam(lr=1e-4),metrics=['acc'])

数据归一化预处理

#生成批次!! 数据生成器
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

训练网络模型

直接fit也可以,但是通常咱们不能把所有数据全部放入内存,fit_generator相当于一个生成器,动态产生所需的batch数据
steps_per_epoch相当给定一个停止条件,因为生成器会不断产生batch数据,说白了就是它不知道一个epoch里需要执行多少个step

train_generator = train_datagen.flow_from_directory(train_dir,  # 文件夹路径target_size=(64, 64),  # 指定resize成的大小batch_size=20,# 如果one-hot就是categorical,二分类用binary就可以class_mode='binary')validation_generator = test_datagen.flow_from_directory(validation_dir,target_size=(64, 64),batch_size=20,class_mode='binary')history = model.fit_generator(train_generator,steps_per_epoch=100,  # 2000 images = batch_size * steps_per_epochepochs=20,validation_data=validation_generator,validation_steps=50,  # 1000 images = batch_size * validation_stepsverbose=2)

训练输出参数

CPU上训练会稍慢~~

Epoch 1/20
100/100 - 16s - loss: 0.6913 - acc: 0.5235 - val_loss: 0.6750 - val_acc: 0.5390
Epoch 2/20
100/100 - 15s - loss: 0.6707 - acc: 0.5810 - val_loss: 0.6553 - val_acc: 0.6310
Epoch 3/20
100/100 - 15s - loss: 0.6340 - acc: 0.6365 - val_loss: 0.6229 - val_acc: 0.6530
Epoch 4/20
100/100 - 16s - loss: 0.6038 - acc: 0.6780 - val_loss: 0.5941 - val_acc: 0.6860
Epoch 5/20
100/100 - 15s - loss: 0.5789 - acc: 0.6900 - val_loss: 0.5841 - val_acc: 0.6930
Epoch 6/20
100/100 - 14s - loss: 0.5368 - acc: 0.7335 - val_loss: 0.5678 - val_acc: 0.6850
Epoch 7/20
100/100 - 13s - loss: 0.5087 - acc: 0.7530 - val_loss: 0.5594 - val_acc: 0.7110
Epoch 8/20
100/100 - 13s - loss: 0.4876 - acc: 0.7620 - val_loss: 0.5643 - val_acc: 0.7150
Epoch 9/20
100/100 - 13s - loss: 0.4568 - acc: 0.7910 - val_loss: 0.5871 - val_acc: 0.7050
Epoch 10/20
100/100 - 13s - loss: 0.4261 - acc: 0.8155 - val_loss: 0.5421 - val_acc: 0.7230
Epoch 11/20
100/100 - 15s - loss: 0.3908 - acc: 0.8315 - val_loss: 0.5770 - val_acc: 0.7010
Epoch 12/20
100/100 - 17s - loss: 0.3952 - acc: 0.8295 - val_loss: 0.5398 - val_acc: 0.7340
Epoch 13/20
100/100 - 17s - loss: 0.3583 - acc: 0.8435 - val_loss: 0.5411 - val_acc: 0.7340
Epoch 14/20
100/100 - 16s - loss: 0.3320 - acc: 0.8670 - val_loss: 0.5619 - val_acc: 0.7190
Epoch 15/20
100/100 - 15s - loss: 0.3046 - acc: 0.8785 - val_loss: 0.5582 - val_acc: 0.7410
Epoch 16/20
100/100 - 15s - loss: 0.2794 - acc: 0.8990 - val_loss: 0.5551 - val_acc: 0.7410
Epoch 17/20
100/100 - 15s - loss: 0.2554 - acc: 0.9080 - val_loss: 0.5435 - val_acc: 0.7530
Epoch 18/20
100/100 - 15s - loss: 0.2445 - acc: 0.9055 - val_loss: 0.5596 - val_acc: 0.7440
Epoch 19/20
100/100 - 14s - loss: 0.2198 - acc: 0.9235 - val_loss: 0.5629 - val_acc: 0.7510
Epoch 20/20
100/100 - 14s - loss: 0.1923 - acc: 0.9385 - val_loss: 0.5818 - val_acc: 0.7430

效果展示

import matplotlib.pyplot as plt"""使用训练时的准确度变化进行画图"""
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(acc))plt.plot(epochs, acc, 'bo', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')plt.figure()plt.plot(epochs, loss, 'bo', label='Training Loss')
plt.plot(epochs, val_loss, 'b', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()plt.show()


可以看到训练集的训练效果非常理想,可验证集却效果非常不好。
到底什么原因呢?
这就是神经网络的通病:

过拟合~

接下来我会在另一篇博客介绍如何解决过拟合这个毛病的方法。Click~
首先先使用最直观的数据增强。
Click~

对于以上代码的训练集进行一个数据增强

#train_datagen = ImageDataGenerator(rescale=1./255)
train_datagen = ImageDataGenerator(rescale=1./255,rotation_range=40,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,fill_mode='nearest')

效果展示:
可以看到效果好了很多~~~

然后再加额外的dropout层看看效果~

model = tf.keras.models.Sequential([# 如果训练慢,可以把数据设置的更小一些tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(64, 64, 3)),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),tf.keras.layers.MaxPooling2D(2, 2),# 为全连接层准备tf.keras.layers.Flatten(),tf.keras.layers.Dense(512, activation='relu'),tf.keras.layers.Dropout(0.5), #杀死率为0.5 dropout一般加在全连接层后面# 二分类sigmoid就够了tf.keras.layers.Dense(1, activation='sigmoid')
])


深度学习框架tensorflow二实战(训练一个简单二分类模型)相关推荐

  1. 2_初学者快速掌握主流深度学习框架Tensorflow、Keras、Pytorch学习代码(20181211)

    初学者快速掌握主流深度学习框架Tensorflow.Keras.Pytorch学习代码 一.TensorFlow 1.资源地址: 2.资源介绍: 3.配置环境: 4.资源目录: 二.Keras 1.资 ...

  2. 深度学习框架 TensorFlow:张量、自动求导机制、tf.keras模块(Model、layers、losses、optimizer、metrics)、多层感知机(即多层全连接神经网络 MLP)

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 安装 TensorFlow2.CUDA10.cuDNN7.6. ...

  3. DL框架之TensorFlow:深度学习框架TensorFlow Core(低级别TensorFlow API)的简介、安装、使用方法之详细攻略

    DL框架之TensorFlow:TensorFlow Core(低级别TensorFlow API)的简介.安装.使用方法之详细DL框架之TensorFlow:深度学习框架TensorFlow Cor ...

  4. DL框架之Tensorflow:深度学习框架Tensorflow的简介、安装、使用方法之详细攻略

    DL框架之Tensorflow:深度学习框架Tensorflow的简介.安装.使用方法之详细攻略 目录 Tensorflow的简介 1.描述 2.TensorFlow的六大特征 3.了解Tensorf ...

  5. 深度学习框架tensorflow学习与应用——代码笔记11(未完成)

    11-1 第十周作业-验证码识别(未完成) #!/usr/bin/env python # coding: utf-8# In[1]:import os import tensorflow as tf ...

  6. 深度学习框架TensorFlow系列之(五)优化器1

    1 背景 梯度下降算法是目前最流行的优化算法之一,并且被用来优化神经网络的模型.业界知名的深度学习框架TensorFlow.Caffe等均包含了各种关于梯度下降优化器的实现.然而这些优化器经常被用作黑 ...

  7. 快速了解深度学习框架--tensorflow(更新中)

    深度学习框架(工具)简单来说即库,需要import,比如tensorflow,Caffe- 深度学习框架提供了一系列的深度学习的组件(对于通用的算法,里面会有实现),当需要使用新的算法的时候就需要用户 ...

  8. 4.1 深度学习框架-TensorFlow

    4.1 深度学习框架-TensorFlow 学习目标 目标 了解Tensorflow框架的组成.接口 了解TensorFlow框架的安装 知道tf.keras的特点和使用 应用 无 4.1.1 常见深 ...

  9. DL框架:主流深度学习框架(TensorFlow/Pytorch/Caffe/Keras/CNTK/MXNet/Theano/PaddlePaddle)简介、多个方向比较、案例应用之详细攻略

    DL框架:主流深度学习框架(TensorFlow/Pytorch/Caffe/Keras/CNTK/MXNet/Theano/PaddlePaddle)简介.多个方向比较.案例应用之详细攻略 目录 深 ...

  10. TensorFlow:深度学习框架TensorFlow TensorFlow_GPU的简介、安装、测试之详细攻略

    TensorFlow:深度学习框架TensorFlow & TensorFlow_GPU的简介.安装.测试之详细攻略 目录 TensorFlow的简介 TensorFlow的安装 1.tens ...

最新文章

  1. 自带数据线的迷你数显充电宝,旅途必备
  2. canvas rotate 累加旋转_【教研动态】音乐活动中,使用材料累加情境的适宜性
  3. 同一个事务里面对同一条数据做2次修改_要我说,多线程事务它必须就是个伪命题!
  4. 研究UEVENT相关东西,看到2篇优秀的博文,转载与此
  5. VS2015中的项目类图
  6. docker使用阿里云镜像仓库docker
  7. 手机应用开发者必看:移动开发者大势图
  8. 物理内存管理之zone详解
  9. 组建BXP无盘网络 (一)
  10. win10无法执行vbs脚本
  11. 微信小程序中实现人脸识别认证
  12. social network 学习心得
  13. PS学习之动态表情制作
  14. linux服务器默认多久断开ssh,SSH超时自动断开问题解决
  15. unity 3d开发的大型网络游戏
  16. 自媒体必备工具合集分享
  17. 蓝桥--不同非空子串
  18. Python每日一练-----整数转罗马数字
  19. 解决linux下.AppImage文件无法运行问题
  20. 使用WebRTC搭建前端视频聊天室——点对点通信篇

热门文章

  1. c语言统计链表值的总合,C语言链表综合操作
  2. 微软、阿里云们的下一个十年:深耕政企市场,打破现有格局
  3. 如何解决animate运行时提示,应该为在运行时可能编辑的任何文本嵌入字体,具有使用设备字体设置的文本除外。“
  4. ArcGIS教程:了解冲突解决和制图综合
  5. 【博客282】udp socket的recvfrom函数的一个易错问题
  6. 4G通讯NFC读卡器|读写器ACR123U-C8性能与应用攻略
  7. Python第四天作业
  8. 如何清除chrome浏览器缓存
  9. Google浏览器强制刷新、清楚缓存(其他浏览器应该也行)
  10. 出现开机慢、开机黑屏长时间的进(转至卡饭论坛,帖子最早出现在爱毒霸社区论坛)之二