1. 基础
    TensorFlow 基础
    TensorFlow 模型建立与训练
    基础示例:多层感知机(MLP)
    卷积神经网络(CNN)
    循环神经网络(RNN)
    深度强化学习(DRL)
    Keras Pipeline
    自定义层、损失函数和评估指标
    常用模块 tf.train.Checkpoint :变量的保存与恢复
    常用模块 TensorBoard:训练过程可视化
    常用模块 tf.data :数据集的构建与预处理
    常用模块 TFRecord :TensorFlow 数据集存储格式
    常用模块 tf.function :图执行模式
    常用模块 tf.TensorArray :TensorFlow 动态数组
    常用模块 tf.config:GPU 的使用与分配

  2. 部署
    TensorFlow 模型导出
    TensorFlow Serving
    TensorFlow Lite

  3. 大规模训练与加速
    TensorFlow 分布式训练
    使用 TPU 训练 TensorFlow 模型

  4. 扩展
    TensorFlow Hub 模型复用
    TensorFlow Datasets 数据集载入

  5. 附录
    强化学习基础简介


目录

强化学习 (Reinforcement learning,RL)强调如何基于环境而行动,以取得最大化的预期利益。结合了深度学习技术后的强化学习(Deep Reinforcement learning,DRL)更是如虎添翼。近年广为人知的 AlphaGo 即是深度强化学习的典型应用。

可参考强化学习基础以获得强化学习的基础知识。

这里,我们使用深度强化学习玩 CartPole(倒立摆)游戏。倒立摆是控制论中的经典问题,在这个游戏中,一根杆的底部与一个小车通过轴相连,而杆的重心在轴之上,因此是一个不稳定的系统。在重力的作用下,杆很容易倒下。而我们则需要控制小车在水平的轨道上进行左右运动,以使得杆一直保持竖直平衡状态。

CartPole 游戏

我们使用 OpenAI 推出的 Gym 环境库 中的 CartPole 游戏环境,可使用 pip install gym 进行安装,具体安装步骤和教程可参考 官方文档 和 这里 。和 Gym 的交互过程很像是一个回合制游戏,我们首先获得游戏的初始状态(比如杆的初始角度和小车位置),然后在每个回合 t,我们都需要在当前可行的动作中选择一个并交由 Gym 执行(比如向左或者向右推动小车,每个回合中二者只能择一),Gym 在执行动作后,会返回动作执行后的下一个状态当前回合所获得的奖励值(比如我们选择向左推动小车并执行后,小车位置更加偏左,而杆的角度更加偏右,Gym 将新的角度和位置返回给我们。而如果杆在这一回合仍没有倒下,Gym 同时返回给我们一个小的正奖励)。这个过程可以一直迭代下去,直到游戏终止(比如杆倒下了)。在 Python 中,Gym 的基本调用方法如下:

import gymenv = gym.make('CartPole-v1')       # 实例化一个游戏环境,参数为游戏名称
state = env.reset()                 # 初始化环境,获得初始状态
while True:env.render()                    # 对当前帧进行渲染,绘图到屏幕action = model.predict(state)   # 假设我们有一个训练好的模型,能够通过当前状态预测出这时应该进行的动作next_state, reward, done, info = env.step(action)   # 让环境执行动作,获得执行完动作的下一个状态,动作的奖励,游戏是否已结束以及额外信息if done:                        # 如果游戏结束则退出循环break

那么,我们的任务就是训练出一个模型,能够根据当前的状态预测出应该进行的一个好的动作。粗略地说,一个好的动作应当能够最大化整个游戏过程中获得的奖励之和,这也是强化学习的目标。以 CartPole 游戏为例,我们的目标是希望做出合适的动作使得杆一直不倒,即游戏交互的回合数尽可能地多。而回合每进行一次,我们都会获得一个小的正奖励,回合数越多则累积的奖励值也越高。因此,我们最大化游戏过程中的奖励之和与我们的最终目标是一致的。

以下代码展示了如何使用深度强化学习中的 Deep Q-Learning 方法 [Mnih2013] 来训练模型。首先,我们引入 TensorFlow、Gym 和一些常用库,并定义一些模型超参数:

import tensorflow as tf
import numpy as np
import gym
import random
from collections import dequenum_episodes = 500              # 游戏训练的总episode数量
num_exploration_episodes = 100  # 探索过程所占的episode数量
max_len_episode = 1000          # 每个episode的最大回合数
batch_size = 32                 # 批次大小
learning_rate = 1e-3            # 学习率
gamma = 1.                      # 折扣因子
initial_epsilon = 1.            # 探索起始时的探索率
final_epsilon = 0.01            # 探索终止时的探索率

然后,我们使用 tf.keras.Model 建立一个 Q 函数网络(Q-network),用于拟合 Q Learning 中的 Q 函数。这里我们使用较简单的多层全连接神经网络进行拟合。该网络输入当前状态,输出各个动作下的 Q-value(CartPole 下为 2 维,即向左和向右推动小车)。

