强化学习经典算法笔记(十二):近端策略优化算法(PPO)实现,基于A2C

本篇实现一个基于A2C框架的PPO算法,应用于连续动作空间任务。

import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
import gym
import numpy as npdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class Memory:def __init__(self):self.actions = []self.states = []self.logprobs = []self.rewards = []self.is_terminals = []def clear_memory(self):del self.actions[:]del self.states[:]del self.logprobs[:]del self.rewards[:]del self.is_terminals[:]

A2C的实现和上篇的区别在于动作的选择。Actor输出多变量高斯分布的均值向量,人为给定一个协方差矩阵,当然Var也可以学习出来。

class ActorCritic(nn.Module):def __init__(self, state_dim, action_dim, action_std):super(ActorCritic, self).__init__()# action mean range -1 to 1self.actor =  nn.Sequential(nn.Linear(state_dim, 64),nn.Tanh(),nn.Linear(64, 32),nn.Tanh(),nn.Linear(32, action_dim),nn.Tanh())# criticself.critic = nn.Sequential(nn.Linear(state_dim, 64),nn.Tanh(),nn.Linear(64, 32),nn.Tanh(),nn.Linear(32, 1))self.action_var = torch.full((action_dim,), action_std*action_std).to(device)def forward(self):raise NotImplementedErrordef act(self, state, memory):action_mean = self.actor(state)cov_mat = torch.diag(self.action_var).to(device)dist = MultivariateNormal(action_mean, cov_mat)action = dist.sample()action_logprob = dist.log_prob(action)memory.states.append(state)memory.actions.append(action)memory.logprobs.append(action_logprob)return action.detach()def evaluate(self, state, action):   action_mean = torch.squeeze(self.actor(state))action_var = self.action_var.expand_as(action_mean)cov_mat = torch.diag_embed(action_var).to(device)dist = MultivariateNormal(action_mean, cov_mat)action_logprobs = dist.log_prob(torch.squeeze(action))dist_entropy = dist.entropy()state_value = self.critic(state)return action_logprobs, torch.squeeze(state_value), dist_entropy
class PPO:def __init__(self, state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip):self.lr = lrself.betas = betasself.gamma = gammaself.eps_clip = eps_clipself.K_epochs = K_epochsself.policy = ActorCritic(state_dim, action_dim, action_std).to(device)self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)self.policy_old = ActorCritic(state_dim, action_dim, action_std).to(device)self.policy_old.load_state_dict(self.policy.state_dict())self.MseLoss = nn.MSELoss()def select_action(self, state, memory):state = torch.FloatTensor(state.reshape(1, -1)).to(device)return self.policy_old.act(state, memory).cpu().data.numpy().flatten()def update(self, memory):# Monte Carlo estimate of rewards:rewards = []discounted_reward = 0for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):if is_terminal:discounted_reward = 0discounted_reward = reward + (self.gamma * discounted_reward)rewards.insert(0, discounted_reward)# Normalizing the rewards:rewards = torch.tensor(rewards).to(device)rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)# convert list to tensorold_states = torch.squeeze(torch.stack(memory.states).to(device)).detach()old_actions = torch.squeeze(torch.stack(memory.actions).to(device)).detach()old_logprobs = torch.squeeze(torch.stack(memory.logprobs)).to(device).detach()# Optimize policy for K epochs:for _ in range(self.K_epochs):# Evaluating old actions and values :logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)# Finding the ratio (pi_theta / pi_theta__old):ratios = torch.exp(logprobs - old_logprobs.detach())# Finding Surrogate Loss:advantages = rewards - state_values.detach()   surr1 = ratios * advantagessurr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantagesloss = -torch.min(surr1, surr2) + 0.5*self.MseLoss(state_values, rewards) - 0.01*dist_entropy# take gradient stepself.optimizer.zero_grad()loss.mean().backward()self.optimizer.step()# Copy new weights into old policy:self.policy_old.load_state_dict(self.policy.state_dict())
def main():############## Hyperparameters ##############env_name = "BipedalWalker-v2"render = Falsesolved_reward = 300         # stop training if avg_reward > solved_rewardlog_interval = 20           # print avg reward in the intervalmax_episodes = 10000        # max training episodesmax_timesteps = 1500        # max timesteps in one episodeupdate_timestep = 4000      # update policy every n timestepsaction_std = 0.5            # constant std for action distribution (Multivariate Normal)K_epochs = 80               # update policy for K epochseps_clip = 0.2              # clip parameter for PPOgamma = 0.99                # discount factorlr = 0.0003                 # parameters for Adam optimizerbetas = (0.9, 0.999)random_seed = None############################################## creating environmentenv = gym.make(env_name)state_dim = env.observation_space.shape[0]action_dim = env.action_space.shape[0]if random_seed:print("Random Seed: {}".format(random_seed))torch.manual_seed(random_seed)env.seed(random_seed)np.random.seed(random_seed)memory = Memory()ppo = PPO(state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip)print(lr,betas)# logging variablesrunning_reward = 0avg_length = 0time_step = 0# training loopfor i_episode in range(1, max_episodes+1):state = env.reset()for t in range(max_timesteps):time_step +=1# Running policy_old:action = ppo.select_action(state, memory)state, reward, done, _ = env.step(action)# Saving reward and is_terminals:memory.rewards.append(reward)memory.is_terminals.append(done)# update if its timeif time_step % update_timestep == 0:ppo.update(memory)memory.clear_memory()time_step = 0running_reward += rewardif render:env.render()if done:breakavg_length += t# stop training if avg_reward > solved_rewardif running_reward > (log_interval*solved_reward):print("########## Solved! ##########")torch.save(ppo.policy.state_dict(), './PPO_continuous_solved_{}.pth'.format(env_name))break# save every 500 episodesif i_episode % 500 == 0:torch.save(ppo.policy.state_dict(), './PPO_continuous_{}.pth'.format(env_name))# loggingif i_episode % log_interval == 0:avg_length = int(avg_length/log_interval)running_reward = int((running_reward/log_interval))print('Episode {} \t Avg length: {} \t Avg reward: {}'.format(i_episode, avg_length, running_reward))running_reward = 0avg_length = 0if __name__ == '__main__':main()

