出租车调度-Q learning & SARSA

  • 案例分析
  • 实验环境使用
  • 同策时序差分学习调度
  • 异策时序差分调度
  • 资格迹学习调度
  • 结论

代码链接

案例分析

本节考虑Gym库里出租车调度问题(Taxi-v2):在一个5×5方格表示的地图上,有4个出租车停靠点。在每个回合开始时,有一个乘客会随机出现在4个出租车停靠点中的一个,并想在任意一个出租车停靠点下车。出租车会随机出现在25个位置的任意一个位置。出租车需要通过移动自己的位置,到达乘客所在的位置,并将乘客接上车,然后移动到乘客想下车的位置,再让乘客下车。出租车只能在地图范围内上下左右移动一格,并且在有竖线阻拦地方不能横向移动。出租车完成一次任务可以得到20个奖励,每次试图移动得到-1个奖励,不合理地邀请乘客上车(例如目前车和乘客不在同一位置,或乘客已经上车)或让乘客下车(例如车不在目的地,或车上没有乘客)得到-10个奖励。希望调度出租车让总奖励的期望最大。

实验环境使用

Gym库的Taxi-v2环境实现了出租车调度问题的环境。导入环境后,可以用env.reset()来初始化环境,用env.step()来执行一步,用env.render()来显示当前局势。env.render()会打印出的局势图,其中乘客的位置、目的地会用彩色字母显示,出租车的位置会高亮显示。具体而言,如果乘客不在车上,乘客等待地点(位置)的字母会显示为蓝色。目的地所在的字母会显示为洋红色。如果乘客不在车上,出租车所在的位置会用黄色高亮;如果乘客在车上,出租车所在的位置会用绿色高亮。

这个环境中的观测是一个范围为[0,500)的int型数值。这个数值实际上唯一表示了整个环境的状态。我们可以用env.decode()函数将这个int数值转化为长度为4的元组(taxirow,taxicol,passloc,desti dx),其各元素含义如下:
·taxirow和taxicol是取值为{0,1,2,3,4}的int型变量,表示当前出租车的位置;
·passloc是取值为{0,1,2,3,4}的int型数值,表示乘客的位置,其中0~3表示乘客在表1中对应的位置等待,4表示乘客在车上;
·destidx是取值为{0,1,2,3}的int型数值,表示目的地,目的地的位置由表5-1给出。全部的状态总数为(5×5)×5×4=500。

这个问题中的动作是取自{0,1,2,3,4,5}的int型数值,其含义下表所示。表中还给出了对应的env.render()函数给出的文字提示以及执行动作后可能得到的奖励值。

代码清单给出了初始化环境并玩一步的代码。初始化后,借助env.decode()获得了出租车、乘客和目的地的位置,并将地图显示出来,接着试图玩了一步。

import gym
env = gym.make('Taxi-v2')
state = env.reset()
taxirow, taxicol, passloc, destidx = env.unwrapped.decode(state)
print(taxirow, taxicol, passloc, destidx)
print('出租车位置 = {}'.format((taxirow, taxicol)))
print('乘客位置 = {}'.format(env.unwrapped.locs[passloc]))
print('目标位置 = {}'.format(env.unwrapped.locs[destidx]))
env.render()
env.step(1)

至此,我们已经会使用这个环境了。

同策时序差分学习调度

本节我们使用SARSA算法和期望SARSA算法来学习策略。
首先我们来看SARSA算法。以下代码中的SARSAAgent类play_sarsa()函数共同实现了SARSA算法。其中,SARSAAgent类包括了智能体的学习逻辑和判决逻辑,是智能体类;play_sarsa()函数实现了智能体和环境交互的逻辑。play_sarsa()函数有两个bool类型的参数,参数train表示是否对智能体进行训练,参数render表示是否用对人类友好的方式显示当前环境。这里把SARSA算法拆分成一个智能体类和一个描述智能体和环境交互的函数,是为了能够更加清晰地将智能体的学习和决策过程隔离开来。智能体和环境的交互过程可以为许多类似的智能体重复使用。例如,play_sarsa()函数不仅在SARSA算法中被使用,还会被本章后续的SARSA(λ)算法使用,甚至被后续章节使用。

