学习深度学习需要从简单模型入手,可以选择手写字识别或者猫vs狗数据入手。
  这篇文章从猫和狗的识别入手对深度学习有一个简单的认知, 最后可以输入自己的图片做测试。
  文章结构如下:

  1. 猫狗数据集
  2. 数据预处理
  3. 模型加载与训练
  4. 输入自己的数据做测试

深度学习框架

  文章用的是tensorflow2.0版本的深度学习框架。所以开始之前需要下载python,安装tensorflow2.0的库。

1.猫狗数据集

数据集导入

  tensorflow_datasets中有许多数据集,我们训练用的猫狗数据集就从tensorflow_datasets中引入即可。

import tensorflow_datasets as tfds
tfds.disable_progress_bar()SPLIT_WEIGHTS = (8, 1, 1)  # 将数据集按8:1:1分为训练集,验证集,测试集。
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)# 加载数据集
(raw_train, raw_validation, raw_test), metadata = tfds.load('cats_vs_dogs', split=list(splits),with_info=True, as_supervised=True)print(raw_train)
print(raw_validation)
print(raw_test)

  打印的结果是训练集,验证集,测试集的shape,它们都是三通道的数组与标签组成,形如((None, None, 3), (1)),如下所示:

<_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
<_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
<_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>

数据查看

  先看看数据长啥样吧。
  代码中metadata是数据自带的元数据,这里是cat或者dog。可以不用关心这个数据。
  raw_train.take(2)表示取两个数据。

get_label_name = metadata.features['label'].int2strfor image, label in raw_train.take(2):print(“label = ”,label)plt.figure()plt.imshow(image)plt.title(get_label_name(label))

  输出的tf.Tensor(1, shape=(), dtype=int64)表示这是一个tensor变量(不清楚可以百度),并且值为1,这里1是狗,0是猫。
  如下所示:

至此,数据集的载入完成,接下来是对数据做一些预处理,做归一化与resize等。

2. 数据预处理

  由于图片大小不一样,所以我们在训练时需要将其resize到一个大小,以便使用批处理,至于为什么用批处理,可以看数据预处理部分。
  注意这里对数据做的几个处理,一是将数据类型变为float32(tensorflow模型输入的统一数据类型), 二是做归一化,将输入图片的像素值变为(0,1), 最后是做resize。
  tensorflow2.0对数据的处理方面也很人性化,很容易理解,如下的raw_train为一个dataset类型,可以调用.map对它做映射,.shuffle打乱顺序,.batch设置批量。对自己的数据集做预处理载入可以看说明文档。

# 图片预处理IMG_SIZE = 160 # All images will be resized to 160x160def format_example(image, label):image = tf.cast(image, tf.float32)image = (image/127.5) - 1image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))return image, labeltrain = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)

3. 模型加载与训练

上面做好了模型的输入后,接下来就要进行模型的加载与训练了。模型加载完了,激动人心的测试阶段还会远吗。

加载内置模型作为基础模型

先加载MobileNetV2模型作为卷积层。
加载tensorflow内置的模型方法,以及参数可以参看模型加载部分。

# 加载模型
import tensorflow as tf
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)# 以 MobileNet V2为基础模型。
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,include_top=False,weights='imagenet')
feature_batch = base_model(image_batch)

添加全连接层

在上述模型的基础上添加全连接层。
在制作模型时涉及的函数参数可以参看tensorflow2.0 api。其中Sequential是一个序列,模型按照序列中的顺序执行。

# 修改模型global_average_layer = tf.keras.layers.GlobalAveragePooling2D()# 分类层, Dense中的参数为输出的类别数量,这里分1类,即只识别狗。
prediction_layer = keras.layers.Dense(1)model = tf.keras.Sequential([base_model,global_average_layer,prediction_layer
])

模型训练

tensorflow2.0的模型训练过程就特别人性化,特别简单了,很多参数已经默认,一般需要注意输入与输出即可,这里输入用dataset,输出为预测值。后面再细看输出是什么。

base_learning_rate = 0.0001  # 学习率,代表每次优化的大小,一般1e-3与1e-4比较合适initial_epochs = 5  # 训练的轮数model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),loss='binary_crossentropy',metrics=['accuracy'])model.fit(train_batches,epochs=initial_epochs,validation_data=validation_batches)weight_path = os.path.join('cat_vs_dogs')
model.save_weights(weight_path)  # 保存模型参数

4. 输入自己的数据做测试

激动人心的时刻终于来了,接下来用自己的照片验证模型识别效果,我百度了几张猫和狗的照片作为例子验证。

from PIL import Imageweight_path = os.path.join('cat_vs_dogs')
model.load_weights(weight_path)def format_test(image):#对输入做预处理image = np.array(image)image = tf.cast(image, tf.float32)image = (image/127.5) - 1image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))return imageimage = Image.open("dog2.jpg")
plt.imshow(image)
test_image = format_test(image)
test_image = test_image[np.newaxis,:,:,:]  # 给输入增加一个维度变为[1, h, w, c], 这里batch为1pred = model.predict(test_image)
print(pred)# 结果大于0为狗,小于0为猫
pred[pred>0]=1
pred[pred<=0]=0
print(pred)

