深度增强学习--Actor Critic
Actor Critic value-based和policy-based的结合
实例代码
1 import sys 2 import gym 3 import pylab 4 import numpy as np 5 from keras.layers import Dense 6 from keras.models import Sequential 7 from keras.optimizers import Adam 8 9 EPISODES = 1000 10 11 12 # A2C(Advantage Actor-Critic) agent for the Cartpole 13 # actor-critic算法结合了value-based和policy-based方法 14 class A2CAgent: 15 def __init__(self, state_size, action_size): 16 # if you want to see Cartpole learning, then change to True 17 self.render = True 18 self.load_model = False 19 # get size of state and action 20 self.state_size = state_size 21 self.action_size = action_size 22 self.value_size = 1 23 24 # These are hyper parameters for the Policy Gradient 25 self.discount_factor = 0.99 26 self.actor_lr = 0.001 27 self.critic_lr = 0.005 28 29 # create model for policy network 30 self.actor = self.build_actor() 31 self.critic = self.build_critic() 32 33 if self.load_model: 34 self.actor.load_weights("./save_model/cartpole_actor.h5") 35 self.critic.load_weights("./save_model/cartpole_critic.h5") 36 37 # approximate policy and value using Neural Network 38 # actor: state is input and probability of each action is output of model 39 def build_actor(self):#actor网络:state-->action 40 actor = Sequential() 41 actor.add(Dense(24, input_dim=self.state_size, activation='relu', 42 kernel_initializer='he_uniform')) 43 actor.add(Dense(self.action_size, activation='softmax', 44 kernel_initializer='he_uniform')) 45 actor.summary() 46 # See note regarding crossentropy in cartpole_reinforce.py 47 actor.compile(loss='categorical_crossentropy', 48 optimizer=Adam(lr=self.actor_lr)) 49 return actor 50 51 # critic: state is input and value of state is output of model 52 def build_critic(self):#critic网络:state-->value,Q值 53 critic = Sequential() 54 critic.add(Dense(24, input_dim=self.state_size, activation='relu', 55 kernel_initializer='he_uniform')) 56 critic.add(Dense(self.value_size, activation='linear', 57 kernel_initializer='he_uniform')) 58 critic.summary() 59 critic.compile(loss="mse", optimizer=Adam(lr=self.critic_lr)) 60 return critic 61 62 # using the output of policy network, pick action stochastically 63 def get_action(self, state): 64 policy = self.actor.predict(state, batch_size=1).flatten()#根据actor网络预测下一步动作 65 return np.random.choice(self.action_size, 1, p=policy)[0] 66 67 # update policy network every episode 68 def train_model(self, state, action, reward, next_state, done): 69 target = np.zeros((1, self.value_size))#(1,1) 70 advantages = np.zeros((1, self.action_size))#(1, 2) 71 72 value = self.critic.predict(state)[0]#critic网络预测的当前q值 73 next_value = self.critic.predict(next_state)[0]#critic网络预测的下一个q值 74 75 ''' 76 理解下面部分 77 ''' 78 if done: 79 advantages[0][action] = reward - value 80 target[0][0] = reward 81 else: 82 advantages[0][action] = reward + self.discount_factor * (next_value) - value#acotr网络 83 target[0][0] = reward + self.discount_factor * next_value#critic网络 84 85 self.actor.fit(state, advantages, epochs=1, verbose=0) 86 self.critic.fit(state, target, epochs=1, verbose=0) 87 88 89 if __name__ == "__main__": 90 # In case of CartPole-v1, maximum length of episode is 500 91 env = gym.make('CartPole-v1') 92 # get size of state and action from environment 93 state_size = env.observation_space.shape[0] 94 action_size = env.action_space.n 95 96 # make A2C agent 97 agent = A2CAgent(state_size, action_size) 98 scores, episodes = [], [] 99 100 for e in range(EPISODES): 101 done = False 102 score = 0 103 state = env.reset() 104 state = np.reshape(state, [1, state_size]) 105 106 while not done: 107 if agent.render: 108 env.render() 109 110 action = agent.get_action(state) 111 next_state, reward, done, info = env.step(action) 112 next_state = np.reshape(next_state, [1, state_size]) 113 # if an action make the episode end, then gives penalty of -100 114 reward = reward if not done or score == 499 else -100 115 116 agent.train_model(state, action, reward, next_state, done)#每执行一次action训练一次 117 118 score += reward 119 state = next_state 120 121 if done: 122 # every episode, plot the play time 123 score = score if score == 500.0 else score + 100 124 scores.append(score) 125 episodes.append(e) 126 pylab.plot(episodes, scores, 'b') 127 pylab.savefig("./save_graph/cartpole_a2c.png") 128 print("episode:", e, " score:", score) 129 130 # if the mean of scores of last 10 episode is bigger than 490 131 # stop training 132 if np.mean(scores[-min(10, len(scores)):]) > 490: 133 sys.exit() 134 135 # save the model 136 if e % 50 == 0: 137 agent.actor.save_weights("./save_model/cartpole_actor.h5") 138 agent.critic.save_weights("./save_model/cartpole_critic.h5")
转载于:https://www.cnblogs.com/buyizhiyou/p/10250161.html
深度增强学习--Actor Critic相关推荐
- 深度增强学习前沿算法思想
作者: Flood Sung,CSDN博主,人工智能方向研究生,专注于深度学习,增强学习与机器人的研究. 责编:何永灿,欢迎人工智能领域技术投稿.约稿.给文章纠错,请发送邮件至heyc@csdn.n ...
- 深度增强学习(DRL)漫谈 - 从AC(Actor-Critic)到A3C(Asynchronous Advantage Actor-Critic)
前言 之前在文章<深度增强学习(DRL)漫谈 - 从DQN到AlphaGo>扯了一些关于DRL的内容,但因为是以DQN为主线,其中大部分谈的是value-based方法.我们知道传统增强学 ...
- 深度增强学习DDPG(Deep Deterministic Policy Gradient)算法源码走读
原文链接:https://blog.csdn.net/jinzhuojun/article/details/82556127 本文是基于OpenAI推出deep reinforcement learn ...
- 深度增强学习(DRL)漫谈 - 信赖域(Trust Region)系方法
一.背景 深度学习的兴起让增强学习这个古老的机器学习分支迎来一轮复兴.它们的结合领域-深度增强学习(Deep reinforcement learning, DRL)随着在一系列极具挑战的控制实验场景 ...
- 深度增强学习PPO(Proximal Policy Optimization)算法源码走读
原文地址:https://blog.csdn.net/jinzhuojun/article/details/80417179 OpenAI出品的baselines项目提供了一系列deep reinfo ...
- 【AAAI Oral】利用深度增强学习自动解数学题,准确率提升15%
[AI科技大本营导读]增强学习和人类学习的机制非常相近,DeepMind已经将增强学习应用于AlphaGo以及Atari游戏等场景当中.作为智能教育领域的引领者,阿凡题研究院首次提出了一种基于DQN( ...
- Deep Reinforcement Learning 深度增强学习资源
http://blog.csdn.net/songrotek/article/details/50572935 1 学习资料 增强学习课程 David Silver (有视频和ppt): http:/ ...
- 深度增强学习:走向通用人工智能之路
深度增强学习:走向通用人工智能之路 本文是系列文章中的第一篇,是对深度增强学习/深度强化学习的基本介绍以及对实现通用人工智能的探讨. 现在但凡写人工智能的文章,必提Alpha Go.也正是因为Alph ...
- (ICLR2019)论文阅读-使用深度增强学习框架的基于场景先验知识的视觉语义导航
论文地址: VISUAL SEMANTIC NAVIGATION USING SCENE PRIORS 1. 介绍 论文的目标是使用场景先验知识来改善陌生场景中未知物体的导航效果,具体地,如下图所 ...
最新文章
- 我的路子 - 发现游戏为模型的软件架构方式
- android打不开链接,安卓的webView的loadUrl打不开,太长的url超链接,求解
- 计算机英语audios啥意思,2020考研英语词汇:audio是什么意思
- php base64图片大小,php 图片 base64编码相互转换
- 尺度不变特征变换(SIFT算法)Matlab程序代码测试例子的说明(Lowe的代码)
- POJ 1177 Picture
- 一键批量检测微信是否被好友删除
- hpm1216nfh驱动程序_惠普m1216nfh
- 牛顿雕像和墓地上镌刻着的两句话
- 通过stm32cubemx配置DCMI驱动ov5640摄像头
- 《数据结构与算法自学与面试指南》01-01:图灵奖得主尼古拉斯·沃斯
- 将1自动补位为01_英雄联盟如何避免被自动补位到辅助位置
- kubeadm 方式搭建k8s笔记
- yahoo邮箱foxmail收发
- Unity开发基础——使用字符串学习笔记
- ireport的简单使用(数据表格)报表
- 电脑上如何卸载html5,电脑安装影子系统后卸载不了怎么办
- 【运筹学】线性规划 最优解分析 ( 唯一最优解 | 无穷多最优解 | 无界解 | 无可行解 | 迭代范围 | 求解步骤 )
- “航天天域分布式数据库”获评第四届数字中国国务院国资委央企十大科技成果
- 开关电源的纹波和噪声
热门文章
- MVC教程第五篇:MVC整合Ajax
- JQuery上传插件Uploadify详解及其中文按钮解决方案
- UpdatePanel 的 UpdateMode 和 ChildrenAsTriggers(较好的总结了前面几篇博客的内容)
- 《scikit-learn》xgboost
- markdown数学公式手册
- Torch 学习总结
- cuda编程python接口_混合编程[python+cpp+cuda]
- leetcode - 543. 二叉树的直径
- Math3中StatUtils类和MathArrays的使用(数组运算)
- gensim训练word2vec并使用PCA实现二维可视化