理论知识见:强化学习笔记:Actor-critic_UQI-LIUWJ的博客-CSDN博客

由于actor-critic是policy gradient和DQN的结合,所以同时很多部分和policy network,DQN的代码部分很接近

pytorch笔记:policy gradient_UQI-LIUWJ的博客-CSDN博客

pytorch 笔记: DQN(experience replay)_UQI-LIUWJ的博客-CSDN博客

1 导入库 & 超参数

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import timefrom torch.distributions import CategoricalGAMMA = 0.95
#奖励折扣因子
LR = 0.01
#学习率EPISODE = 3000
# 生成多少个episode
STEP = 3000
# 一个episode里面最多多少步
TEST = 10
# 每100步episode后进行测试,测试多少个

2 actor 部分

2.1 actor 基本类

class PGNetwork(nn.Module):def __init__(self, state_dim, action_dim):super(PGNetwork, self).__init__()self.fc1 = nn.Linear(state_dim, 20)self.fc2 = nn.Linear(20, action_dim)def forward(self, x):x = F.relu(self.fc1(x))action_scores = self.fc2(x)return F.softmax(action_scores,dim=1)#PGNetwork的作用是输入某一时刻的state向量,输出是各个action被采纳的概率#和policy gradient中的Policy一样

2.2 actor 类

2.2.1 __init__

class Actor(object):def __init__(self, env):  # 初始化self.state_dim = env.observation_space.shape[0]#表示某一时刻状态是几个维度组成的#在推杆小车问题中,这一数值为4self.action_dim = env.action_space.n#表示某一时刻动作空间的维度(可以有几个不同的动作)#在推杆小车问题中,这一数值为2self.network = PGNetwork(state_dim=self.state_dim, action_dim=self.action_dim)#输入S输出各个动作被采取的概率self.optimizer = torch.optim.Adam(self.network.parameters(), lr=LR)

2.2.2 选择动作

和policy gradient中的几乎一模一样

def choose_action(self, observation):# 选择动作,这个动作不是根据Q值来选择,而是使用softmax生成的概率来选#  在policy gradient和A2C中,不需要epsilon-greedy,因为概率本身就具有随机性observation =  torch.from_numpy(observation).float().unsqueeze(0)#print(state.shape)   #torch.size([1,4])#通过unsqueeze操作变成[1,4]维的向量probs = self.network(observation)#Policy的返回结果,在状态x下各个action被执行的概率m = Categorical(probs)      # 生成分布action = m.sample()           # 从分布中采样(根据各个action的概率)#print(m.log_prob(action))# m.log_prob(action)相当于probs.log()[0][action.item()].unsqueeze(0)#换句话说,就是选出来的这个action的概率,再加上log运算return action.item()         # 返回一个元素值'''所以每一次select_action做的事情是,选择一个合理的action,返回这个action;'''

2.2.3 学习actor 网络

也就是学习如何更好地选择action

neg_log_prob 在后续的critic中会有计算的方法,相当于

 def learn(self, state, action, td_error):observation =  torch.from_numpy(state).float().unsqueeze(0)softmax_input = self.network(observation)#各个action被采取的概率action = torch.LongTensor([action])neg_log_prob = F.cross_entropy(input=softmax_input, target=action)# 反向传播(梯度上升)# 这里需要最大化当前策略的价值#因此需要最大化neg_log_prob * tf_error,即最小化-neg_log_prob * td_errorloss_a = -neg_log_prob * td_errorself.optimizer.zero_grad()loss_a.backward()self.optimizer.step()#pytorch 老三样

3 critic部分

根据actor的采样,用TD的方式计算V(s)

为了方便起见,这里没有使用target network以及experience relay,这两个可以看DQN 的pytorch代码,里面有涉及

3.1 critic 基本类

class QNetwork(nn.Module):def __init__(self, state_dim, action_dim):super(QNetwork, self).__init__()self.fc1 = nn.Linear(state_dim, 20)self.fc2 = nn.Linear(20, 1)   # 这个地方和之前略有区别,输出不是动作维度,而是一维#因为我们这里需要计算的是V(s),而在DQN中,是Q(s,a),所以那里是两维,这里是一维def forward(self, x):out = F.relu(self.fc1(x))out = self.fc2(out)return out

3.2 Critic类

3.2.1 __init__