class QNetwork(tf.keras.Model):def __init__(self):super().__init__()self.dense1 = tf.keras.layers.Dense(units=24, activation=tf.nn.relu)self.dense2 = tf.keras.layers.Dense(units=24, activation=tf.nn.relu)self.dense3 = tf.keras.layers.Dense(units=2)def call(self, inputs):x = self.dense1(inputs)x = self.dense2(x)x = self.dense3(x)return xdef predict(self, inputs):q_values = self(inputs)return tf.argmax(q_values, axis=-1)

最后,我们在主程序中实现 Q Learning 算法。

if __name__ == '__main__':env = gym.make('CartPole-v1')       # 实例化一个游戏环境,参数为游戏名称model = QNetwork()optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)replay_buffer = deque(maxlen=10000) # 使用一个 deque 作为 Q Learning 的经验回放池epsilon = initial_epsilonfor episode_id in range(num_episodes):state = env.reset()             # 初始化环境,获得初始状态epsilon = max(                  # 计算当前探索率initial_epsilon * (num_exploration_episodes - episode_id) / num_exploration_episodes,final_epsilon)for t in range(max_len_episode):env.render()                                # 对当前帧进行渲染,绘图到屏幕if random.random() < epsilon:               # epsilon-greedy 探索策略,以 epsilon 的概率选择随机动作action = env.action_space.sample()      # 选择随机动作(探索)else:action = model.predict(np.expand_dims(state, axis=0)).numpy()   # 选择模型计算出的 Q Value 最大的动作action = action[0]# 让环境执行动作,获得执行完动作的下一个状态,动作的奖励,游戏是否已结束以及额外信息next_state, reward, done, info = env.step(action)# 如果游戏Game Over,给予大的负奖励reward = -10. if done else reward# 将(state, action, reward, next_state)的四元组(外加 done 标签表示是否结束)放入经验回放池replay_buffer.append((state, action, reward, next_state, 1 if done else 0))# 更新当前 statestate = next_stateif done:                                    # 游戏结束则退出本轮循环,进行下一个 episodeprint("episode %d, epsilon %f, score %d" % (episode_id, epsilon, t))breakif len(replay_buffer) >= batch_size:# 从经验回放池中随机取一个批次的四元组,并分别转换为 NumPy 数组batch_state, batch_action, batch_reward, batch_next_state, batch_done = zip(*random.sample(replay_buffer, batch_size))batch_state, batch_reward, batch_next_state, batch_done = \[np.array(a, dtype=np.float32) for a in [batch_state, batch_reward, batch_next_state, batch_done]]batch_action = np.array(batch_action, dtype=np.int32)q_value = model(batch_next_state)y = batch_reward + (gamma * tf.reduce_max(q_value, axis=1)) * (1 - batch_done)  # 计算 y 值with tf.GradientTape() as tape:loss = tf.keras.losses.mean_squared_error(  # 最小化 y 和 Q-value 的距离y_true=y,y_pred=tf.reduce_sum(model(batch_state) * tf.one_hot(batch_action, depth=2), axis=1))grads = tape.gradient(loss, model.variables)optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))       # 计算梯度并更新参数

对于不同的任务(或者说环境),我们需要根据任务的特点,设计不同的状态以及采取合适的网络来拟合 Q 函数。例如,如果我们考虑经典的打砖块游戏(Gym 环境库中的 Breakout-v0 ),每一次执行动作(挡板向左、向右或不动),都会返回一个 210 * 160 * 3 的 RGB 图片,表示当前屏幕画面。为了给打砖块游戏这个任务设计合适的状态表示,我们有以下分析:

  • 砖块的颜色信息并不是很重要,画面转换成灰度也不影响操作,因此可以去除状态中的颜色信息(即将图片转为灰度表示);
  • 小球移动的信息很重要,如果只知道单帧画面而不知道小球往哪边运动,即使是人也很难判断挡板应当移动的方向。因此,必须在状态中加入表征小球运动方向的信息。一个简单的方式是将当前帧与前面几帧的画面进行叠加,得到一个 210 * 160 * XX 为叠加帧数)的状态表示;
  • 每帧的分辨率不需要特别高,只要能大致表征方块、小球和挡板的位置以做出决策即可,因此对于每帧的长宽可做适当压缩。

而考虑到我们需要从图像信息中提取特征,使用 CNN 作为拟合 Q 函数的网络将更为适合。由此,将上面的 QNetwork 更换为 CNN 网络,并对状态做一些修改,即可用于玩一些简单的视频游戏。