class SARSAAgent:def __init__(self, env, gamma=0.9, learning_rate=0.2, epsilon=.01):self.gamma = gammaself.learning_rate = learning_rateself.epsilon = epsilonself.action_n = env.action_space.nself.q = np.zeros((env.observation_space.n, env.action_space.n))def decide(self, state):if np.random.uniform() > self.epsilon:action = self.q[state].argmax()else:action = np.random.randint(self.action_n)return actiondef learn(self, state, action, reward, next_state, done, next_action):u = reward + self.gamma * \self.q[next_state, next_action] * (1. - done)td_error = u - self.q[state, action]self.q[state, action] += self.learning_rate * td_error

SARSA智能体与环境交互一回合

def play_sarsa(env, agent, train=False, render=False):episode_reward = 0observation = env.reset()action = agent.decide(observation)while True:if render:env.render()next_observation, reward, done, _ = env.step(action)episode_reward += rewardnext_action = agent.decide(next_observation) # 终止状态时此步无意义if train:agent.learn(observation, action, reward, next_observation,done, next_action)if done:breakobservation, action = next_observation, next_actionreturn episode_reward

智能体在初始化时,先根据状态空间和动作空间的大小初始化q(s,a),s∈S,a∈Aq(s,a),s∈\mathcal{S},a∈\mathcal{A}q(s,a),s∈S,a∈A。在判决时,使用了ε贪心策略。

下面给出了训练SARSA算法的代码。该代码调用play_sarsa()函数5000次,运行了5000回合的环境进行训练。

# 训练
episodes = 3000
episode_rewards = []
for episode in range(episodes):episode_reward = play_sarsa(env, agent, train=True)episode_rewards.append(episode_reward)plt.plot(episode_rewards)# 测试
agent.epsilon = 0. # 取消探索episode_rewards = [play_sarsa(env, agent) for _ in range(100)]
print('平均回合奖励 = {} / {} = {}'.format(sum(episode_rewards),len(episode_rewards), np.mean(episode_rewards)))

测试结果平均总奖励数值一般在6~8.5之间。增加迭代次数往往能进一步提高性能。

agent.epsilon = 0. # 取消探索
episode_rewards = [play_sarsa(env, agent) for _ in range(100)]
print('平均回合奖励 = {} / {} = {}'.format(sum(episode_rewards),
len(episode_rewards), np.mean(episode_rewards)))

如果我们要显示最优价值估计,可以使用以下语句:

pd.DataFrame(agent.q)

如果显示最优策略估计,可以使用以下语句:

policy = np.eye(agent.action_n)[agent.q.argmax(axis=-1)]
pd.DataFrame(policy)

接下来使用期望SARSA算法求解最优策略。ExpectedSARSAAgent类实现了期望SARSA智能体类。

class ExpectedSARSAAgent:def __init__(self, env, gamma=0.9, learning_rate=0.1, epsilon=.01):self.gamma = gammaself.learning_rate = learning_rateself.epsilon = epsilonself.q = np.zeros((env.observation_space.n, env.action_space.n))self.action_n = env.action_space.ndef decide(self, state):if np.random.uniform() > self.epsilon:action = self.q[state].argmax()else:action = np.random.randint(self.action_n)return actiondef learn(self, state, action, reward, next_state, done):v = (self.q[next_state].mean() * self.epsilon + \self.q[next_state].max() * (1. - self.epsilon))u = reward + self.gamma * v * (1. - done)td_error = u - self.q[state, action]self.q[state, action] += self.learning_rate * td_error

play_qlearning()函数实现了期望SARSA智能体与环境的交互。这里的交互函数命名为play_qlearning,是因为期望SARSA智能体的交互函数和后续Q学习的交互函数相同。

def play_qlearning(env, agent, train=False, render=False):episode_reward = 0observation = env.reset()while True:if render:env.render()action = agent.decide(observation)next_observation, reward, done, _ = env.step(action)episode_reward += rewardif train:agent.learn(observation, action, reward, next_observation,done)if done:breakobservation = next_observationreturn episode_reward

实现了期望SARSA算法后,下面是训练和测试期望SARSA算法的代码。期望SARSA算法在这个问题中的性能往往比SARSA算法要好一些。

episodes = 5000
episode_rewards = []
for episode in range(episodes):
episode_reward = play_qlearning(env, agent, train=True)
episode_rewards.append(episode_reward)
plt.plot(episode_rewards);agent.epsilon = 0. # 取消探索
episode_rewards = [play_qlearning(env, agent) for _ in range(100)]
print('平均回合奖励 = {} / {} = {}'.format(sum(episode_rewards),
len(episode_rewards), np.mean(episode_rewards)))

异策时序差分调度

