写在前面

  1. DRL各种算法在github上各处都是,例如莫凡的DRL代码、ElegantDRL(易读性NO.1)
  2. 很多代码不是原算法的最佳实现,在具体实现细节上也存在差异,不建议直接用在科研上。
  3. 这篇博客的代码改写自OpenAi spinningup源码DRL_OpenAI,代码性能方面不再是你需要考虑的问题了
  4. 为什么改写?因为源码依赖环境过多,新手读起来很吃力,还有很多logger让人头疼。
  5. 这篇博客的代码将环境依赖降低到最小,并且摒弃了一些辅助功能,让代码更容易读懂。
  6. SAC算法很新,且性能出众。

项目分三个文件:main.py , SACModel.py , core.py
Python3.6

SACModel.py

import torch
from torch.optim import Adam
from copy import deepcopy
import itertools
import core as core
import numpy as npclass ReplayBuffer:"""A simple FIFO experience replay buffer for SAC agents."""def __init__(self, obs_dim, act_dim, size):self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)self.obs2_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)self.act_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32)self.rew_buf = np.zeros(size, dtype=np.float32)self.done_buf = np.zeros(size, dtype=np.float32)self.ptr, self.size, self.max_size = 0, 0, sizedef store(self, obs, act, rew, next_obs, done):self.obs_buf[self.ptr] = obsself.obs2_buf[self.ptr] = next_obsself.act_buf[self.ptr] = actself.rew_buf[self.ptr] = rewself.done_buf[self.ptr] = doneself.ptr = (self.ptr+1) % self.max_sizeself.size = min(self.size+1, self.max_size)def sample_batch(self, batch_size=32):idxs = np.random.randint(0, self.size, size=batch_size)batch = dict(obs=self.obs_buf[idxs],obs2=self.obs2_buf[idxs],act=self.act_buf[idxs],rew=self.rew_buf[idxs],done=self.done_buf[idxs])return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in batch.items()}class SAC:def __init__(self, obs_dim, act_dim, act_bound, actor_critic=core.MLPActorCritic, seed=0,replay_size=int(1e6), gamma=0.99, polyak=0.995, lr=1e-3, alpha=0.2):self.obs_dim = obs_dimself.act_dim = act_dimself.act_bound = act_boundself.gamma = gammaself.polyak = polyakself.alpha = alphatorch.manual_seed(seed)np.random.seed(seed)self.ac = actor_critic(obs_dim, act_dim, act_limit=2.0)self.ac_targ = deepcopy(self.ac)# Freeze target networks with respect to optimizers (only update via polyak averaging)for p in self.ac_targ.parameters():p.requires_grad = False# List of parameters for both Q-networks (save this for convenience)self.q_params = itertools.chain(self.ac.q1.parameters(), self.ac.q2.parameters())# Set up optimizers for policy and q-functionself.pi_optimizer = Adam(self.ac.pi.parameters(), lr=lr)self.q_optimizer = Adam(self.q_params, lr=lr)# Experience bufferself.replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size)# Set up function for computing SAC Q-lossesdef compute_loss_q(self, data):o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']q1 = self.ac.q1(o,a)q2 = self.ac.q2(o,a)# Bellman backup for Q functionswith torch.no_grad():# Target actions come from *current* policya2, logp_a2 = self.ac.pi(o2)# Target Q-valuesq1_pi_targ = self.ac_targ.q1(o2, a2)q2_pi_targ = self.ac_targ.q2(o2, a2)q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)backup = r + self.gamma * (1 - d) * (q_pi_targ - self.alpha * logp_a2)# MSE loss against Bellman backuploss_q1 = ((q1 - backup)**2).mean()loss_q2 = ((q2 - backup)**2).mean()loss_q = loss_q1 + loss_q2# Useful info for loggingq_info = dict(Q1Vals=q1.detach().numpy(),Q2Vals=q2.detach().numpy())return loss_q, q_info# Set up function for computing SAC pi lossdef compute_loss_pi(self, data):o = data['obs']pi, logp_pi = self.ac.pi(o)q1_pi = self.ac.q1(o, pi)q2_pi = self.ac.q2(o, pi)q_pi = torch.min(q1_pi, q2_pi)# Entropy-regularized policy lossloss_pi = (self.alpha * logp_pi - q_pi).mean()# Useful info for loggingpi_info = dict(LogPi=logp_pi.detach().numpy())return loss_pi, pi_infodef update(self, data):# First run one gradient descent step for Q1 and Q2self.q_optimizer.zero_grad()loss_q, q_info = self.compute_loss_q(data)loss_q.backward()self.q_optimizer.step()# Freeze Q-networks so you don't waste computational effort# computing gradients for them during the policy learning step.for p in self.q_params:p.requires_grad = False# Next run one gradient descent step for pi.self.pi_optimizer.zero_grad()loss_pi, pi_info = self.compute_loss_pi(data)loss_pi.backward()self.pi_optimizer.step()# Unfreeze Q-networks so you can optimize it at next DDPG step.for p in self.q_params:p.requires_grad = True# Finally, update target networks by polyak averaging.with torch.no_grad():for p, p_targ in zip(self.ac.parameters(), self.ac_targ.parameters()):# NB: We use an in-place operations "mul_", "add_" to update target# params, as opposed to "mul" and "add", which would make new tensors.p_targ.data.mul_(self.polyak)p_targ.data.add_((1 - self.polyak) * p.data)def get_action(self, o, deterministic=False):return self.ac.act(torch.as_tensor(o, dtype=torch.float32),deterministic)

