TensorFlow2.0 Alpha版已经发布,在2.0中最重要的API或者说到处都出现的API是谁,那无疑是Keras。因此用过2.0的人都会吐槽全世界都是Keras。今天我们就来说说Keras这个高级API。

作者 | 汤兴旺

编辑 | 汤兴旺

1 Keras概述

在TensorFlow2.0中,Keras是一个用于构建和训练深度学习模型的高阶 API。因此如果你正在使用TensorFow2.0,那么使用Keras构建深度学习模型是您的不二选择。在Keras API中总共有如下三大块:

在Modules中有构建训练模型各种必备的组件,如激活函数activations、损失函数losses、优化器optimizers等;在Class中有Sequential和Model两个类,它们用来堆叠模型;在Functions中有Input()函数,它用来实例化张量。

因此若您使用的深度学习框架是TensorFlow,而且是2.0版本,那么你就不可能不使用tensorflow.keras。这也就是使用过TensorFlow2.0版本的都在吐槽全世界都是Keras的原因。

2 Modules

通过上面的介绍,我们知道在Modules中有activations、losses、optimizers等构建训练模型时各种必备的组件。下图就是Modules中有所的模块。

下面我们详细说说里面最常见的几个模块应该如何使用。

1. 常用的数据集(datasets)


在TensorFlow2.0中,常用的数据集需要使用tf.keras.datasets来加载,在datasets中有如下数据集。

对于上图中的数据集我们可以像下面这样加载

(train_images,train_labels),(test_images,test_labels)= keras.datasets.fashion_mnist.load_data()

当然我们平时使用的数据集肯定不在于此,这些数据集都是些最基础的数据集。对于自己的数据如何读取,请期待我们下次的分享。

2. 神经网络层(Layers)


在构建深度学习网络模型时,我们需要定制各种各样的层结构。这时候就要用到layers了,下图是TensorFlow2.0中部分层,它们都是Layer的子类。

那么我们如何使用layer来构建模型呢?方法如下:

from tensorflow.keras import layers

layers.Conv2D()

layers.MaxPool2D()

layers.Flatten()

layers.Dense()

3. 激活函数(Optimizers)


在构建深度学习网络时,我们经常需要选择激活函数来使网络的表达能力更强。下面将介绍TensorFlow2.0中的激活函数及它们应该在TensorFlow2.0中该如何使用。下图是TensorFlow2.0中部分激活函数:

from tensorflow.keras import layers

layers.Conv2D(...,activation='relu')

layers.Dense(...,activation='softmax')

4. 优化器(activations)


通常当我们准备好数据,设计好模型后,我们就需要选择一个合适的优化器(Optimizers)对模型进行优化。下面将介绍TensorFlow2.0中的优化器及他们应该在TensorFlow2.0中该如何使用。下图是TensorFlow2.0中所有的优化器,它们都是Optimizer的子类。

对于优化器的使用你可以像下面这样使用:

optimizers = tf.keras.optimizers.Adam()

optimizers = tf.keras.optimizers.SGD()

...


5. 损失函数(Losses)


我们知道当我们设计好模型时我们需要优化模型,所谓的优化就是优化网络权值使损失函数值变小,但是损失函数变小是否能代表精度越高呢?那么多的损失函数,我们又该如何选择呢?接下来我们了解下在TensorFlow2.0中如何使用损失函数。下图是TensorFlow2.0中所有的损失函数,它们都是Loss的子类。

对于损失函数的使用你可以像下面这样使用:

loss = tf.keras.losses.SparseCategoricalCrossentropy()

loss = tf.keras.losses.mean_squared_error()

...

3 Class

在Class中有Sequential和Model两个类,它们分别是用来堆叠网络层和把堆叠好的层实例化可以训练的模型。


1. Model


对于实例化Model有下面两种方法

(1).使用keras.Model API

import tensorflow as tftf.keras.Model(inputs=inputs, outputs=outputs)

(2).继承Model类

import tensorflow as tfclass MyModel(tf.keras.Model):


2. Sequential


在TensorFlow2.0中,我们可以使用Sequential模型。具体方式如下:

model = keras.Sequential()

model = model.add(layers.Conv2D(input_shape=(x_train.shape[1], x_train.shape[2],x_train.shape[3]),filters=32,kernel_size=(3,3), strides=(1,1), padding='valid',activation='relu'))