测试代码

import gym
from PPO_continuous import PPO, Memory
from PIL import Image
import torchdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def test():############## Hyperparameters ##############env_name = "BipedalWalker-v2"env = gym.make(env_name)state_dim = env.observation_space.shape[0]action_dim = env.action_space.shape[0]n_episodes = 3          # num of episodes to runmax_timesteps = 1500    # max timesteps in one episoderender = True           # render the environmentsave_gif = False        # png images are saved in gif folder# filename and directory to load model fromfilename = "PPO_continuous_" +env_name+ ".pth"directory = "./preTrained/"action_std = 0.5        # constant std for action distribution (Multivariate Normal)K_epochs = 80           # update policy for K epochseps_clip = 0.2          # clip parameter for PPOgamma = 0.99            # discount factorlr = 0.0003             # parameters for Adam optimizerbetas = (0.9, 0.999)#############################################memory = Memory()ppo = PPO(state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip)ppo.policy_old.load_state_dict(torch.load(directory+filename))for ep in range(1, n_episodes+1):ep_reward = 0state = env.reset()for t in range(max_timesteps):action = ppo.select_action(state, memory)state, reward, done, _ = env.step(action)ep_reward += rewardif render:env.render()if save_gif:img = env.render(mode = 'rgb_array')img = Image.fromarray(img)img.save('./gif/{}.jpg'.format(t))  if done:breakprint('Episode: {}\tReward: {}'.format(ep, int(ep_reward)))ep_reward = 0env.close()if __name__ == '__main__':test()