core.py

import numpy as np
import scipy.signalimport torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normaldef combined_shape(length, shape=None):if shape is None:return (length,)return (length, shape) if np.isscalar(shape) else (length, *shape)def mlp(sizes, activation, output_activation=nn.Identity):layers = []for j in range(len(sizes)-1):act = activation if j < len(sizes)-2 else output_activationlayers += [nn.Linear(sizes[j], sizes[j+1]), act()]return nn.Sequential(*layers)def count_vars(module):return sum([np.prod(p.shape) for p in module.parameters()])LOG_STD_MAX = 2
LOG_STD_MIN = -20class SquashedGaussianMLPActor(nn.Module):def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):super().__init__()self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim)self.act_limit = act_limitdef forward(self, obs, deterministic=False, with_logprob=True):net_out = self.net(obs)mu = self.mu_layer(net_out)log_std = self.log_std_layer(net_out)log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)std = torch.exp(log_std)# Pre-squash distribution and samplepi_distribution = Normal(mu, std)if deterministic:# Only used for evaluating policy at test time.pi_action = muelse:pi_action = pi_distribution.rsample()if with_logprob:# Compute logprob from Gaussian, and then apply correction for Tanh squashing.# NOTE: The correction formula is a little bit magic. To get an understanding # of where it comes from, check out the original SAC paper (arXiv 1801.01290) # and look in appendix C. This is a more numerically-stable equivalent to Eq 21.# Try deriving it yourself as a (very difficult) exercise. :)logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)else:logp_pi = Nonepi_action = torch.tanh(pi_action)pi_action = self.act_limit * pi_actionreturn pi_action, logp_piclass MLPQFunction(nn.Module):def __init__(self, obs_dim, act_dim, hidden_sizes, activation):super().__init__()self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)def forward(self, obs, act):q = self.q(torch.cat([obs, act], dim=-1))return torch.squeeze(q, -1) # Critical to ensure q has right shape.class MLPActorCritic(nn.Module):def __init__(self, obs_dim, act_dim, hidden_sizes=(256,256),activation=nn.ReLU, act_limit = 2.0):super().__init__()# build policy and value functionsself.pi = SquashedGaussianMLPActor(obs_dim, act_dim, hidden_sizes, activation, act_limit)self.q1 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)self.q2 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)def act(self, obs, deterministic=False):with torch.no_grad():a, _ = self.pi(obs, deterministic, False)return a.numpy()

main.py

from SACModel import *
import gym
import matplotlib.pyplot as pltif __name__ == '__main__':env = gym.make('CartPole-v0')obs_dim = env.observation_space.shape[0]act_dim = env.action_space.shape[0]act_bound = [-env.action_space.high[0], env.action_space.high[0]]sac = SAC(obs_dim, act_dim, act_bound)MAX_EPISODE = 100MAX_STEP = 500update_every = 50batch_size = 100rewardList = []for episode in range(MAX_EPISODE):o = env.reset()ep_reward = 0for j in range(MAX_STEP):if episode > 20:env.render()a = sac.get_action(o)else:a = env.action_space.sample()o2, r, d, _ = env.step(a)sac.replay_buffer.store(o, a, r, o2, d)if episode >= 10 and j % update_every == 0:for _ in range(update_every):batch = sac.replay_buffer.sample_batch(batch_size)sac.update(data=batch)o = o2ep_reward += rif d:breakprint('Episode:', episode, 'Reward:%i' % int(ep_reward))rewardList.append(ep_reward)plt.figure()plt.plot(np.arange(len(rewardList)),rewardList)plt.show()

'CartPole-v0’倒立摆实验Reawrd Curve


由于倒立摆这个环境比较简单,我比较了spinningup的DDPG,差距不是很明显。可以更换一些较为复杂的环境进行测试。

