深度学习入门之猫vs狗(超简单)
学习深度学习需要从简单模型入手,可以选择手写字识别或者猫vs狗数据入手。
这篇文章从猫和狗的识别入手对深度学习有一个简单的认知, 最后可以输入自己的图片做测试。
文章结构如下:
- 猫狗数据集
- 数据预处理
- 模型加载与训练
- 输入自己的数据做测试
深度学习框架
文章用的是tensorflow2.0版本的深度学习框架。所以开始之前需要下载python,安装tensorflow2.0的库。
1.猫狗数据集
数据集导入
tensorflow_datasets中有许多数据集,我们训练用的猫狗数据集就从tensorflow_datasets中引入即可。
import tensorflow_datasets as tfds
tfds.disable_progress_bar()SPLIT_WEIGHTS = (8, 1, 1) # 将数据集按8:1:1分为训练集,验证集,测试集。
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)# 加载数据集
(raw_train, raw_validation, raw_test), metadata = tfds.load('cats_vs_dogs', split=list(splits),with_info=True, as_supervised=True)print(raw_train)
print(raw_validation)
print(raw_test)
打印的结果是训练集,验证集,测试集的shape,它们都是三通道的数组与标签组成,形如((None, None, 3), (1)),如下所示:
<_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
<_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
<_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
数据查看
先看看数据长啥样吧。
代码中metadata是数据自带的元数据,这里是cat或者dog。可以不用关心这个数据。
raw_train.take(2)表示取两个数据。
get_label_name = metadata.features['label'].int2strfor image, label in raw_train.take(2):print(“label = ”,label)plt.figure()plt.imshow(image)plt.title(get_label_name(label))
输出的tf.Tensor(1, shape=(), dtype=int64)表示这是一个tensor变量(不清楚可以百度),并且值为1,这里1是狗,0是猫。
如下所示:
至此,数据集的载入完成,接下来是对数据做一些预处理,做归一化与resize等。
2. 数据预处理
由于图片大小不一样,所以我们在训练时需要将其resize到一个大小,以便使用批处理,至于为什么用批处理,可以看数据预处理部分。
注意这里对数据做的几个处理,一是将数据类型变为float32(tensorflow模型输入的统一数据类型), 二是做归一化,将输入图片的像素值变为(0,1), 最后是做resize。
tensorflow2.0对数据的处理方面也很人性化,很容易理解,如下的raw_train为一个dataset类型,可以调用.map对它做映射,.shuffle打乱顺序,.batch设置批量。对自己的数据集做预处理载入可以看说明文档。
# 图片预处理IMG_SIZE = 160 # All images will be resized to 160x160def format_example(image, label):image = tf.cast(image, tf.float32)image = (image/127.5) - 1image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))return image, labeltrain = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)
3. 模型加载与训练
上面做好了模型的输入后,接下来就要进行模型的加载与训练了。模型加载完了,激动人心的测试阶段还会远吗。
加载内置模型作为基础模型
先加载MobileNetV2模型作为卷积层。
加载tensorflow内置的模型方法,以及参数可以参看模型加载部分。
# 加载模型
import tensorflow as tf
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)# 以 MobileNet V2为基础模型。
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,include_top=False,weights='imagenet')
feature_batch = base_model(image_batch)
添加全连接层
在上述模型的基础上添加全连接层。
在制作模型时涉及的函数参数可以参看tensorflow2.0 api。其中Sequential是一个序列,模型按照序列中的顺序执行。
# 修改模型global_average_layer = tf.keras.layers.GlobalAveragePooling2D()# 分类层, Dense中的参数为输出的类别数量,这里分1类,即只识别狗。
prediction_layer = keras.layers.Dense(1)model = tf.keras.Sequential([base_model,global_average_layer,prediction_layer
])
模型训练
tensorflow2.0的模型训练过程就特别人性化,特别简单了,很多参数已经默认,一般需要注意输入与输出即可,这里输入用dataset,输出为预测值。后面再细看输出是什么。
base_learning_rate = 0.0001 # 学习率,代表每次优化的大小,一般1e-3与1e-4比较合适initial_epochs = 5 # 训练的轮数model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),loss='binary_crossentropy',metrics=['accuracy'])model.fit(train_batches,epochs=initial_epochs,validation_data=validation_batches)weight_path = os.path.join('cat_vs_dogs')
model.save_weights(weight_path) # 保存模型参数
4. 输入自己的数据做测试
激动人心的时刻终于来了,接下来用自己的照片验证模型识别效果,我百度了几张猫和狗的照片作为例子验证。
from PIL import Imageweight_path = os.path.join('cat_vs_dogs')
model.load_weights(weight_path)def format_test(image):#对输入做预处理image = np.array(image)image = tf.cast(image, tf.float32)image = (image/127.5) - 1image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))return imageimage = Image.open("dog2.jpg")
plt.imshow(image)
test_image = format_test(image)
test_image = test_image[np.newaxis,:,:,:] # 给输入增加一个维度变为[1, h, w, c], 这里batch为1pred = model.predict(test_image)
print(pred)# 结果大于0为狗,小于0为猫
pred[pred>0]=1
pred[pred<=0]=0
print(pred)
结果中的predict为[batch, 1]维数组,在文章的例子中,batch取1,所以一次训练一张图。同时,由于二分类可以只识别狗,概率大于0的是狗,小于0的是猫, 这是模型中的分类层决定的。
结果:
猫狗识别是比较简单的分类网络,输入的图片与label分别为n通道的图与数字, 输出的为各个类别的概率列表。
深度学习入门之猫vs狗(超简单)相关推荐
- 深度学习入门的建议_来自《简单粗暴Tensorflow2》
参考资料与推荐阅读 一.如果你是一名在校大学生,具有较好的数学基础,可以从以下教材入手,作为学习机器学习的起点: 二.如果你希望更具实践性的内容,推荐以下书籍: 三.如果你对大学的知识已经生疏,或者还 ...
- 深度学习入门-误差反向传播法(人工神经网络实现mnist数据集识别)
文章目录 误差反向传播法 5.1 链式法则与计算图 5.2 计算图代码实践 5.3激活函数层的实现 5.4 简单矩阵求导 5.5 Affine 层的实现 5.6 softmax-with-loss层计 ...
- 深度学习入门之PyTorch学习笔记:深度学习介绍
深度学习入门之PyTorch学习笔记:深度学习介绍 绪论 1 深度学习介绍 1.1 人工智能 1.2 数据挖掘.机器学习.深度学习 1.2.1 数据挖掘 1.2.2 机器学习 1.2.3 深度学习 第 ...
- 模块一:深度学习入门算法
模块一:深度学习入门算法 1.深度学习必备知识 1.1深度学习要解决的问题 机器学习流程: 数据获取 -----> 特征工程 -----> 建立模型 ------> 评估与应用 特征 ...
- 深度学习入门笔记(二十):经典神经网络(LeNet-5、AlexNet和VGGNet)
欢迎关注WX公众号:[程序员管小亮] 专栏--深度学习入门笔记 声明 1)该文章整理自网上的大牛和机器学习专家无私奉献的资料,具体引用的资料请看参考文献. 2)本文仅供学术交流,非商用.所以每一部分具 ...
- 深度学习入门之PyTorch学习笔记:卷积神经网络
深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 4.1 主要任务及起源 4.2 卷积神经网络的原理和结构 4.2.1 卷积层 1. ...
- 深度学习入门笔记(五):神经网络的学习
专栏--深度学习入门笔记 推荐文章 深度学习入门笔记(一):机器学习基础 深度学习入门笔记(二):神经网络基础 深度学习入门笔记(三):感知机 深度学习入门笔记(四):神经网络 深度学习入门笔记(五) ...
- 给深度学习入门者的Python快速教程 - numpy和Matplotlib篇
转载自:https://zhuanlan.zhihu.com/p/24309547 本篇部分代码的下载地址: https://github.com/frombeijingwithlove/dlcv_f ...
- 深度学习入门(一)——深度学习是什么?
深度学习入门(一)--深度学习是什么? 看了标题,你心中或许已经有了疑惑.什么是深度学习?这和人工智能有什么关系吗?神经网络不是生物学知识吗?什么是全连接神经网络?如果你对本次技术分享内容足够感兴趣且 ...
- pytorch深度学习入门笔记
Pytorch 深度学习入门笔记 作者:梅如你 学习来源: 公众号: 阿力阿哩哩.土堆碎念 B站视频:https://www.bilibili.com/video/BV1hE411t7RN? 中国大学 ...
最新文章
- 表格检测开源网络推荐
- Spring Cloud alibaba版本对应
- Java 8中的策略模式
- jQuery常用的层次选择器
- Docker容器中的Linux机器快速设置国内源
- JS常用的设计模式(2)——简单工厂模式
- 你们期待的小屏旗舰来了: 骁龙855 没有刘海!
- vue-cli入门(四)——vue-resource登录注册实例
- webpack2.7.0配置不同的打包环境
- Fortran 注释符号
- linux两台服务器文件实时同步
- minecraft java文件夹_Minecraft游戏下载 文件结构说明
- java math 三角函数_Java中的三角函数
- Racket编程指南——1 欢迎来到Racket!
- 统计学第一类错误和第二类错误
- RHCE linux学习第一天
- 数智化转型中的零售餐饮行业
- 我们公司使用了 5 年的系统限流方案 ,从实现到部署实战详解,稳的一B
- 《英语口语900句 (624页+360分钟录音)》(Oral English 900 Expressions)
- 【其他】逻辑、逻辑推理概念
热门文章
- Java词法分析器的设计与实现
- 游戏直播用哪个录屏软件好?
- 一笔画: 表现绘画过程的美
- 通达信板块监控指标_板块监控及使用方法指标详解 通达信板块监控
- 用 VC2012 产生脱离VC运行库的 C/C++ 程序
- 银行对公账户编码规则
- arduino atmega328P MCP4725 proteus 仿真 程序
- 用puttygen工具把私钥id_rsa转换成公钥id_rsa.ppk
- 基于SSM的宠物领养系统(附源码)
- [nssl 1322][jzoj cz 2109] 清兵线 {dp}