本节我们使用Q学习和双重Q学习来学习最优策略。
首先来看Q学习算法。QLearningAgent智能体类和play_qlearning()函数一起实现了Q学习算法。QLearningAgent类和ExpectedSARSAAgent类的区别在于learn()函数内自益的方法不同。

class QLearningAgent:def __init__(self, env, gamma=0.9, learning_rate=0.1, epsilon=.01):self.gamma = gammaself.learning_rate = learning_rateself.epsilon = epsilonself.action_n = env.action_space.nself.q = np.zeros((env.observation_space.n, env.action_space.n))def decide(self, state):if np.random.uniform() > self.epsilon:action = self.q[state].argmax()else:action = np.random.randint(self.action_n)return actiondef learn(self, state, action, reward, next_state, done):u = reward + self.gamma * self.q[next_state].max() * (1. - done)td_error = u - self.q[state, action]self.q[state, action] += self.learning_rate * td_error

接下来看双重Q学习算法。DoubleQLearningAgent类和play_qlearning()函数一起实现了双重Q学习算法。双重Q学习涉及两组动作价值估计,DoubleQLearnignAgent类和QLearningAgent类在构造函数、decide()函数和learn()函数都有区别。在该问题中,最大化偏差并不明显,所以双重Q学习往往不能得到好处。

class DoubleQLearningAgent:def __init__(self, env, gamma=0.9, learning_rate=0.1, epsilon=.01):self.gamma = gammaself.learning_rate = learning_rateself.epsilon = epsilonself.action_n = env.action_space.nself.q0 = np.zeros((env.observation_space.n, env.action_space.n))self.q1 = np.zeros((env.observation_space.n, env.action_space.n))def decide(self, state):if np.random.uniform() > self.epsilon:action = (self.q0 + self.q1)[state].argmax()else:action = np.random.randint(self.action_n)return actiondef learn(self, state, action, reward, next_state, done):if np.random.randint(2):self.q0, self.q1 = self.q1, self.q0a = self.q0[next_state].argmax()u = reward + self.gamma * self.q1[next_state, a] * (1. - done)td_error = u - self.q0[state, action]self.q0[state, action] += self.learning_rate * td_error

资格迹学习调度

本节使用SARSA(λ)算法来学习策略。代码实现了SARSA(λ)算法智能体类SARSALambdaAgent类,它由代码清单5-2中的SARSAAgent类派生而来。与SARSAAgent类相比,它多了需要控制衰减速度的参数lambd和控制资格迹增加的参数beta。值得一提的是,lambda是Python的关键字,所以这里不用lambda作为变量名,而是用去掉最后一个字母的lambd作为变量名。由于引入了资格迹,所以SARSA(λ)算法的性能往往比单步SARSA算法要好。

class SARSALambdaAgent(SARSAAgent):def __init__(self, env, lambd=0.6, beta=1.,gamma=0.9, learning_rate=0.1, epsilon=.01):super().__init__(env, gamma=gamma, learning_rate=learning_rate,epsilon=epsilon)self.lambd = lambdself.beta = betaself.e = np.zeros((env.observation_space.n, env.action_space.n))def learn(self, state, action, reward, next_state, done, next_action):# 更新资格迹self.e *= (self.lambd * self.gamma)self.e[state, action] = 1. + self.beta * self.e[state, action]# 更新价值u = reward + self.gamma * \self.q[next_state, next_action] * (1. - done)td_error = u - self.q[state, action]self.q += self.learning_rate * self.e * td_errorif done:self.e *= 0.

在这一节中,我们尝试了很多算法,有些算法的性能相对另外一些较好。其中的原因比较复杂,可能是算法本身的问题,也可能是参数选择的问题。没有一个算法是对所有的任务都有效的。可能对于这个任务,这个算法效果好;换了一个任务后,另外一个算法效果好。

结论

无模型时序差分更新方法,包括了同策时序差分算法SARSA算法和期望SARSA算法,以及异策时序差分算法Q学习和双重Q学习算法。各种算法的主要区别在于更新目标Ut具有不同的表达式。最后还介绍了历史上具有重大影响力的资格迹算法。

