PYTORCH笔记 actor-critic (A2C)
理论知识见:强化学习笔记:Actor-critic_UQI-LIUWJ的博客-CSDN博客
由于actor-critic是policy gradient和DQN的结合,所以同时很多部分和policy network,DQN的代码部分很接近
pytorch笔记:policy gradient_UQI-LIUWJ的博客-CSDN博客
pytorch 笔记: DQN(experience replay)_UQI-LIUWJ的博客-CSDN博客
1 导入库 & 超参数
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import timefrom torch.distributions import CategoricalGAMMA = 0.95
#奖励折扣因子
LR = 0.01
#学习率EPISODE = 3000
# 生成多少个episode
STEP = 3000
# 一个episode里面最多多少步
TEST = 10
# 每100步episode后进行测试,测试多少个
2 actor 部分
2.1 actor 基本类
class PGNetwork(nn.Module):def __init__(self, state_dim, action_dim):super(PGNetwork, self).__init__()self.fc1 = nn.Linear(state_dim, 20)self.fc2 = nn.Linear(20, action_dim)def forward(self, x):x = F.relu(self.fc1(x))action_scores = self.fc2(x)return F.softmax(action_scores,dim=1)#PGNetwork的作用是输入某一时刻的state向量,输出是各个action被采纳的概率#和policy gradient中的Policy一样
2.2 actor 类
2.2.1 __init__
class Actor(object):def __init__(self, env): # 初始化self.state_dim = env.observation_space.shape[0]#表示某一时刻状态是几个维度组成的#在推杆小车问题中,这一数值为4self.action_dim = env.action_space.n#表示某一时刻动作空间的维度(可以有几个不同的动作)#在推杆小车问题中,这一数值为2self.network = PGNetwork(state_dim=self.state_dim, action_dim=self.action_dim)#输入S输出各个动作被采取的概率self.optimizer = torch.optim.Adam(self.network.parameters(), lr=LR)
2.2.2 选择动作
和policy gradient中的几乎一模一样
def choose_action(self, observation):# 选择动作,这个动作不是根据Q值来选择,而是使用softmax生成的概率来选# 在policy gradient和A2C中,不需要epsilon-greedy,因为概率本身就具有随机性observation = torch.from_numpy(observation).float().unsqueeze(0)#print(state.shape) #torch.size([1,4])#通过unsqueeze操作变成[1,4]维的向量probs = self.network(observation)#Policy的返回结果,在状态x下各个action被执行的概率m = Categorical(probs) # 生成分布action = m.sample() # 从分布中采样(根据各个action的概率)#print(m.log_prob(action))# m.log_prob(action)相当于probs.log()[0][action.item()].unsqueeze(0)#换句话说,就是选出来的这个action的概率,再加上log运算return action.item() # 返回一个元素值'''所以每一次select_action做的事情是,选择一个合理的action,返回这个action;'''
2.2.3 学习actor 网络
也就是学习如何更好地选择action
neg_log_prob 在后续的critic中会有计算的方法,相当于
def learn(self, state, action, td_error):observation = torch.from_numpy(state).float().unsqueeze(0)softmax_input = self.network(observation)#各个action被采取的概率action = torch.LongTensor([action])neg_log_prob = F.cross_entropy(input=softmax_input, target=action)# 反向传播(梯度上升)# 这里需要最大化当前策略的价值#因此需要最大化neg_log_prob * tf_error,即最小化-neg_log_prob * td_errorloss_a = -neg_log_prob * td_errorself.optimizer.zero_grad()loss_a.backward()self.optimizer.step()#pytorch 老三样
3 critic部分
根据actor的采样,用TD的方式计算V(s)
为了方便起见,这里没有使用target network以及experience relay,这两个可以看DQN 的pytorch代码,里面有涉及
3.1 critic 基本类
class QNetwork(nn.Module):def __init__(self, state_dim, action_dim):super(QNetwork, self).__init__()self.fc1 = nn.Linear(state_dim, 20)self.fc2 = nn.Linear(20, 1) # 这个地方和之前略有区别,输出不是动作维度,而是一维#因为我们这里需要计算的是V(s),而在DQN中,是Q(s,a),所以那里是两维,这里是一维def forward(self, x):out = F.relu(self.fc1(x))out = self.fc2(out)return out
3.2 Critic类
3.2.1 __init__
class Critic(object):#通过采样数据,学习V(S)def __init__(self, env):self.state_dim = env.observation_space.shape[0]#表示某一时刻状态是几个维度组成的#在推杆小车问题中,这一数值为4self.action_dim = env.action_space.n#表示某一时刻动作空间的维度(可以有几个不同的动作)#在推杆小车问题中,这一数值为2self.network = QNetwork(state_dim=self.state_dim, action_dim=self.action_dim)#输入S,输出V(S)self.optimizer = torch.optim.Adam(self.network.parameters(), lr=LR)self.loss_func = nn.MSELoss()
3.2.2 训练critic 网络
def train_Q_network(self, state, reward, next_state):#类似于DQN的5.4,不过这里没有用fixed network,experience relay的机制s, s_ = torch.FloatTensor(state), torch.FloatTensor(next_state)#当前状态,执行了action之后的状态v = self.network(s) # v(s)v_ = self.network(s_) # v(s')# 反向传播loss_q = self.loss_func(reward + GAMMA * v_, v)#TD##r+γV(S') 和V(S) 之间的差距self.optimizer.zero_grad()loss_q.backward()self.optimizer.step()#pytorch老三样with torch.no_grad():td_error = reward + GAMMA * v_ - v#表示不把相应的梯度传到actor中(actor和critic是独立训练的)return td_error
4 主函数
def main():env = gym.make('CartPole-v1')#创建一个推车杆的gym环境actor = Actor(env)critic = Critic(env)for episode in range(EPISODE):state = env.reset()#state表示初始化这一个episode的环境for step in range(STEP):action = actor.choose_action(state) # 根据actor选择actionnext_state, reward, done, _ = env.step(action)#四个返回的内容是state,reward,done(是否重置环境),infotd_error = critic.train_Q_network(state, reward, next_state) # gradient = grad[r + gamma * V(s_) - V(s)]#先根据采样的action,当前状态,后续状态,训练critic,以获得更准确的V(s)值actor.learn(state, action, td_error) # true_gradient = grad[logPi(a|s) * td_error]#然后根据前面学到的V(s)值,训练actor,以更好地采样动作state = next_stateif done:break# 每100步测试效果if episode % 100 == 0:total_reward = 0for i in range(TEST):state = env.reset()for j in range(STEP):#env.render()#渲染环境,如果你是在服务器上跑的,只想出结果,不想看动态推杆过程的话,可以注释掉action = actor.choose_action(state) #采样了一个actionstate, reward, done, _ = env.step(action)#四个返回的内容是state,reward,done(是否重置环境),infototal_reward += rewardif done:breakave_reward = total_reward/TESTprint('episode: ', episode, 'Evaluation Average Reward:', ave_reward)if __name__ == '__main__':time_start = time.time()main()time_end = time.time()print('Total time is ', time_end - time_start, 's')'''
episode: 0 Evaluation Average Reward: 17.2
episode: 100 Evaluation Average Reward: 10.6
episode: 200 Evaluation Average Reward: 11.4
episode: 300 Evaluation Average Reward: 10.7
episode: 400 Evaluation Average Reward: 9.3
episode: 500 Evaluation Average Reward: 9.5
episode: 600 Evaluation Average Reward: 9.5
episode: 700 Evaluation Average Reward: 9.6
episode: 800 Evaluation Average Reward: 9.9
episode: 900 Evaluation Average Reward: 8.9
episode: 1000 Evaluation Average Reward: 9.3
episode: 1100 Evaluation Average Reward: 9.8
episode: 1200 Evaluation Average Reward: 9.3
episode: 1300 Evaluation Average Reward: 9.0
episode: 1400 Evaluation Average Reward: 9.4
episode: 1500 Evaluation Average Reward: 9.3
episode: 1600 Evaluation Average Reward: 9.1
episode: 1700 Evaluation Average Reward: 9.0
episode: 1800 Evaluation Average Reward: 9.6
episode: 1900 Evaluation Average Reward: 8.8
episode: 2000 Evaluation Average Reward: 9.4
episode: 2100 Evaluation Average Reward: 9.2
episode: 2200 Evaluation Average Reward: 9.4
episode: 2300 Evaluation Average Reward: 9.2
episode: 2400 Evaluation Average Reward: 9.3
episode: 2500 Evaluation Average Reward: 9.5
episode: 2600 Evaluation Average Reward: 9.6
episode: 2700 Evaluation Average Reward: 9.2
episode: 2800 Evaluation Average Reward: 9.1
episode: 2900 Evaluation Average Reward: 9.6
Total time is 41.6014940738678 s
'''
PYTORCH笔记 actor-critic (A2C)相关推荐
- 强化学习论文笔记:Soft Actor Critic算法
Soft Actor Critic是伯克利大学团队在2018年的ICML(International Conference on Machine Learning)上发表的off-policy mod ...
- 【强化学习笔记】2020 李宏毅 强化学习课程笔记(PPO、Q-Learning、Actor + Critic、Sparse Reward、IRL)
前言 如果你对这篇文章感兴趣,可以点击「[访客必读 - 指引页]一文囊括主页内所有高质量博客」,查看完整博客分类与对应链接. 文章目录 前言 Introduction Two Learning Mod ...
- 深度增强学习--Actor Critic
Actor Critic value-based和policy-based的结合 实例代码 1 import sys 2 import gym 3 import pylab 4 import nump ...
- pytorch笔记:policy gradient
本文参考了 策略梯度PG( Policy Gradient) 的pytorch代码实现示例 cart-pole游戏_李莹斌XJTU的博客-CSDN博客_策略梯度pytorch 在其基础上添加了注释和自 ...
- pytorch 笔记:手动实现AR (auto regressive)
1 导入库& 数据说明 import numpy as np import torch import matplotlib.pyplot as plt from tensorboardX im ...
- pytorch 笔记:tensorboardX
1 SummaryWriter 1.1 创建 首先,需要创建一个 SummaryWriter 的示例: from tensorboardX import SummaryWriter#以下是三种不同的初 ...
- pytorch 笔记:DataLoader 扩展:构造图片DataLoader
数据来源:OneDrive for Business 涉及内容:pytorch笔记:Dataloader_UQI-LIUWJ的博客-CSDN博客 torchvision 笔记:ToTensor()_U ...
- pytorch 笔记:torchsummary
作用:打印神经网络的结构 以pytorch笔记:搭建简易CNN_UQI-LIUWJ的博客-CSDN博客 中搭建的CNN为例 import torch from torchsummary import ...
- (d2l-ai/d2l-zh)《动手学深度学习》pytorch 笔记(2)前言(介绍各种机器学习问题)以及数据操作预备知识Ⅰ
开源项目地址:d2l-ai/d2l-zh 教材官网:https://zh.d2l.ai/ 书介绍:https://zh-v2.d2l.ai/ 笔记基于2021年7月26日发布的版本,书及代码下载地址在 ...
最新文章
- 当不使用会话状态时禁用它
- 【LeetCode从零单排】No21.MergeTwoSortedLists
- ASMCMD 命令详解
- vue - 官方 - 上手
- hive 窗口函数_Datatist科技专栏 | Hive排序窗口函数速学教程!
- Bluetooth Obex
- Xshell 常见问题及相关配置
- 2.对memcached进行curd操作
- html中css字体颜色代码大全,css字体颜色的设置方法
- C# ZPL打印标签
- HTML5 学习总结(一)——HTML5概要与新增标签
- springboot 整合mybatis,pagehelper。测试类。
- 【数字IC第一步】Linux系统安装(含常用IC软件)
- 基于Opencv3的活动轮廓模型--CV, RSF and DRLSE
- icewm+rox-filer美化过程(转)
- 计算机碎片整理,如何对计算机进行碎片整理
- libusb android 编译,Android如何对libusb进行编译和使用
- filebeat k8s健康探针
- 青提WiFi微信小程序功能介绍解析及其运营常见问题
- 双硫脲改性Zr-MOF吸附材料|聚多巴胺(PDA)改性MOF-5|羧酸改性的UiO-66(Zr)膜|有机骨架材料的定制技术
热门文章
- 【贪心】Codeforces Round #436 (Div. 2) D. Make a Permutation!
- 《AngularJS深度剖析与最佳实践》一2.2 模块
- sqlserver 多排序的问题
- c#实现数据集合转换为csv文本
- 金立又推新机 欧新V908或近期发布 外观设计独到
- SQL Server 中常见的十张系统表
- ruby(wrong number of arguments (1 for 2) )
- CSP认证201509-2	日期计算[C++题解]:枚举、模拟
- PAT甲级1119 Pre- and Post-order Traversals (30分):[C++题解]暴搜dfs、前序遍历和后序遍历求中序遍历
- maven配置阿里云_阿里云OSS PicGo 配置图床教程 超详细