DQN 强化学习 (Reinforcement Learning)
模块导入和参数设置
这次除了 Torch 自家模块, 我们还要导入 Gym 环境库模块.
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import gym# 超参数
BATCH_SIZE = 32
LR = 0.01 # learning rate
EPSILON = 0.9 # 最优选择动作百分比
GAMMA = 0.9 # 奖励递减参数
TARGET_REPLACE_ITER = 100 # Q 现实网络的更新频率
MEMORY_CAPACITY = 2000 # 记忆库大小
env = gym.make(\'CartPole-v0\') # 立杆子游戏
env = env.unwrapped
N_ACTIONS = env.action_space.n # 杆子能做的动作
N_STATES = env.observation_space.shape[0] # 杆子能获取的环境信息数
神经网络
DQN 当中的神经网络模式, 我们将依据这个模式建立两个神经网络, 一个是现实网络 (Target Net), 一个是估计网络 (Eval Net).
class Net(nn.Module):def __init__(self, ):super(Net, self).__init__()self.fc1 = nn.Linear(N_STATES, 10)self.fc1.weight.data.normal_(0, 0.1) # initializationself.out = nn.Linear(10, N_ACTIONS)self.out.weight.data.normal_(0, 0.1) # initializationdef forward(self, x):x = self.fc1(x)x = F.relu(x)actions_value = self.out(x)return actions_value
DQN体系
简化的 DQN 体系是这样, 我们有两个 net, 有选动作机制, 有存经历机制, 有学习机制.
class DQN(object):def __init__(self):# 建立 target net 和 eval net 还有 memorydef choose_action(self, x):# 根据环境观测值选择动作的机制return actiondef store_transition(self, s, a, r, s_):# 存储记忆def learn(self):# target 网络更新# 学习记忆库中的记忆
接下来就是具体的啦, 在 DQN 中每个功能都是怎么做的.
class DQN(object):def __init__(self):self.eval_net, self.target_net = Net(), Net()self.learn_step_counter = 0 # 用于 target 更新计时self.memory_counter = 0 # 记忆库记数self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 2)) # 初始化记忆库self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR) # torch 的优化器self.loss_func = nn.MSELoss() # 误差公式def choose_action(self, x):x = Variable(torch.unsqueeze(torch.FloatTensor(x), 0))# 这里只输入一个 sampleif np.random.uniform() < EPSILON: # 选最优动作actions_value = self.eval_net.forward(x)action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmaxelse: # 选随机动作action = np.random.randint(0, N_ACTIONS)return actiondef store_transition(self, s, a, r, s_):transition = np.hstack((s, [a, r], s_))# 如果记忆库满了, 就覆盖老数据index = self.memory_counter % MEMORY_CAPACITYself.memory[index, :] = transitionself.memory_counter = 1def learn(self):# target net 参数更新if self.learn_step_counter % TARGET_REPLACE_ITER == 0:self.target_net.load_state_dict(self.eval_net.state_dict())self.learn_step_counter = 1# 抽取记忆库中的批数据sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)b_memory = self.memory[sample_index, :]b_s = Variable(torch.FloatTensor(b_memory[:, :N_STATES]))b_a = Variable(torch.LongTensor(b_memory[:, N_STATES:N_STATES 1].astype(int)))b_r = Variable(torch.FloatTensor(b_memory[:, N_STATES 1:N_STATES 2]))b_s_ = Variable(torch.FloatTensor(b_memory[:, -N_STATES:]))# 针对做过的动作b_a, 来选 q_eval 的值, (q_eval 原本有所有动作的值)q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1)q_next = self.target_net(b_s_).detach() # q_next 不进行反向传递误差, 所以 detachq_target = b_r GAMMA * q_next.max(1)[0] # shape (batch, 1)loss = self.loss_func(q_eval, q_target)# 计算, 更新 eval netself.optimizer.zero_grad()loss.backward()self.optimizer.step()
训练
按照 Qlearning 的形式进行 off-policy 的更新. 我们进行回合制更行, 一个回合完了, 进入下一回合. 一直到他们将杆子立起来很久.
dqn = DQN() # 定义 DQN 系统for i_episode in range(400):s = env.reset()while True:env.render() # 显示实验动画a = dqn.choose_action(s)# 选动作, 得到环境反馈s_, r, done, info = env.step(a)# 修改 reward, 使 DQN 快速学习x, x_dot, theta, theta_dot = s_r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5r = r1 r2# 存记忆dqn.store_transition(s, a, r, s_)if dqn.memory_counter > MEMORY_CAPACITY:dqn.learn() # 记忆库满了就进行学习if done: # 如果回合结束, 进入下回合breaks = s_
DQN 强化学习 (Reinforcement Learning)相关推荐
- 强化学习 (Reinforcement Learning)
强化学习: 强化学习是机器学习中的一个领域,强调如何基于环境而行动,以取得最大化的预期利益.其灵感来源于心理学中的行为主义理论,即有机体如何在环境给予的奖励或惩罚的刺激下,逐步形成对刺激的预期,产生能 ...
- 强化学习 Reinforcement Learning(三)——是时候用 PARL 框架玩会儿 DOOM 了!!!(下)
强化学习 Reinforcement Learning(三)-- 是时候用 PARL 框架玩会儿 DOOM 了!!!(下) 本文目录 强化学习 Reinforcement Learning(三)-- ...
- 强化学习(Reinforcement Learning)入门学习--01
强化学习(Reinforcement Learning)入门学习–01 定义 Reinforcement learning (RL) is an area of machine learning in ...
- 强化学习(Reinforcement Learning)入门知识
强化学习(Reinforcement Learning) 概率统计知识 1. 随机变量和观测值 抛硬币是一个随机事件,其结果为**随机变量 X ** 正面为1,反面为0,若第 i 次试验中为正面,则观 ...
- 强化学习(Reinforcement Learning)中的Q-Learning、DQN,面试看这篇就够了!
文章目录 1. 什么是强化学习 2. 强化学习模型 2.1 打折的未来奖励 2.2 Q-Learning算法 2.3 Deep Q Learning(DQN) 2.3.1 神经网络的作用 2.3.2 ...
- 强化学习(Reinforcement Learning)
背景 当我们思考学习的本质时,我们首先想到的可能是我们通过与环境的互动来学习.无论是在学习开车还是在交谈,我们都清楚地意识到环境是如何回应我们的行为的,我们试图通过行为来影响后续发生的事情.从互动中学 ...
- 强化学习 (Reinforcement Learning) 基础及论文资料汇总
持续更新中... 书籍 1. <Reinforcement Learning: An Introduction>Richard S. Sutton and Andrew G.Barto , ...
- 李宏毅机器学习——强化学习Reinforcement Learning
目录 应用场景 强化学习的本质 以电脑游戏为例 强化学习三个步骤 第一步:有未知参数的函数 第二步:定义Loss 第三步:Optimization RL的难点 类比GAN Policy Gradien ...
- 强化学习Reinforcement Learning
Abstract Abstract 背景 强化学习算法概念 背景 (1) 强化学习的历史发展 1956年Bellman提出了动态规划方法. 1977年Werbos提出只适应动态规划算法. 1988年s ...
最新文章
- Redis 主从复制
- 居然之家:核心业务系统全面上云,采用PolarDB替代传统商业数据库
- Linux 内核态与用户态通信 netlink
- java string... 参数_Java String.Format() 方法及参数说明
- vue组件 Prop传递数据
- 中兴V880使用手记之二——取得root权限
- excel loc() python_python pandas df.loc[]的典型用法
- hexo添加_hexo 如何给文章添加目录
- token干什么用_什么是TOKEN?Token小号的理解运用,拼多多,知乎,快手,抖音的Token是什么意思...
- 打乱数组 matlab,matlab对数组前N个数求和
- 《彩虹屁》快夸夸我!彩虹屁生成器
- CFA一级学习笔记--衍生品(一)--概念以及定义
- 恢复Windows7快捷方式小箭头的方法
- OPPO技术开放日第六期丨OPPO安全解析“应用与数据安全防护”背后的技术
- 独立版微信动态二维码活码管理系统免授权版
- 【简单3d网络游戏制作】——基于Unity
- 程序员面试,为什么不跟我谈高并发?
- HTML中a标签的作用
- 每日一介绍:烽火算法2.0
- Cytoscape安装及使用