class Critic(object):#通过采样数据,学习V(S)def __init__(self, env):self.state_dim = env.observation_space.shape[0]#表示某一时刻状态是几个维度组成的#在推杆小车问题中,这一数值为4self.action_dim = env.action_space.n#表示某一时刻动作空间的维度(可以有几个不同的动作)#在推杆小车问题中,这一数值为2self.network = QNetwork(state_dim=self.state_dim, action_dim=self.action_dim)#输入S,输出V(S)self.optimizer = torch.optim.Adam(self.network.parameters(), lr=LR)self.loss_func = nn.MSELoss()

3.2.2  训练critic 网络

def train_Q_network(self, state, reward, next_state):#类似于DQN的5.4,不过这里没有用fixed network,experience relay的机制s, s_ = torch.FloatTensor(state), torch.FloatTensor(next_state)#当前状态,执行了action之后的状态v = self.network(s)     # v(s)v_ = self.network(s_)   # v(s')# 反向传播loss_q = self.loss_func(reward + GAMMA * v_, v)#TD##r+γV(S') 和V(S) 之间的差距self.optimizer.zero_grad()loss_q.backward()self.optimizer.step()#pytorch老三样with torch.no_grad():td_error = reward + GAMMA * v_ - v#表示不把相应的梯度传到actor中(actor和critic是独立训练的)return td_error

4 主函数

