1 前言

此博文是南溪学习《REINFORCEMENT LEARNING (DQN) TUTORIAL》的笔记~

2 代码学习

2.1 Hyperparameters and utilities

这里主要是超参数的设置;

BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10# Get screen size so that we can initialize layers correctly based on shape
# returned from AI gym. Typical dimensions at this point are close to 3x40x90
# which is the result of a clamped and down-scaled render buffer in get_screen()
init_screen = get_screen()
_, _, screen_height, screen_width = init_screen.shape# Get number of actions from gym action space
n_actions = env.action_space.npolicy_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()optimizer = optim.RMSprop(policy_net.parameters())
# 设置回放的容量为10000
memory = ReplayMemory(10000)steps_done = 0def select_action(state):global steps_donesample = random.random()eps_threshold = EPS_END + (EPS_START - EPS_END) * \math.exp(-1. * steps_done / EPS_DECAY)steps_done += 1if sample > eps_threshold:with torch.no_grad():# t.max(1) will return largest column value of each row.# second column on max result is index of where max element was# found, so we pick action with the larger expected reward.return policy_net(state).max(1)[1].view(1, 1)else:return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)episode_durations = []def plot_durations():plt.figure(2)plt.clf()durations_t = torch.tensor(episode_durations, dtype=torch.float)plt.title('Training...')plt.xlabel('Episode')plt.ylabel('Duration')plt.plot(durations_t.numpy())# Take 100 episode averages and plot them tooif len(durations_t) >= 100:means = durations_t.unfold(0, 100, 1).mean(1).view(-1)means = torch.cat((torch.zeros(99), means))plt.plot(means.numpy())plt.pause(0.001)  # pause a bit so that plots are updatedif is_ipython:display.clear_output(wait=True)display.display(plt.gcf())

2.2 optimize_model()——模型优化过程

def optimize_model():if len(memory) < BATCH_SIZE:returntransitions = memory.sample(BATCH_SIZE)# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for# detailed explanation). This converts batch-array of Transitions# to Transition of batch-arrays.batch = Transition(*zip(*transitions))# Compute a mask of non-final states and concatenate the batch elements# (a final state would've been the one after which simulation ended)non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,batch.next_state)), device=device, dtype=torch.bool)non_final_next_states = torch.cat([s for s in batch.next_stateif s is not None])state_batch = torch.cat(batch.state)action_batch = torch.cat(batch.action)reward_batch = torch.cat(batch.reward)# Compute Q(s_t, a) - the model computes Q(s_t), then we select the# columns of actions taken. These are the actions which would've been taken# for each batch state according to policy_netstate_action_values = policy_net(state_batch).gather(1, action_batch)# Compute V(s_{t+1}) for all next states.# Expected values of actions for non_final_next_states are computed based# on the "older" target_net; selecting their best reward with max(1)[0].# This is merged based on the mask, such that we'll have either the expected# state value or 0 in case the state was final.next_state_values = torch.zeros(BATCH_SIZE, device=device)next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()# Compute the expected Q valuesexpected_state_action_values = (next_state_values * GAMMA) + reward_batch# Compute Huber lossloss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))# Optimize the modeloptimizer.zero_grad()loss.backward()for param in policy_net.parameters():param.grad.data.clamp_(-1, 1)optimizer.step()