model.add(layers.MaxPool2D(pool_size=(2,2)))

model.add(layers.Flatten())model.add(layers.Dense(32,activation='relu'))

model.add(layers.Dense(10, activation='softmax'))

model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(),metrics=['accuracy'])

4 Functions

在Functions中,有一个Input函数,其用来实例化Keras张量。对于Input函数,它有如下参数

tf.keras.Input(

具体方法如下:

x = Input(shape=(32,))

5 简单的图像分类模型实例

#1导入相应的API

import tensorflow as tf

from tensorflow import keras

from tensorflow.keras import layers

#2加载数据

(train_images,train_labels),(test_images,test_labels)= keras.datasets.fashion_mnist.load_data()

#3构建网络

model = keras.Sequential()

model = model.add(layers.Conv2D(input_shape=(x_train.shape[1], x_train.shape[2],x_train.shape[3]),filters=32,kernel_size=(3,3), strides=(1,1), padding='valid',activation='relu'))

model.add(layers.MaxPool2D(pool_size=(2,2)))

model.add(layers.Flatten())model.add(layers.Dense(32,activation='relu'))

model.add(layers.Dense(10, activation='softmax'))

model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(),metrics=['accuracy'])

#4模型显示

model.summary()

#5模型训练

model_train=model.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.1)

总结

在本讲中,我们简单的了解了TensorFlow2.0中高级API Keras是如何使用的,我们可以看到Keras真的是无处不在,如果你想学好TensorFlow2.0,那么你必须掌握好Kears。

下期预告:如何读取自己的数据集及数据的使用。

最近直播

今日看图猜技术

网络结构

更多精彩内容请关注知乎专栏《有三AI学院》

转载文章请后台联系

侵权必究

往期精选

  • 【TensorFlow2.0】TensorFlow2.0专栏上线,你来吗?

  • 【AI初识境】从3次人工智能潮起潮落说起

  • 【AI初识境】从头理解神经网络-内行与外行的分水岭

  • 【AI初识境】近20年深度学习在图像领域的重要进展节点

  • 【AI初识境】激活函数:从人工设计到自动搜索

  • 【AI初识境】什么是深度学习成功的开始?参数初始化

  • 【AI初识境】深度学习模型中的Normalization,你懂了多少?

  • 【AI初识境】为了围剿SGD大家这些年想过的那十几招

  • 【AI初识境】被Hinton,DeepMind和斯坦福嫌弃的池化,到底是什么?

  • 【AI初识境】如何增加深度学习模型的泛化能力

  • 【AI初识境】深度学习模型评估,从图像分类到生成模型

  • 【AI初识境】深度学习中常用的损失函数有哪些?

  • 【AI初识境】给深度学习新手做项目的10个建议

  • 【AI不惑境】数据压榨有多狠,人工智能就有多成功

  • 【AI不惑境】网络深度对深度学习模型性能有什么影响?

  • 【AI不惑境】网络的宽度如何影响深度学习模型的性能?

  • 【AI不惑境】学习率和batchsize如何影响模型的性能?

  • 【完结】深度学习CV算法工程师从入门到初级面试有多远,大概是25篇文章的距离

  • 【完结】优秀的深度学习从业者都有哪些优秀的习惯

  • 【完结】给新手的12大深度学习开源框架快速入门项目

  • 【完结】总结12大CNN主流模型架构设计思想

  • 创业第一天,有三AI扔出了深度学习的150多篇文章和10多个专栏

  • 言有三新书预售,不贵,有料

  • 这个春天,有三最后一月的学习“季划”招生

  • 有三AI VIP会员发售,你的私人AI顾问已上线

  • 有三AI知识星球官宣,BAT等大咖等你来撩

  • 有三AI小程序上线,把你的代码show给世界

  • 揭秘7大AI学习板块,这个星球推荐你拥有