[强化学习实战]出租车调度-Q learning SARSA相关推荐

  1. 基于多智能体强化学习的出租车调度框架

    网约车平台的繁荣使得人们比以往能更加"智慧"的出行.平台能实时掌握全局的车辆与乘客的供需关系,从而在车辆与乘客之间实现更加有效的匹配.但车辆与乘客还是会经常遭遇"车辆不停 ...

  2. 【强化学习实战】基于gym和tensorflow的强化学习算法实现

    [新智元导读]知乎专栏强化学习大讲堂作者郭宪博士开讲<强化学习从入门到进阶>,我们为您节选了其中的第二节<基于gym和tensorflow的强化学习算法实现>,希望对您有所帮助 ...

  3. 【强化学习实战-04】DQN和Double DQN保姆级教程(2):以MountainCar-v0

    [强化学习实战-04]DQN和Double DQN保姆级教程(2):以MountainCar-v0 实战:用Double DQN求解MountainCar问题 MountainCar问题详解 Moun ...

  4. 【经典书籍】深度强化学习实战(附最新PDF和源代码下载)

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 深度强化学习可以说是人工智能领域现在最热门的方向,吸引了众多该领域优秀的科学家去发 ...

  5. 30篇强化学习求解车间调度文章(中文)大全

    国内使用强化学习求解车间调度问题的研究起步较晚,基本是在在2000年以后,而深度强化学习求解车间调度问题更是在2019.2020年左右开始流行.今天在上一篇文章的基础上((吐血整理)118篇强化学习求 ...

  6. Keras深度学习实战——使用深度Q学习进行SpaceInvaders游戏

    Keras深度学习实战--使用深度Q学习进行SpaceInvaders游戏 0. 前言 1. 问题与模型分析 2. 使用深度 Q 学习进行 SpaceInvaders 游戏 相关链接 0. 前言 在& ...

  7. PyTorch强化学习实战(1)——强化学习环境配置与PyTorch基础

    PyTorch强化学习实战(1)--强化学习环境配置与PyTorch基础 0. 前言 1. 搭建 PyTorch 环境 2. OpenAI Gym简介与安装 3. 模拟 Atari 环境 4. 模拟 ...

  8. 强化学习 最前沿之Hierarchical reinforcement learning(一)

    强化学习-最前沿系列 深度强化学习作为当前发展最快的方向,可以说是百家争鸣的时代.针对特定问题,针对特定环境的文章也层出不穷.对于这么多的文章和方向,如果能撇一隅,往往也能够带来较多的启发. 本系列文 ...

  9. 增强学习(五)----- 时间差分学习(Q learning, Sarsa learning)

    接下来我们回顾一下动态规划算法(DP)和蒙特卡罗方法(MC)的特点,对于动态规划算法有如下特性: 需要环境模型,即状态转移概率PsaPsa 状态值函数的估计是自举的(bootstrapping),即当 ...

最新文章

  1. DeepMind的蛋白质折叠AI解决了50年来的生物学重大挑战
  2. 轻松实现远程批量拷贝文件脚本(女学生作品)
  3. 《图解HTTP》读书笔记--第7章 确保Web安全的HTTPS
  4. Linux软件安装管理 - CentOS (三) ---- 源码包管理
  5. cf1555D. Say No to Palindromes
  6. word 编辑域中的汉字_word中插入的cad对象无法双击编辑问题解决记录
  7. linux怎么查看sklearn版本,Sklearn——Sklearn的介绍与安装
  8. 将Eclipse中的工程保存到Github的操作步骤
  9. 只用一个WiFi,渗透进企业全部内网
  10. WCF中配置文件解析
  11. 250分b区计算机专硕,2021兰州大学研究生复试分数线
  12. Android wakelock机制
  13. 使用深度学习进行表检测、信息提取和构建
  14. php 身份证格式校验,年龄计算
  15. 写给20几岁的女孩、男孩
  16. 运放指标-压摆率SR
  17. 计算机网络面试知识点整理
  18. c++语言解一元二次方程,初学C++新手跪求:用C++编 解一元二次方程 并 结果用复数表示。。...
  19. 机器学习(一):什么是机器学习
  20. 用百行Python代码写一个关于德州扑克的类

热门文章

  1. 目标检测算法(YOLOv4)
  2. Java接口回调的概念和作用
  3. Impala 技术点梳理
  4. 【QT开发笔记-基础篇】| 第五章 绘图QPainter | 5.14 平移、旋转、缩放
  5. 惩罚因子(penalty term)与损失函数(loss function)
  6. Ae:常用表达式及应用(01)
  7. 离散数学题目总结归纳
  8. python步长什么意思_python – Numpy:在每个时间步长平均许多数据点
  9. 实现无限轮播广告条如此简单
  10. 安装Goland19.3