模块导入和参数设置

这次除了 Torch 自家模块, 我们还要导入 Gym 环境库模块.

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import gym# 超参数
BATCH_SIZE = 32
LR = 0.01                   # learning rate
EPSILON = 0.9               # 最优选择动作百分比
GAMMA = 0.9                 # 奖励递减参数
TARGET_REPLACE_ITER = 100   # Q 现实网络的更新频率
MEMORY_CAPACITY = 2000      # 记忆库大小
env = gym.make(\'CartPole-v0\')   # 立杆子游戏
env = env.unwrapped
N_ACTIONS = env.action_space.n  # 杆子能做的动作
N_STATES = env.observation_space.shape[0]   # 杆子能获取的环境信息数

神经网络

DQN 当中的神经网络模式, 我们将依据这个模式建立两个神经网络, 一个是现实网络 (Target Net), 一个是估计网络 (Eval Net).

class Net(nn.Module):def __init__(self, ):super(Net, self).__init__()self.fc1 = nn.Linear(N_STATES, 10)self.fc1.weight.data.normal_(0, 0.1)   # initializationself.out = nn.Linear(10, N_ACTIONS)self.out.weight.data.normal_(0, 0.1)   # initializationdef forward(self, x):x = self.fc1(x)x = F.relu(x)actions_value = self.out(x)return actions_value

DQN体系

简化的 DQN 体系是这样, 我们有两个 net, 有选动作机制, 有存经历机制, 有学习机制.

class DQN(object):def __init__(self):# 建立 target net 和 eval net 还有 memorydef choose_action(self, x):# 根据环境观测值选择动作的机制return actiondef store_transition(self, s, a, r, s_):# 存储记忆def learn(self):# target 网络更新# 学习记忆库中的记忆

接下来就是具体的啦, 在 DQN 中每个功能都是怎么做的.

class DQN(object):def __init__(self):self.eval_net, self.target_net = Net(), Net()self.learn_step_counter = 0     # 用于 target 更新计时self.memory_counter = 0         # 记忆库记数self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2   2))     # 初始化记忆库self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)    # torch 的优化器self.loss_func = nn.MSELoss()   # 误差公式def choose_action(self, x):x = Variable(torch.unsqueeze(torch.FloatTensor(x), 0))# 这里只输入一个 sampleif np.random.uniform() < EPSILON:   # 选最优动作actions_value = self.eval_net.forward(x)action = torch.max(actions_value, 1)[1].data.numpy()[0, 0]     # return the argmaxelse:   # 选随机动作action = np.random.randint(0, N_ACTIONS)return actiondef store_transition(self, s, a, r, s_):transition = np.hstack((s, [a, r], s_))# 如果记忆库满了, 就覆盖老数据index = self.memory_counter % MEMORY_CAPACITYself.memory[index, :] = transitionself.memory_counter  = 1def learn(self):# target net 参数更新if self.learn_step_counter % TARGET_REPLACE_ITER == 0:self.target_net.load_state_dict(self.eval_net.state_dict())self.learn_step_counter  = 1# 抽取记忆库中的批数据sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)b_memory = self.memory[sample_index, :]b_s = Variable(torch.FloatTensor(b_memory[:, :N_STATES]))b_a = Variable(torch.LongTensor(b_memory[:, N_STATES:N_STATES 1].astype(int)))b_r = Variable(torch.FloatTensor(b_memory[:, N_STATES 1:N_STATES 2]))b_s_ = Variable(torch.FloatTensor(b_memory[:, -N_STATES:]))# 针对做过的动作b_a, 来选 q_eval 的值, (q_eval 原本有所有动作的值)q_eval = self.eval_net(b_s).gather(1, b_a)  # shape (batch, 1)q_next = self.target_net(b_s_).detach()     # q_next 不进行反向传递误差, 所以 detachq_target = b_r   GAMMA * q_next.max(1)[0]   # shape (batch, 1)loss = self.loss_func(q_eval, q_target)# 计算, 更新 eval netself.optimizer.zero_grad()loss.backward()self.optimizer.step()

训练

按照 Qlearning 的形式进行 off-policy 的更新. 我们进行回合制更行, 一个回合完了, 进入下一回合. 一直到他们将杆子立起来很久.

dqn = DQN() # 定义 DQN 系统for i_episode in range(400):s = env.reset()while True:env.render()    # 显示实验动画a = dqn.choose_action(s)# 选动作, 得到环境反馈s_, r, done, info = env.step(a)# 修改 reward, 使 DQN 快速学习x, x_dot, theta, theta_dot = s_r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5r = r1   r2# 存记忆dqn.store_transition(s, a, r, s_)if dqn.memory_counter > MEMORY_CAPACITY:dqn.learn() # 记忆库满了就进行学习if done:    # 如果回合结束, 进入下回合breaks = s_

DQN 强化学习 (Reinforcement Learning)相关推荐

  1. 强化学习 (Reinforcement Learning)

    强化学习: 强化学习是机器学习中的一个领域,强调如何基于环境而行动,以取得最大化的预期利益.其灵感来源于心理学中的行为主义理论,即有机体如何在环境给予的奖励或惩罚的刺激下,逐步形成对刺激的预期,产生能 ...

  2. 强化学习 Reinforcement Learning(三)——是时候用 PARL 框架玩会儿 DOOM 了!!!(下)

    强化学习 Reinforcement Learning(三)-- 是时候用 PARL 框架玩会儿 DOOM 了!!!(下) 本文目录 强化学习 Reinforcement Learning(三)-- ...

  3. 强化学习(Reinforcement Learning)入门学习--01

    强化学习(Reinforcement Learning)入门学习–01 定义 Reinforcement learning (RL) is an area of machine learning in ...

  4. 强化学习(Reinforcement Learning)入门知识

    强化学习(Reinforcement Learning) 概率统计知识 1. 随机变量和观测值 抛硬币是一个随机事件,其结果为**随机变量 X ** 正面为1,反面为0,若第 i 次试验中为正面,则观 ...

  5. 强化学习(Reinforcement Learning)中的Q-Learning、DQN,面试看这篇就够了!

    文章目录 1. 什么是强化学习 2. 强化学习模型 2.1 打折的未来奖励 2.2 Q-Learning算法 2.3 Deep Q Learning(DQN) 2.3.1 神经网络的作用 2.3.2 ...

  6. 强化学习(Reinforcement Learning)

    背景 当我们思考学习的本质时,我们首先想到的可能是我们通过与环境的互动来学习.无论是在学习开车还是在交谈,我们都清楚地意识到环境是如何回应我们的行为的,我们试图通过行为来影响后续发生的事情.从互动中学 ...

  7. 强化学习 (Reinforcement Learning) 基础及论文资料汇总

    持续更新中... 书籍 1. <Reinforcement Learning: An Introduction>Richard S. Sutton and Andrew G.Barto , ...

  8. ​李宏毅机器学习——强化学习Reinforcement Learning

    目录 应用场景 强化学习的本质 以电脑游戏为例 强化学习三个步骤 第一步:有未知参数的函数 第二步:定义Loss 第三步:Optimization RL的难点 类比GAN Policy Gradien ...

  9. 强化学习Reinforcement Learning

    Abstract Abstract 背景 强化学习算法概念 背景 (1) 强化学习的历史发展 1956年Bellman提出了动态规划方法. 1977年Werbos提出只适应动态规划算法. 1988年s ...

最新文章

  1. Redis 主从复制
  2. 居然之家:核心业务系统全面上云,采用PolarDB替代传统商业数据库
  3. Linux 内核态与用户态通信 netlink
  4. java string... 参数_Java String.Format() 方法及参数说明
  5. vue组件 Prop传递数据
  6. 中兴V880使用手记之二——取得root权限
  7. excel loc() python_python pandas df.loc[]的典型用法
  8. hexo添加_hexo 如何给文章添加目录
  9. token干什么用_什么是TOKEN?Token小号的理解运用,拼多多,知乎,快手,抖音的Token是什么意思...
  10. 打乱数组 matlab,matlab对数组前N个数求和
  11. 《彩虹屁》快夸夸我!彩虹屁生成器
  12. CFA一级学习笔记--衍生品(一)--概念以及定义
  13. 恢复Windows7快捷方式小箭头的方法
  14. OPPO技术开放日第六期丨OPPO安全解析“应用与数据安全防护”背后的技术
  15. 独立版微信动态二维码活码管理系统免授权版
  16. 【简单3d网络游戏制作】——基于Unity
  17. 程序员面试,为什么不跟我谈高并发?
  18. HTML中a标签的作用
  19. 每日一介绍:烽火算法2.0
  20. Cytoscape安装及使用

热门文章

  1. 网线8芯线各自作用是什么?几种常用的网线定义行业标准
  2. 尚鼎峰:抖音短视频是如何在几秒钟内吸引用户观看的?
  3. 速学大学计算机基本内容(一)有图
  4. 努比亚手机安装linux,努比亚红魔5G电竞手机将发布;Linux版荣耀MagicBook降价促销...
  5. php操作redis命令
  6. 多项式定理【OI Pharos 6.2.2】
  7. javascript--贪食蛇(完整版-逻辑思路)
  8. 手机端网页技术--使自己做的asp.net网页适应手机浏览
  9. 图片按日期批量导入WPS表格
  10. 《你不懂我,我不怪你》 作者:余秋雨