深度强化学习Soft-Actor Critic算法高性能Pytorch代码(改写自spinningup,低环境依赖,低阅读障碍)相关推荐

  1. 深度强化学习DDPG算法高性能Pytorch代码(改写自spinningup,低环境依赖,低阅读障碍)

    写在前面 DRL各种算法在github上各处都是,例如莫凡的DRL代码.ElegantDRL(推荐,易读性NO.1) 很多代码不是原算法的最佳实现,在具体实现细节上也存在差异,不建议直接用在科研上. ...

  2. 【深度强化学习】8. DDPG算法及部分代码解析

    [DataWhale打卡]DDPG算法 Deep Deterministric Policy Gradient 视频参考自:https://www.bilibili.com/video/BV1yv41 ...

  3. 赠票 | 深度强化学习的理论、算法与应用专题探索班

    文末有数据派赠票福利呦! 深度强化学习是人工智能领域的一个新的研究热点.它以一种通用的形式将深度学习的感知能力与强化学习的决策能力相结合,并能够通过端对端的学习方式实现从原始输入到输出的直接控制.自提 ...

  4. 线下报名 | YOCSEF TDS:深度强化学习的理论、算法与应用

    时间:7月29日9:00-17:20 地点:北京中科院计算所,一层/四层报告厅(暂定) 报名方式:1.报名链接:http://conf2.ccf.org.cn/TDS  2.点击文末阅读原文报名  3 ...

  5. 强化学习论文笔记:Soft Actor Critic算法

    Soft Actor Critic是伯克利大学团队在2018年的ICML(International Conference on Machine Learning)上发表的off-policy mod ...

  6. 【经典书籍】深度强化学习实战(附最新PDF和源代码下载)

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 深度强化学习可以说是人工智能领域现在最热门的方向,吸引了众多该领域优秀的科学家去发 ...

  7. 深度强化学习系列(14): A3C算法原理及Tensorflow实现

    在DQN.DDPG算法中均用到了一个非常重要的思想经验回放,而使用经验回放的一个重要原因就是打乱数据之间的相关性,使得强化学习的序列满足独立同分布. 本文首先从Google于ICML2016顶会上发的 ...

  8. 自动驾驶前沿综述:基于深度强化学习的自动驾驶算法

    ©作者 | 陈道明 学校 | 布里斯托尔大学 研究方向 | 自动驾驶方向 这是 21 年的一篇综述文章,可以算得上是最前沿的自动驾驶技术综述.这几年随着深度表征学习的发展,强化学习领域也得到了加强.本 ...

  9. 深度强化学习中Double DQN算法(Q-Learning+CNN)的讲解及在Asterix游戏上的实战(超详细 附源码)

    需要源码和环境搭建请点赞关注收藏后评论区留下QQ~~~ 一.核心思想 针对DQN中出现的高估问题,有人提出深度双Q网络算法(DDQN),该算法是将强化学习中的双Q学习应用于DQN中.在强化学习中,双Q ...

最新文章

  1. python爬取单个网页照片!
  2. 栈与队列5——汉诺塔问题
  3. 两个div叠加触发事件发生闪烁问题
  4. openssl生成https证书 (转)
  5. spark-sql建表语句限制_第三篇|Spark SQL编程指南
  6. 电子邮件一般不在用户计算机中,[单选] 在一个完整的Internet电子邮件地址中,决定用户信箱所在的计算机地址的是()。...
  7. 安装Python第三方库方法总结
  8. Process Explorer更新至v15.2
  9. 准备 macvlan 环境 - 每天5分钟玩转 Docker 容器技术(54)
  10. java加按钮_剪辑大神都在用的加字幕神器,你知道嘛!
  11. 解决 mac系统下sublime imput 函数交互问题
  12. android 百度地图鹰眼,百度地图鹰眼轨迹
  13. OGG故障集锦(一)
  14. mysql通用精确计算年龄方式
  15. java nc接口开发_OA和NC系统集成接口开发方案.doc
  16. cassandra java cql_使用Cassandra CQL Java Driver 访问Lindorm
  17. 苏建林DGCNN模型代码详解
  18. C#与产电PLC以太网通讯,C# For LS PLC Ethernet Communication,产电PLC以太网通讯,上位机与PLC通讯C#,LG PLC以太网通讯
  19. java爬网页图片到本地
  20. Ubuntu下 vim安装失败的解决方法 以及安装vim

热门文章

  1. 第二章:Mac OS X内核故事之三位一体:
  2. 混乱的淘宝,堕落的阿里旺旺
  3. 【读书笔记】数据仓库- Apache Kylin权威指南
  4. Rhythmbox 0.9.8
  5. Rhythmbox乱码的解决的方法
  6. 很有趣的MM,你叫她做什么她就做什么,希望不是火星。
  7. 使用Piwigo管理您的照片
  8. 快印客人工智能名片,7个销售新玩法
  9. html与css3入门经典 周靖,HTML5与CSS3从入门到精通
  10. 人工智能数据集(资源篇)(更新于2020.11.27)