def main():env = gym.make('CartPole-v1')#创建一个推车杆的gym环境actor = Actor(env)critic = Critic(env)for episode in range(EPISODE):state = env.reset()#state表示初始化这一个episode的环境for step in range(STEP):action = actor.choose_action(state)  # 根据actor选择actionnext_state, reward, done, _ = env.step(action)#四个返回的内容是state,reward,done(是否重置环境),infotd_error = critic.train_Q_network(state, reward, next_state)  # gradient = grad[r + gamma * V(s_) - V(s)]#先根据采样的action,当前状态,后续状态,训练critic,以获得更准确的V(s)值actor.learn(state, action, td_error)  # true_gradient = grad[logPi(a|s) * td_error]#然后根据前面学到的V(s)值,训练actor,以更好地采样动作state = next_stateif done:break# 每100步测试效果if episode % 100 == 0:total_reward = 0for i in range(TEST):state = env.reset()for j in range(STEP):#env.render()#渲染环境,如果你是在服务器上跑的,只想出结果,不想看动态推杆过程的话,可以注释掉action = actor.choose_action(state)  #采样了一个actionstate, reward, done, _ = env.step(action)#四个返回的内容是state,reward,done(是否重置环境),infototal_reward += rewardif done:breakave_reward = total_reward/TESTprint('episode: ', episode, 'Evaluation Average Reward:', ave_reward)if __name__ == '__main__':time_start = time.time()main()time_end = time.time()print('Total time is ', time_end - time_start, 's')'''
episode:  0 Evaluation Average Reward: 17.2
episode:  100 Evaluation Average Reward: 10.6
episode:  200 Evaluation Average Reward: 11.4
episode:  300 Evaluation Average Reward: 10.7
episode:  400 Evaluation Average Reward: 9.3
episode:  500 Evaluation Average Reward: 9.5
episode:  600 Evaluation Average Reward: 9.5
episode:  700 Evaluation Average Reward: 9.6
episode:  800 Evaluation Average Reward: 9.9
episode:  900 Evaluation Average Reward: 8.9
episode:  1000 Evaluation Average Reward: 9.3
episode:  1100 Evaluation Average Reward: 9.8
episode:  1200 Evaluation Average Reward: 9.3
episode:  1300 Evaluation Average Reward: 9.0
episode:  1400 Evaluation Average Reward: 9.4
episode:  1500 Evaluation Average Reward: 9.3
episode:  1600 Evaluation Average Reward: 9.1
episode:  1700 Evaluation Average Reward: 9.0
episode:  1800 Evaluation Average Reward: 9.6
episode:  1900 Evaluation Average Reward: 8.8
episode:  2000 Evaluation Average Reward: 9.4
episode:  2100 Evaluation Average Reward: 9.2
episode:  2200 Evaluation Average Reward: 9.4
episode:  2300 Evaluation Average Reward: 9.2
episode:  2400 Evaluation Average Reward: 9.3
episode:  2500 Evaluation Average Reward: 9.5
episode:  2600 Evaluation Average Reward: 9.6
episode:  2700 Evaluation Average Reward: 9.2
episode:  2800 Evaluation Average Reward: 9.1
episode:  2900 Evaluation Average Reward: 9.6
Total time is  41.6014940738678 s
'''

PYTORCH笔记 actor-critic (A2C)相关推荐

  1. 强化学习论文笔记:Soft Actor Critic算法

    Soft Actor Critic是伯克利大学团队在2018年的ICML(International Conference on Machine Learning)上发表的off-policy mod ...

  2. 【强化学习笔记】2020 李宏毅 强化学习课程笔记(PPO、Q-Learning、Actor + Critic、Sparse Reward、IRL)

    前言 如果你对这篇文章感兴趣,可以点击「[访客必读 - 指引页]一文囊括主页内所有高质量博客」,查看完整博客分类与对应链接. 文章目录 前言 Introduction Two Learning Mod ...

  3. 深度增强学习--Actor Critic

    Actor Critic value-based和policy-based的结合 实例代码 1 import sys 2 import gym 3 import pylab 4 import nump ...

  4. pytorch笔记:policy gradient

    本文参考了 策略梯度PG( Policy Gradient) 的pytorch代码实现示例 cart-pole游戏_李莹斌XJTU的博客-CSDN博客_策略梯度pytorch 在其基础上添加了注释和自 ...

  5. pytorch 笔记:手动实现AR (auto regressive)

    1 导入库& 数据说明 import numpy as np import torch import matplotlib.pyplot as plt from tensorboardX im ...

  6. pytorch 笔记:tensorboardX

    1 SummaryWriter 1.1 创建 首先,需要创建一个 SummaryWriter 的示例: from tensorboardX import SummaryWriter#以下是三种不同的初 ...

  7. pytorch 笔记:DataLoader 扩展:构造图片DataLoader

    数据来源:OneDrive for Business 涉及内容:pytorch笔记:Dataloader_UQI-LIUWJ的博客-CSDN博客 torchvision 笔记:ToTensor()_U ...

  8. pytorch 笔记:torchsummary

    作用:打印神经网络的结构 以pytorch笔记:搭建简易CNN_UQI-LIUWJ的博客-CSDN博客 中搭建的CNN为例 import torch from torchsummary import ...

  9. (d2l-ai/d2l-zh)《动手学深度学习》pytorch 笔记(2)前言(介绍各种机器学习问题)以及数据操作预备知识Ⅰ

    开源项目地址:d2l-ai/d2l-zh 教材官网:https://zh.d2l.ai/ 书介绍:https://zh-v2.d2l.ai/ 笔记基于2021年7月26日发布的版本,书及代码下载地址在 ...

最新文章

  1. 当不使用会话状态时禁用它
  2. 【LeetCode从零单排】No21.MergeTwoSortedLists
  3. ASMCMD 命令详解
  4. vue - 官方 - 上手
  5. hive 窗口函数_Datatist科技专栏 | Hive排序窗口函数速学教程!
  6. Bluetooth Obex
  7. Xshell 常见问题及相关配置
  8. 2.对memcached进行curd操作
  9. html中css字体颜色代码大全,css字体颜色的设置方法
  10. C# ZPL打印标签
  11. HTML5 学习总结(一)——HTML5概要与新增标签
  12. springboot 整合mybatis,pagehelper。测试类。
  13. 【数字IC第一步】Linux系统安装(含常用IC软件)
  14. 基于Opencv3的活动轮廓模型--CV, RSF and DRLSE
  15. icewm+rox-filer美化过程(转)
  16. 计算机碎片整理,如何对计算机进行碎片整理
  17. libusb android 编译,Android如何对libusb进行编译和使用
  18. filebeat k8s健康探针
  19. 青提WiFi微信小程序功能介绍解析及其运营常见问题
  20. 双硫脲改性Zr-MOF吸附材料|聚多巴胺(PDA)改性MOF-5|羧酸改性的UiO-66(Zr)膜|有机骨架材料的定制技术

热门文章

  1. 【贪心】Codeforces Round #436 (Div. 2) D. Make a Permutation!
  2. 《AngularJS深度剖析与最佳实践》一2.2 模块
  3. sqlserver 多排序的问题
  4. c#实现数据集合转换为csv文本
  5. 金立又推新机 欧新V908或近期发布 外观设计独到
  6. SQL Server 中常见的十张系统表
  7. ruby(wrong number of arguments (1 for 2) )
  8. CSP认证201509-2 日期计算[C++题解]:枚举、模拟
  9. PAT甲级1119 Pre- and Post-order Traversals (30分):[C++题解]暴搜dfs、前序遍历和后序遍历求中序遍历
  10. maven配置阿里云_阿里云OSS PicGo 配置图床教程 超详细