【强化学习】PPO算法求解倒立摆问题 + Pytorch代码实战
文章目录
- 一、倒立摆问题介绍
- 二、PPO算法简介
- 三、详细资料
- 四、Python代码实战
- 4.1 运行前配置
- 4.2 主要代码
- 4.3 运行结果展示
- 4.4 关于可视化的设置
一、倒立摆问题介绍
Agent 必须在两个动作之间做出决定 - 向左或向右移动推车 - 以使连接到它的杆保持直立。
二、PPO算法简介
近端策略优化 ( proximal policy optimization, PPO):
避免在使用重要性采样时由于在 θ\thetaθ 下的 pθ(at∣st)p_\theta\left(a_t \mid s_t\right)pθ(at∣st) 与在 θ′\theta^{\prime}θ′ 下的 pθ′(at∣st)p_{\theta^{\prime}}\left(a_t \mid s_t\right)pθ′(at∣st) 相差太多, 导致重要性采样结果偏差较大而采取的算法。具体来说就是在训练的过 程中增加一个限制, 这个限制对应 θ\thetaθ 和 θ′\theta^{\prime}θ′ 输出的动作的 KL 散度, 来衡量 θ\thetaθ 与 θ′\theta^{\prime}θ′ 的相似程度。
三、详细资料
关于更加详细的PPO算法介绍,请看我之前发的博客:【EasyRL学习笔记】第五章 Proximal Policy Optimization 近端策略优化算法
在学习PPO算法前你最好能了解以下知识点:
- 全连接神经网络
- 神经网络求解分类问题
- 神经网络基本工作原理
- KL散度
四、Python代码实战
4.1 运行前配置
准备好一个RL_Utils.py文件,文件内容可以从我的一篇里博客获取:【RL工具类】强化学习常用函数工具类(Python代码)
这一步很重要,后面需要引入该RL_Utils.py文件
4.2 主要代码
import argparse
import datetime
import time
import torch.optim as optim
from torch.distributions.categorical import Categorical
import gym
from torch import nn# 这里需要改成自己的RL_Utils.py文件的路径
from Python.ReinforcementLearning.EasyRL.RL_Utils import *class PPOMemory:def __init__(self, batch_size):self.states = []self.probs = []self.vals = []self.actions = []self.rewards = []self.dones = []self.batch_size = batch_sizedef sample(self):batch_step = np.arange(0, len(self.states), self.batch_size)indices = np.arange(len(self.states), dtype=np.int64)np.random.shuffle(indices)batches = [indices[i:i + self.batch_size] for i in batch_step]return np.array(self.states), np.array(self.actions), np.array(self.probs), \np.array(self.vals), np.array(self.rewards), np.array(self.dones), batchesdef push(self, state, action, probs, vals, reward, done):self.states.append(state)self.actions.append(action)self.probs.append(probs)self.vals.append(vals)self.rewards.append(reward)self.dones.append(done)def clear(self):self.states = []self.probs = []self.actions = []self.rewards = []self.dones = []self.vals = []class Actor(nn.Module):def __init__(self, n_states, n_actions,hidden_dim):super(Actor, self).__init__()self.actor = nn.Sequential(nn.Linear(n_states, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, n_actions),nn.Softmax(dim=-1))def forward(self, state):dist = self.actor(state)dist = Categorical(dist)return distclass Critic(nn.Module):def __init__(self, n_states, hidden_dim):super(Critic, self).__init__()self.critic = nn.Sequential(nn.Linear(n_states, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, 1))def forward(self, state):value = self.critic(state)return valueclass PPO:def __init__(self, n_states, n_actions, cfg):self.gamma = cfg['gamma']self.continuous = cfg['continuous']self.policy_clip = cfg['policy_clip']self.n_epochs = cfg['n_epochs']self.gae_lambda = cfg['gae_lambda']self.device = cfg['device']self.actor = Actor(n_states, n_actions, cfg['hidden_dim']).to(self.device)self.critic = Critic(n_states, cfg['hidden_dim']).to(self.device)self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=cfg['actor_lr'])self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=cfg['critic_lr'])self.memory = PPOMemory(cfg['batch_size'])self.loss = 0def choose_action(self, state):state = np.array([state]) # 先转成数组再转tensor更高效state = torch.tensor(state, dtype=torch.float).to(self.device)dist = self.actor(state)value = self.critic(state)action = dist.sample()probs = torch.squeeze(dist.log_prob(action)).item()if self.continuous:action = torch.tanh(action)else:action = torch.squeeze(action).item()value = torch.squeeze(value).item()return action, probs, valuedef update(self):for _ in range(self.n_epochs):state_arr, action_arr, old_prob_arr, vals_arr, reward_arr, dones_arr, batches = self.memory.sample()values = vals_arr[:]### compute advantage ###advantage = np.zeros(len(reward_arr), dtype=np.float32)for t in range(len(reward_arr) - 1):discount = 1a_t = 0for k in range(t, len(reward_arr) - 1):a_t += discount * (reward_arr[k] + self.gamma * values[k + 1] * \(1 - int(dones_arr[k])) - values[k])discount *= self.gamma * self.gae_lambdaadvantage[t] = a_tadvantage = torch.tensor(advantage).to(self.device)### SGD ###values = torch.tensor(values).to(self.device)for batch in batches:states = torch.tensor(state_arr[batch], dtype=torch.float).to(self.device)old_probs = torch.tensor(old_prob_arr[batch]).to(self.device)actions = torch.tensor(action_arr[batch]).to(self.device)dist = self.actor(states)critic_value = self.critic(states)critic_value = torch.squeeze(critic_value)new_probs = dist.log_prob(actions)prob_ratio = new_probs.exp() / old_probs.exp()weighted_probs = advantage[batch] * prob_ratioweighted_clipped_probs = torch.clamp(prob_ratio, 1 - self.policy_clip,1 + self.policy_clip) * advantage[batch]actor_loss = -torch.min(weighted_probs, weighted_clipped_probs).mean()returns = advantage[batch] + values[batch]critic_loss = (returns - critic_value) ** 2critic_loss = critic_loss.mean()total_loss = actor_loss + 0.5 * critic_lossself.loss = total_lossself.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()total_loss.backward()self.actor_optimizer.step()self.critic_optimizer.step()self.memory.clear()def save_model(self, path):Path(path).mkdir(parents=True, exist_ok=True)actor_checkpoint = os.path.join(path, 'ppo_actor.pt')critic_checkpoint = os.path.join(path, 'ppo_critic.pt')torch.save(self.actor.state_dict(), actor_checkpoint)torch.save(self.critic.state_dict(), critic_checkpoint)def load_model(self, path):actor_checkpoint = os.path.join(path, 'ppo_actor.pt')critic_checkpoint = os.path.join(path, 'ppo_critic.pt')self.actor.load_state_dict(torch.load(actor_checkpoint))self.critic.load_state_dict(torch.load(critic_checkpoint))# 训练函数
def train(arg_dict, env, agent):# 开始计时startTime = time.time()print(f"环境名: {arg_dict['env_name']}, 算法名: {arg_dict['algo_name']}, Device: {arg_dict['device']}")print("开始训练智能体......")rewards = [] # 记录所有回合的奖励ma_rewards = [] # 记录所有回合的滑动平均奖励steps = 0for i_ep in range(arg_dict['train_eps']):state = env.reset()done = Falseep_reward = 0while not done:# 画图if arg_dict['train_render']:env.render()action, prob, val = agent.choose_action(state)state_, reward, done, _ = env.step(action)steps += 1ep_reward += rewardagent.memory.push(state, action, prob, val, reward, done)if steps % arg_dict['update_fre'] == 0:agent.update()state = state_rewards.append(ep_reward)if ma_rewards:ma_rewards.append(0.9 * ma_rewards[-1] + 0.1 * ep_reward)else:ma_rewards.append(ep_reward)if (i_ep + 1) % 10 == 0:print(f"回合:{i_ep + 1}/{arg_dict['train_eps']},奖励:{ep_reward:.2f}")print('训练结束 , 用时: ' + str(time.time() - startTime) + " s")# 关闭环境env.close()return {'episodes': range(len(rewards)), 'rewards': rewards}# 测试函数
def test(arg_dict, env, agent):startTime = time.time()print("开始测试智能体......")print(f"环境名: {arg_dict['env_name']}, 算法名: {arg_dict['algo_name']}, Device: {arg_dict['device']}")rewards = [] # 记录所有回合的奖励ma_rewards = [] # 记录所有回合的滑动平均奖励for i_ep in range(arg_dict['test_eps']):state = env.reset()done = Falseep_reward = 0while not done:# 画图if arg_dict['test_render']:env.render()action, prob, val = agent.choose_action(state)state_, reward, done, _ = env.step(action)ep_reward += rewardstate = state_rewards.append(ep_reward)if ma_rewards:ma_rewards.append(0.9 * ma_rewards[-1] + 0.1 * ep_reward)else:ma_rewards.append(ep_reward)print('回合:{}/{}, 奖励:{}'.format(i_ep + 1, arg_dict['test_eps'], ep_reward))print("测试结束 , 用时: " + str(time.time() - startTime) + " s")env.close()return {'episodes': range(len(rewards)), 'rewards': rewards}# 创建环境和智能体
def create_env_agent(arg_dict):# 创建环境env = gym.make(arg_dict['env_name'])# 设置随机种子all_seed(env, seed=arg_dict["seed"])# 获取状态数try:n_states = env.observation_space.nexcept AttributeError:n_states = env.observation_space.shape[0]# 获取动作数n_actions = env.action_space.nprint(f"状态数: {n_states}, 动作数: {n_actions}")# 将状态数和动作数加入算法参数字典arg_dict.update({"n_states": n_states, "n_actions": n_actions})# 实例化智能体对象agent = PPO(n_states, n_actions, arg_dict)# 返回环境,智能体return env, agentif __name__ == '__main__':# 防止报错 OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"# 获取当前路径curr_path = os.path.dirname(os.path.abspath(__file__))# 获取当前时间curr_time = datetime.datetime.now().strftime("%Y_%m_%d-%H_%M_%S")# 相关参数设置parser = argparse.ArgumentParser(description="hyper parameters")parser.add_argument('--algo_name', default='PPO', type=str, help="name of algorithm")parser.add_argument('--env_name', default='CartPole-v0', type=str, help="name of environment")parser.add_argument('--continuous', default=False, type=bool,help="if PPO is continuous") # PPO既可适用于连续动作空间,也可以适用于离散动作空间parser.add_argument('--train_eps', default=200, type=int, help="episodes of training")parser.add_argument('--test_eps', default=20, type=int, help="episodes of testing")parser.add_argument('--gamma', default=0.99, type=float, help="discounted factor")parser.add_argument('--batch_size', default=5, type=int) # mini-batch SGD中的批量大小parser.add_argument('--n_epochs', default=4, type=int)parser.add_argument('--actor_lr', default=0.0003, type=float, help="learning rate of actor net")parser.add_argument('--critic_lr', default=0.0003, type=float, help="learning rate of critic net")parser.add_argument('--gae_lambda', default=0.95, type=float)parser.add_argument('--policy_clip', default=0.2, type=float) # PPO-clip中的clip参数,一般是0.1~0.2左右parser.add_argument('--update_fre', default=20, type=int)parser.add_argument('--hidden_dim', default=256, type=int)parser.add_argument('--device', default='cpu', type=str, help="cpu or cuda")parser.add_argument('--seed', default=520, type=int, help="seed")parser.add_argument('--show_fig', default=False, type=bool, help="if show figure or not")parser.add_argument('--save_fig', default=True, type=bool, help="if save figure or not")parser.add_argument('--train_render', default=False, type=bool,help="Whether to render the environment during training")parser.add_argument('--test_render', default=True, type=bool,help="Whether to render the environment during testing")args = parser.parse_args()default_args = {'result_path': f"{curr_path}/outputs/{args.env_name}/{curr_time}/results/",'model_path': f"{curr_path}/outputs/{args.env_name}/{curr_time}/models/",}# 将参数转化为字典 type(dict)arg_dict = {**vars(args), **default_args}print("算法参数字典:", arg_dict)# 创建环境和智能体env, agent = create_env_agent(arg_dict)# 传入算法参数、环境、智能体,然后开始训练res_dic = train(arg_dict, env, agent)print("算法返回结果字典:", res_dic)# 保存相关信息agent.save_model(path=arg_dict['model_path'])save_args(arg_dict, path=arg_dict['result_path'])save_results(res_dic, tag='train', path=arg_dict['result_path'])plot_rewards(res_dic['rewards'], arg_dict, path=arg_dict['result_path'], tag="train")# =================================================================================================# 创建新环境和智能体用来测试print("=" * 300)env, agent = create_env_agent(arg_dict)# 加载已保存的智能体agent.load_model(path=arg_dict['model_path'])res_dic = test(arg_dict, env, agent)save_results(res_dic, tag='test', path=arg_dict['result_path'])plot_rewards(res_dic['rewards'], arg_dict, path=arg_dict['result_path'], tag="test")
4.3 运行结果展示
由于有些输出太长,下面仅展示部分输出
状态数: 4, 动作数: 2
环境名: CartPole-v0, 算法名: PPO, Device: cpu
开始训练智能体......
回合:10/200,奖励:14.00
回合:20/200,奖励:36.00
回合:30/200,奖励:21.00
回合:40/200,奖励:23.00
回合:50/200,奖励:25.00
回合:60/200,奖励:155.00
回合:70/200,奖励:200.00
回合:80/200,奖励:101.00
回合:90/200,奖励:153.00
回合:100/200,奖励:145.00
回合:110/200,奖励:166.00
回合:120/200,奖励:200.00
回合:130/200,奖励:200.00
回合:140/200,奖励:200.00
回合:150/200,奖励:200.00
回合:160/200,奖励:144.00
回合:170/200,奖励:200.00
回合:180/200,奖励:200.00
回合:190/200,奖励:200.00
回合:200/200,奖励:200.00
训练结束 , 用时: 130.60313510894775 s
============================================================================================================================================================================================================================================================================================================
状态数: 4, 动作数: 2
开始测试智能体......
环境名: CartPole-v0, 算法名: PPO, Device: cpu
回合:1/20, 奖励:200.0
回合:2/20, 奖励:200.0
回合:3/20, 奖励:200.0
回合:4/20, 奖励:200.0
回合:5/20, 奖励:200.0
回合:6/20, 奖励:200.0
回合:7/20, 奖励:200.0
回合:8/20, 奖励:200.0
回合:9/20, 奖励:200.0
回合:10/20, 奖励:200.0
回合:11/20, 奖励:200.0
回合:12/20, 奖励:200.0
回合:13/20, 奖励:200.0
回合:14/20, 奖励:200.0
回合:15/20, 奖励:200.0
回合:16/20, 奖励:200.0
回合:17/20, 奖励:200.0
回合:18/20, 奖励:181.0
回合:19/20, 奖励:200.0
回合:20/20, 奖励:125.0
测试结束 , 用时: 31.763733386993408 s
PPO算法测试
策略梯度算法测试:【强化学习】Policy Gradient 策略梯度算法求解CartPole倒立摆问题 + Python代码实战
是不是明显感觉到经过PPO算法训练出来的智能体在测试中表现得更加稳呢!
4.4 关于可视化的设置
如果你觉得可视化比较耗时,你可以进行设置,取消可视化。
或者你想看看训练过程的可视化,也可以进行相关设置
【强化学习】PPO算法求解倒立摆问题 + Pytorch代码实战相关推荐
- 【强化学习】优势演员-评论员算法(Advantage Actor-Critic , A2C)求解倒立摆问题 + Pytorch代码实战
文章目录 一.倒立摆问题介绍 二.优势演员-评论员算法简介 三.详细资料 四.Python代码实战 4.1 运行前配置 4.2 主要代码 4.2.1 网络参数不共享版本 4.2.2 网络参数共享版本 ...
- 【强化学习】双深度Q网络(DDQN)求解倒立摆问题 + Pytorch代码实战
文章目录 一.倒立摆问题介绍 二.双深度Q网络简介 三.详细资料 四.Python代码实战 4.1 运行前配置 4.2 主要代码 4.3 运行结果展示 4.4 关于可视化的设置 一.倒立摆问题介绍 A ...
- 【强化学习】竞争深度Q网络(Dueling DQN)求解倒立摆问题 + Pytorch代码实战
文章目录 一.倒立摆问题介绍 二.竞争深度Q网络简介 三.详细资料 四.Python代码实战 4.1 运行前配置 4.2 主要代码 4.3 运行结果展示 4.4 关于可视化的设置 一.倒立摆问题介绍 ...
- 【强化学习PPO算法】
强化学习PPO算法 一.PPO算法 二.伪代码 三.相关的简单理论 1.ratio 2.裁断 3.Advantage的计算 4.loss的计算 四.算法实现 五.效果 六.感悟 最近再改一个代码, ...
- 强化学习ppo算法详解
PPO (Proximal Policy Optimization) 是一种基于梯度的强化学习算法.它的主要思想是通过对策略的更新来提高策略的效率.主要包括以下步骤: 首先选取一个初始策略,然后使用这 ...
- 【强化学习】Sarsa算法求解悬崖行走问题 + Python代码实战
文章目录 一.Sarsa算法简介 1.1 更新公式 1.2 预测策略 1.3 详细资料 二.Python代码实战 2.1 运行前配置 2.2 主要代码 2.3 运行结果展示 2.4 关于可视化寻路过程 ...
- 【强化学习】Q-Learning算法求解悬崖行走问题 + Python代码实战
文章目录 一.Q-Learning算法简介 1.1 更新公式 1.2 预测策略 1.3 详细资料 二.Python代码实战 2.1 运行前配置 2.2 主要代码 2.3 运行结果展示 2.4 关于可视 ...
- 【原创】强化学习笔记|从零开始学习PPO算法编程(pytorch版本)
从零开始学习PPO算法编程(pytorch版本)_melody_cjw的博客-CSDN博客_ppo算法 pytorch 从零开始学习PPO算法编程(pytorch版本)(二)_melody_cjw的博 ...
- 强化学习经典算法笔记(十二):近端策略优化算法(PPO)实现,基于A2C(下)
强化学习经典算法笔记(十二):近端策略优化算法(PPO)实现,基于A2C 本篇实现一个基于A2C框架的PPO算法,应用于连续动作空间任务. import torch import torch.nn a ...
最新文章
- 中国油气装备行业发展状况与投资前景咨询报告2022-2028年版
- python下载过程中最后一步执行opencv出错怎么回事_如何修复python中opencv中的错误“QObject::moveToThread:”?...
- pythonflat怎么设置_python – numpy 2d和1d add flat
- pythontcp文件传输_python socket实现文件传输(防粘包)
- command对象和DataReader的学习
- python的集合类型_python集合类型
- python装饰器的应用案例
- 操作系统 第二部分 进程管理(五)
- ftp等远程登录工具的星号密码查看方法
- 计算机学安杰拉,《朗文高级英语阅读参考-(上册)》.pdf
- 美国金融危机产生的原因
- Linux驱动开发---杂项设备
- OpenCV-Python 中文教程
- 计算机控制台win10,Win10系统打开Windows控制台的方法
- 小米四启用虚拟按键以及禁用实体按键
- windows下使用vscode编写运行以及调试C/C++
- xamarin android 微信,转换微信SDK为Xamarin绑定库 Android5.5.8 iOS1.8.6.2
- 大豆技术面分析_大豆高产栽培关键技术分析,简单、明了轻松学会
- 阿里达摩院的AI Earth(AIE)云平台介绍
- scrollY,scrollTo
热门文章
- 运行tomcat7w.exe,提示:指定的服务未安装unable to open the service tomcat7
- Javascript 实现gb2312和utf8编码的互换
- System Verilog约束块(constrain block)控制和随机变量的随机属性控制
- awk,sed,grep
- 题目错题记录表mysql设计_基于Web2.0的跨平台电子错题本功能的设计与实现
- 数字人民币概论、特征、架构介绍
- Mean Squared Error 和 Maximum-A-Posterior (Maximum Likelihood Estimation) 的关系
- MBA/MEM 复试准备(03)面试礼仪
- c语言中的二目运算符,C语言中的三目运算符是什么
- linux环境变量设置图解,Ubuntu Linux 各个环境变量配置文件详解, 环境变量PATH设置...