结果中的predict为[batch, 1]维数组,在文章的例子中,batch取1,所以一次训练一张图。同时,由于二分类可以只识别狗,概率大于0的是狗,小于0的是猫, 这是模型中的分类层决定的。
结果:


猫狗识别是比较简单的分类网络,输入的图片与label分别为n通道的图与数字, 输出的为各个类别的概率列表。

深度学习入门之猫vs狗(超简单)相关推荐

  1. 深度学习入门的建议_来自《简单粗暴Tensorflow2》

    参考资料与推荐阅读 一.如果你是一名在校大学生,具有较好的数学基础,可以从以下教材入手,作为学习机器学习的起点: 二.如果你希望更具实践性的内容,推荐以下书籍: 三.如果你对大学的知识已经生疏,或者还 ...

  2. 深度学习入门-误差反向传播法(人工神经网络实现mnist数据集识别)

    文章目录 误差反向传播法 5.1 链式法则与计算图 5.2 计算图代码实践 5.3激活函数层的实现 5.4 简单矩阵求导 5.5 Affine 层的实现 5.6 softmax-with-loss层计 ...

  3. 深度学习入门之PyTorch学习笔记:深度学习介绍

    深度学习入门之PyTorch学习笔记:深度学习介绍 绪论 1 深度学习介绍 1.1 人工智能 1.2 数据挖掘.机器学习.深度学习 1.2.1 数据挖掘 1.2.2 机器学习 1.2.3 深度学习 第 ...

  4. 模块一:深度学习入门算法

    模块一:深度学习入门算法 1.深度学习必备知识 1.1深度学习要解决的问题 机器学习流程: 数据获取 -----> 特征工程 -----> 建立模型 ------> 评估与应用 特征 ...

  5. 深度学习入门笔记(二十):经典神经网络(LeNet-5、AlexNet和VGGNet)

    欢迎关注WX公众号:[程序员管小亮] 专栏--深度学习入门笔记 声明 1)该文章整理自网上的大牛和机器学习专家无私奉献的资料,具体引用的资料请看参考文献. 2)本文仅供学术交流,非商用.所以每一部分具 ...

  6. 深度学习入门之PyTorch学习笔记:卷积神经网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 4.1 主要任务及起源 4.2 卷积神经网络的原理和结构 4.2.1 卷积层 1. ...

  7. 深度学习入门笔记(五):神经网络的学习

    专栏--深度学习入门笔记 推荐文章 深度学习入门笔记(一):机器学习基础 深度学习入门笔记(二):神经网络基础 深度学习入门笔记(三):感知机 深度学习入门笔记(四):神经网络 深度学习入门笔记(五) ...

  8. 给深度学习入门者的Python快速教程 - numpy和Matplotlib篇

    转载自:https://zhuanlan.zhihu.com/p/24309547 本篇部分代码的下载地址: https://github.com/frombeijingwithlove/dlcv_f ...

  9. 深度学习入门(一)——深度学习是什么?

    深度学习入门(一)--深度学习是什么? 看了标题,你心中或许已经有了疑惑.什么是深度学习?这和人工智能有什么关系吗?神经网络不是生物学知识吗?什么是全连接神经网络?如果你对本次技术分享内容足够感兴趣且 ...

  10. pytorch深度学习入门笔记

    Pytorch 深度学习入门笔记 作者:梅如你 学习来源: 公众号: 阿力阿哩哩.土堆碎念 B站视频:https://www.bilibili.com/video/BV1hE411t7RN? 中国大学 ...

最新文章

  1. 表格检测开源网络推荐
  2. Spring Cloud alibaba版本对应
  3. Java 8中的策略模式
  4. jQuery常用的层次选择器
  5. Docker容器中的Linux机器快速设置国内源
  6. JS常用的设计模式(2)——简单工厂模式
  7. 你们期待的小屏旗舰来了: 骁龙855 没有刘海!
  8. vue-cli入门(四)——vue-resource登录注册实例
  9. webpack2.7.0配置不同的打包环境
  10. Fortran 注释符号
  11. linux两台服务器文件实时同步
  12. minecraft java文件夹_Minecraft游戏下载 文件结构说明
  13. java math 三角函数_Java中的三角函数
  14. Racket编程指南——1 欢迎来到Racket!
  15. 统计学第一类错误和第二类错误
  16. RHCE linux学习第一天
  17. 数智化转型中的零售餐饮行业
  18. 我们公司使用了 5 年的系统限流方案 ,从实现到部署实战详解,稳的一B
  19. 《英语口语900句 (624页+360分钟录音)》(Oral English 900 Expressions)
  20. 【其他】逻辑、逻辑推理概念

热门文章

  1. Java词法分析器的设计与实现
  2. 游戏直播用哪个录屏软件好?
  3. 一笔画: 表现绘画过程的美
  4. 通达信板块监控指标_板块监控及使用方法指标详解 通达信板块监控
  5. 用 VC2012 产生脱离VC运行库的 C/C++ 程序
  6. 银行对公账户编码规则
  7. arduino atmega328P MCP4725 proteus 仿真 程序
  8. 用puttygen工具把私钥id_rsa转换成公钥id_rsa.ppk
  9. 基于SSM的宠物领养系统(附源码)
  10. [nssl 1322][jzoj cz 2109] 清兵线 {dp}