【Tensorflow教程笔记】深度强化学习(DRL)相关推荐

  1. TensorFlow 2.0深度强化学习指南

    在本教程中,我将通过实施Advantage Actor-Critic(演员-评论家,A2C)代理来解决经典的CartPole-v0环境,通过深度强化学习(DRL)展示即将推出的TensorFlow2. ...

  2. 深度强化学习DRL训练指南和现存问题(D3QN(Dueling Double DQN))

    目录 参数 iteration episode epoch Batch_Size Experimence Replay Buffer经验回放缓存 Reward discount factor或gamm ...

  3. AI量化(代码):深度强化学习DRL应用于金融量化

    原创文章第93篇,专注"个人成长与财富自由.世界运作的逻辑, AI量化投资". 今天要说说强化学习. 强化学习个人认为,是最契合金融投资的范式.它其实不是一个具体的算法,而是一种范 ...

  4. 【入门教程】TensorFlow 2 模型:深度强化学习

    文 /  李锡涵,Google Developers Expert 本文节选自<简单粗暴 TensorFlow 2> 本文将介绍在 OpenAI 的 gym 环境下,使用 TensorFl ...

  5. 论文研读笔记(二)——通过深度强化学习避免碰撞的编队控制

    通过深度强化学习避免碰撞的编队控制(Formation Control with Collision Avoidance through Deep Reinforcement Learning) 文献 ...

  6. bootstrapt学习指南_TensorFlow 2.0深度强化学习指南

    摘要: 用深度强化学习来展示TensorFlow 2.0的强大特性! 在本教程中,我将通过实施Advantage Actor-Critic(演员-评论家,A2C)代理来解决经典的CartPole-v0 ...

  7. 基于深度强化学习的室内场景目标驱动视觉导航

    基于深度强化学习的室内场景目标驱动视觉导航 摘要 介绍 相关工作 AI2-THOR框架 目标驱动导航模型 A.问题陈述 B.公式问题 C.学习设置 D.模型 E.训练协议 F.网络架构 实验 A.导航 ...

  8. 论文研读——基于深度强化学习的自动驾驶汽车运动规划研究综述

    论文研读--Survey of Deep Reinforcement Learning for Motion Planning of Autonomous V ehicles 此篇文章为论文的学习笔记 ...

  9. 《强化学习周刊》第26期:UCL UC Berkeley发表深度强化学习中的泛化研究综述、JHU推出基于强化学习的人工决策模型...

    No.26 智源社区 强化学习组 强 化 学  习 研究 观点 资源 活动 关于周刊 强化学习作为人工智能领域研究热点之一,其研究进展与成果也引发了众多关注.为帮助研究与工程人员了解该领域的相关进展和 ...

  10. 深度强化学习探索算法最新综述,近200篇文献揭示挑战和未来方向

    ©作者 | 杨天培.汤宏垚等 来源 | 机器之心 强化学习是在与环境交互过程中不断学习的,⽽交互中获得的数据质量很⼤程度上决定了智能体能够学习到的策略的⽔平.因此,如何引导智能体探索成为强化学习领域研 ...

最新文章

  1. UITableVeiw相关的需求解决
  2. 类成员指针和0x0地址转换
  3. boost::graph::isomorphism用法的测试程序
  4. java 加密解密简单实现
  5. oracle简易数据库搭建,Oracle 10g 手工创建一个最简单的数据库
  6. 【Express】—get传递参数
  7. android 实现抽屉效果
  8. php pos 接收,PHP开发中php pos()函数的使用详解
  9. UVA10014 Simple calculations【数列】
  10. VDN元宇宙游戏公会|Cool Metaverse首个开放共享式元宇宙平台
  11. java计算机毕业设计家教管理系统源码+mysql数据库+系统+lw文档+部署
  12. ENVI学习总结(十)——遥感图像监督分类
  13. DepthMap(1):D. Eigen (NIPS2014)
  14. 分享,请不要忽视了作者的版权
  15. Expressive JavaScript
  16. 有没有视频合并软件?合并视频这样做
  17. From Big to Small
  18. Cris 的Python日记(五):Python 数据结构之元祖,字典和集合
  19. 了解交换机基本原理与配置
  20. Semi-Supervised Deep Learning for Monocular Depth Map Prediction

热门文章

  1. MyBatis四大对象
  2. 这几款真香旗舰机,买到就是赚到,有你入手了的吗?
  3. jit和jitx区别_JIT是什么东西 分分钟打下来!
  4. 【红帽入门指南】第二期:Linux的基本使用
  5. 有奖问卷 | 2022年中国云原生安全调查,邀您来答!
  6. 用手机如何把PDF转成PPT文件
  7. 1.tessent命令学习笔记
  8. 头歌-信息安全技术-【实训10】HTML信息隐藏、动态分析技术
  9. 物联网实时内核 vnRTOS 免费开源
  10. PAMI19 - 强大的级联RCNN架构《Cascade R-CNN: High Quality Object Detection and Instance Segmentation》