强化学习经典算法笔记(十二):近端策略优化算法(PPO)实现,基于A2C(下)相关推荐

  1. 近端策略优化算法(PPO)

    策略梯度算法(PG) 策略梯度迭代,通过计算策略梯度的估计,并利用随机梯度上升算法进行迭代.其常用的梯度估计形式为: E^t[∇θlogπθ(at∣st)A^t]\hat{\mathbb{E}}_t[ ...

  2. 【学霸笔记】AlphaGo之父David Silver的强化学习经典课程笔记

    聚焦AI干货,关注:决策智能与机器学习 AlphaGo之父David Silver的强化学习经典课程前文已有介绍,本想自己整理一下课程的学习笔记,但发现已经有学霸整理的很完善,不做东施效颦之举,在此分 ...

  3. 基于近端策略优化算法的车载边缘计算网络频谱资源分配

    摘要 [目的]在车载网络边缘计算中,合理地分配频谱资源对改善车辆通讯质量具有重要意义.频谱资源稀缺是影响车辆通讯质量的重要原因之一,车辆的高移动性以及在基站处准确收集信道状态信息的困难给频谱资源分配带 ...

  4. 在浏览器中进行深度学习:TensorFlow.js (十二)异常检测算法

    2019独角兽企业重金招聘Python工程师标准>>> 异常检测是机器学习领域常见的应用场景,例如金融领域里的信用卡欺诈,企业安全领域里的非法入侵,IT运维里预测设备的维护时间点等. ...

  5. PPO近端策略优化算法概述

    Policy Gradient算法存在两个问题,一是蒙特卡罗只能回合更新,二是on-policy采集的数据只能使用一次. 对于第一个更新慢的问题,改用时序差分方法,引入critic网络估计V值,就能实 ...

  6. 吴恩达深度学习2.2笔记_Improving Deep Neural Networks_优化算法

    版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/weixin_42432468 学习心得: 1.每周的视频课程看一到两遍 2.做笔记 3.做每周的作业 ...

  7. 【详解+推导!!】PPO 近端策略优化

    近端策略优化(PPO, Proximal Policy Optimization)是强化学习中十分重要的一种算法,被 OpenAI 作为默认强化学习算法,在多种强化学习应用中表现十分优异. 文章目录 ...

  8. 强化学习经典算法笔记(十四):双延迟深度确定性策略梯度算法TD3的PyTorch实现

    强化学习经典算法笔记(十四):双延迟深度确定性策略梯度算法TD3的PyTorch实现 TD3算法简介 TD3是Twin Delayed Deep Deterministic policy gradie ...

  9. 强化学习经典算法笔记(十九):无监督策略学习算法Diversity Is All You Need

    强化学习经典算法笔记19:无监督策略学习算法Diversity Is All You Need DIAYN核心要点 模型定义 目标函数的构造 DIAYN算法细节 目标函数的优化 SAC的训练 判别器的 ...

最新文章

  1. oracle数据库硬恢复,ORACLE数据库恢复技术
  2. 多个数字数组_1分钟彻底理解JavaScript的数组与函数
  3. ssh: connect to host github.com port 22: Connection timed out
  4. “老师,我写着写着就 强制交卷了……”
  5. 面试想拿 10K,HR 说你只值 7K,该怎样回答或者反驳?看看这位老前辈怎么说
  6. python模拟抛硬币_python实现简单随机模拟——抛呀抛硬币
  7. SpringBoot的启动过程
  8. 通过 NPOI 生成 Excel
  9. 左边工具栏 隐藏_203 【Ps基础】 工具栏
  10. 卡饭里的云计算机,微云可以在电脑用吗
  11. 基于LSTM的时间序列预测-原理-python代码详解
  12. 小D课堂-SpringBoot 2.x微信支付在线教育网站项目实战_3-4.动态Sql语句Mybaties SqlProvider...
  13. C语言程序设计实验报告——实验五
  14. android 6.0截屏的实现,android截屏实现
  15. 【转】Windows Error code (Windows错误码说明)
  16. 论文笔记——Comparing to Learn
  17. C# worksheet设置Excel样式
  18. jQuery —— 实现电梯导航功能
  19. Error: The required parameter ‘channelID‘ is empty. Rerun the command with -C flag
  20. 生产力飙升!皮卡智能新产品上线,带你进入AIGC新纪元

热门文章

  1. android时钟秒针转动,TextView显示系统时间(时钟功能带秒针变化
  2. 架构设计第五讲:数据巡检系统的设计与应用
  3. 华为微信平行世界怎么添加服务器,华为实现手机和平板“微信双登”:配合“平行视界”独创玩法...
  4. 2020_8_31闲谈——应用统计专业考研建议
  5. 关闭Java11中即将移除Nashorn引擎的警告Warning: Nashorn engine is planned to be removed from a future JDK release
  6. 太极链_太极链取代第三方所需要面对的挑战
  7. NVIDIA英伟达jetson xavier nx官方套件刷机教程
  8. java中改进方式遍历数组
  9. 2018年9月26日公司断网情况处理(交换机环路,err-disable state)
  10. html自定义字体demo,21.8.自定义字体