《REINFORCEMENT LEARNING (DQN) TUTORIAL》的学习笔记相关推荐

  1. 《Approximation Capabilities of Multilayer Feedforward Networks》的学习笔记

    论文链接 <Approximation Capabilities of Multilayer Feedforward Networks> 前言 我在知乎上面看到有一篇关于这篇论文的解读,感 ...

  2. Squeeze-and-Excitation Networks(SENet) 学习笔记

    1. 简介 作者提出了一个SE块的概念,它是根据channel之间的相关性来进行显式建模,从而实现自适应地channel-wise上的特征响应,把重要的特征进行强化.不重要的特征进行抑制来提升准确率. ...

  3. cs224w(图机器学习)2021冬季课程学习笔记11 Theory of Graph Neural Networks

    诸神缄默不语-个人CSDN博文目录 cs224w(图机器学习)2021冬季课程学习笔记集合 文章目录 1. How Expressive are Graph Neural Networks? 2. D ...

  4. 【论文学习笔记】《An Overview of Voice Conversion and Its Challenges》

    <An Overview of Voice Conversion and Its Challenges: From Statistical Modeling to Deep Learning&g ...

  5. 通用近似定理(学习笔记)

    通用近似定理(学习笔记) -----用任意深度的神经网络逼近函数,张玉宏的<深度学习之美>阅读笔记. 发展历程 "通用近似定理"1989年被提出[1],其中George ...

  6. Deep Feedforward Networks(1)

    CODE WORKS Work Here! CONTENTS Deep feedforward networks, also often called feedforward neural netwo ...

  7. 论文阅读:ResMLP: Feedforward networks for image classification with data-efficient training

    ResMLP: Feedforward networks for image classification with data-efficient training [pdf] 目录 Abstract ...

  8. 联邦学习笔记—《Communication-Efficient Learning of Deep Networks from Decentralized Data》

    摘要: Modern mobile devices have access to a wealth of data suitable for learning models, which in tur ...

  9. ufldl学习笔记与编程作业:Multi-Layer Neural Network(多层神经网络+识别手写体编程)...

    ufldl学习笔记与编程作业:Multi-Layer Neural Network(多层神经网络+识别手写体编程) ufldl出了新教程,感觉比之前的好,从基础讲起,系统清晰,又有编程实践. 在dee ...

  10. cs224w(图机器学习)2021冬季课程学习笔记16 Community Detection in Networks

    诸神缄默不语-个人CSDN博文目录 cs224w(图机器学习)2021冬季课程学习笔记集合 文章目录 1. Community Detection in Networks 2. Network Com ...

最新文章

  1. java理解程序逻辑_使用java理解程序逻辑(12)
  2. 51nod 1486 大大走格子(DP+组合数学)
  3. Bean标签基本配置
  4. ​可扩展的公有云媒体服务设计解析
  5. 160 - 10 Andrénalin.3
  6. 那个拒绝北大教授,却坚持留在美国做服务员的数学天才,现状如何
  7. python系统关键字_Python基础教程 - global关键字及全局变量的用法
  8. MySQL入门 (五) : CRUD 与资料维护
  9. Vue中组件间通信的方式
  10. amazon linux ami root 密码,Ubuntu Server的Amazon AMI映像的默认用户名是什么?
  11. 变色龙引导r2795
  12. 2017-9-22 NOIP模拟赛[xxy][数论]
  13. 苹果电脑恢复服务器上的安装器信息已被破坏,苹果电脑安装软件显示:映像数据已损坏的解决办法...
  14. 模型包装,答辩吹牛方法论!
  15. 鸿蒙应用开发教程第12期:被央视点名!打破垄断志在必得?
  16. Python 爬虫学习笔记(十(2))scrapy爬取图书电商实战详解
  17. word文档加密经验实战分享
  18. Hadoop客户端环境准备(附IDEA免费激活码及补丁破解教程)
  19. 批量删除时传参的转换
  20. 回归生活:清理微信公众号

热门文章

  1. 总结下SQLServer和Oracle转换的脚本
  2. c# 上传图片到一个外链相册服务器
  3. (转)git 忽略规则
  4. android app resign之后安装提示INSTALL_PARSE_FAILED_NO_CERTIFICATES的解决办法
  5. Fortran 学习1--数据类型
  6. 在Struts 2中使用JSON Ajax
  7. AD域控exchange邮箱(三)——exchange2010卸载报错的解决方法全纪录
  8. mysql 聚集索引 存什么,关于mysql的聚集索引
  9. php redis 封装类,php redis封装类
  10. FineBI与FineReport对比