【TensorFlow2.0】以后我们再也离不开Keras了?相关推荐

  1. tensorflow2.0——预测泰坦尼克号旅客生存概率(Keras应用实践)

    一.数据准备 1.导入相关的库 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import pa ...

  2. internetreadfile读取数据长度为0_【完结】TensorFlow2.0 快速上手手册

    大家好,这是专栏<TensorFlow2.0>的第五篇文章,我们对专栏<TensorFlow2.0>进行一个总结. 我们知道全新的TensorFlow2.0 Alpha已经于2 ...

  3. 【TensorFlow2.0】如何搭建网络模型?

    大家好,这是专栏<TensorFlow2.0>的第四篇文章,讲述网络模型的搭建. 我们知道在不考虑输入层的情况下,一个典型的卷积神经网络通常由若干个卷积层.激活层.池化层及全连接层组成,无 ...

  4. tensorflow2.0基础简介

    tensorflow2.0简介 1.tensorflow 2.0基础知识简介 tensorflow2.0是谷歌在2019年3月份发布更新的一款到端开源机器学习平台,其目的在于优化tensorflow1 ...

  5. 【TensorFlow2.0】数据读取与使用方式

    大家好,这是专栏<TensorFlow2.0>的第三篇文章,讲述如何使用TensorFlow2.0读取和使用自己的数据集. 如果您正在学习计算机视觉,无论你通过书籍还是视频学习,大部分的教 ...

  6. pip更新失败_最全Tensorflow2.0 入门教程持续更新

    最全Tensorflow 2.0 入门教程持续更新: Doit:最全Tensorflow 2.0 入门教程持续更新​zhuanlan.zhihu.com 完整tensorflow2.0教程代码请看ht ...

  7. TensorFlow2.0学习笔记2-tf2.0两种方式搭建神经网络

    目录 一,TensorFlow2.0搭建神经网络八股 1)import  [引入相关模块] 2)train,test  [告知喂入网络的训练集测试集以及相应的标签] 3)model=tf.keras. ...

  8. tensorflow2.0教程- Keras 快速入门

    tensorflow2.0教程-tensorflow.keras 快速入门 Tensorflow 2.0 教程持续更新: https://blog.csdn.net/qq_31456593/artic ...

  9. 【TensorFlow2.0】(6) 数据统计,范数、最值、求和、均值、最值位置、唯一值、张量比较

    各位同学好,今天和大家分享一下TensorFlow2.0中的数据分析操作.内容有: (1)范数 tf.norm():(2)最值 tf.reduce_min(), tf.reduce_max()(3)求 ...

最新文章

  1. klee错误汇报二:KLEE的optimize选项的一个困惑
  2. PTA团体程序设计天梯赛篇(四)----几何+算法专题
  3. RabbitMQ系列笔记work模式
  4. java在线支付---13.java在线支付所有源码:
  5. 收下这份说明书,原来迈进智能计算的大门如此简单
  6. java删不了_java – 为什么我不能删除项目?
  7. python读取npy文件
  8. MQ消息队列常用命令
  9. 电脑翻页时钟屏保Fliqlo
  10. html登录qq页面无法显示,如何解决QQ可以登录网页而无法打开的问题
  11. 信息学奥赛一本通 1183:病人排队 | OpenJudge NOI 1.10 08:病人排队
  12. UOS 加锁文件夹/文件之解锁
  13. PDF里面复制出来的文章,在word里去掉回车符
  14. java lambda 反射_反射调用与Lambda表达式调用
  15. 小船过河(贪心算法)
  16. 【思考一】Android程序员想做手机游戏开发
  17. 推荐几款文字翻译软件,快速实现翻译
  18. 字节跳动面试:Android-系统预设-App,有什么难的?
  19. 车牌、手机、身份证、等敏感信息 屏蔽 替换 、中文转unicode编码 函数
  20. enq:TM-contention

热门文章

  1. java基础之集合类
  2. 观察者模式(Observer) 简介
  3. abb限位开关已打开drv1_施工升降机上有10个限位器,你都知道了吗?
  4. 小学文凭有计算机知识,重大版小学信息技术毕业复习题
  5. 计算机图形学多边形填充代码_零基础学计算机图形学太难?或许你缺的只是一本好书...
  6. mysql 排名_微服务架构下,如何利用Mysql的limit配合orderby进行排名统计
  7. 您有一份阿里云云原生直播攻略待查收
  8. 课时 11:可观测性:你的应用健康吗?(莫源)
  9. 阿里巴巴如何改善开发人员在 K8s 上的体验?
  10. Knative 基本功能深入剖析:Knative Serving 的流量灰度和版本管理