详细分析莫烦DQN代码

Python入门,莫烦是很好的选择,快去b站搜视频吧!
作为一只渣渣白,去看了莫烦的强化学习入门, 现在来回忆总结下DQN,作为笔记记录下来。
主要是对代码做了详细注释
DQN有两个网络,一个eval网络,一个target网络,两个网络结构相同,只是target网络的参数在一段时间后会被eval网络更新。
maze_env.py是环境文件,建立的是一个陷阱游戏的环境,就不用细分析了。
RL_brain.py是建立网络结构的文件:
在类DeepQNetwork中,有五个函数:
n_actions 是动作空间数,环境中上下左右所以是4,n_features是状态特征数,根据位置坐标所以是2.
函数_build_net(self):(讲道理这个注释是详细到不能再详细了)
建立eval网络:

# ------------------ build evaluate_net ------------------
# input 用来接收observation
self.s = tf.placeholder(tf.float32, [None, self.n_features], name='s')
# for calculating loss 用来接收q_target的值
self.q_target = tf.placeholder(tf.float32, [None, self.n_actions], name='Q_target')
# 两层网络l1,l2,神经元 10个,第二层有多少动作输出多少
# variable_scope()用于定义创建变量(层)的操作的上下文管理器
with tf.variable_scope('eval_net'):# c_names(collections_names) are the collections to store variables  在更新target_net参数时会用到# \表示没有[],()的换行c_names, n_l1, w_initializer, b_initializer = \['eval_net_params', tf.GraphKeys.GLOBAL_VARIABLES], 10, \tf.random_normal_initializer(0., 0.3), tf.constant_initializer(0.1)# config of layers  nl1第一层有多少个神经元# eval_net 的第一层. collections 是在更新 target_net 参数时会用到with tf.variable_scope('l1'):w1 = tf.get_variable('w1', [self.n_features, n_l1], initializer=w_initializer, collections=c_names)b1 = tf.get_variable('b1', [1, n_l1], initializer=b_initializer, collections=c_names)l1 = tf.nn.relu(tf.matmul(self.s, w1) + b1)print(l1)# eval_net 的第二层. collections 是在更新 target_net 参数时会用到with tf.variable_scope('l2'):w2 = tf.get_variable('w2', [n_l1, self.n_actions], initializer=w_initializer, collections=c_names)b2 = tf.get_variable('b2', [1, self.n_actions], initializer=b_initializer, collections=c_names)self.q_eval = tf.matmul(l1, w2) + b2#作为行为的Q值  估计with tf.variable_scope('loss'): #求误差self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval))
with tf.variable_scope('train'): #梯度下降self._train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss)

两层全连接,隐藏层神经元个数都是10个,最后输出是q_eval,再求误差。
target网络建立和上面的大致相同,结构也相同。输出是q_next。

函数:store_transition(): 存储记忆

def store_transition(self, s, a, r, s_):# hasattr() 函数用于判断对象是否包含对应的属性  如果对象有该属性返回 True,否则返回 Falseif not hasattr(self, 'memory_counter'):self.memory_counter = 0# 记录一条 [s, a, r, s_] 记录transition = np.hstack((s, [a, r], s_))# numpy.hstack(tup)参数tup可以是元组,列表,或者numpy数组,返回结果为按顺序堆叠numpy的数组(按列堆叠一个)。# 总 memory 大小是固定的, 如果超出总大小, 旧 memory 就被新 memory 替换index = self.memory_counter % self.memory_sizeself.memory[index, :] = transitionself.memory_counter += 1

存储transition,按照记忆池大小,按行插入,超过的则覆盖存储。

函数choose_action():选择动作

def choose_action(self, observation):# to have batch dimension when feed into tf placeholder  统一 observation 的 shape (1, size_of_observation)observation = observation[np.newaxis, :]#np.newaxis增加维度  []变成[[]]多加了一个行轴,一维变二维if np.random.uniform() < self.epsilon:# forward feed the observation and get q value for every actions# 让 eval_net 神经网络生成所有 action 的值, 并选择值最大的 actionactions_value = self.sess.run(self.q_eval, feed_dict={self.s: observation})action = np.argmax(actions_value)  #返回axis维度的最大值的索引else:action = np.random.randint(0, self.n_actions)return action

如果随机生成的数小于epsilon,则按照q_eval中最大值对应的索引作为action,否则就在动作空间中随机产生动作。

函数learn(): agent学习过程

def learn(self):# 检查是否替换 target_net 参数if self.learn_step_counter % self.replace_target_iter == 0:self.sess.run(self.replace_target_op)  #判断要不要换参数print('\ntarget_params_replaced\n')# sample batch memory from all memory 随机抽取多少个记忆变成batch memoryif self.memory_counter > self.memory_size:sample_index = np.random.choice(self.memory_size, size=self.batch_size)else:sample_index = np.random.choice(self.memory_counter, size=self.batch_size)# 从 memory 中随机抽取 batch_size 这么多记忆batch_memory = self.memory[sample_index, :]  #随机选出的记忆#获取 q_next (target_net 产生了 q) 和 q_eval(eval_net 产生的 q)q_next, q_eval = self.sess.run([self.q_next, self.q_eval],feed_dict={self.s_: batch_memory[:, -self.n_features:],  # fixed paramsself.s: batch_memory[:, :self.n_features],  # newest params})# change q_target w.r.t q_eval's action 先让target = evalq_target = q_eval.copy()batch_index = np.arange(self.batch_size, dtype=np.int32)#返回一个长度为self.batch_size的索引值列表aray([0,1,2,...,31])eval_act_index = batch_memory[:, self.n_features].astype(int)# 返回一个长度为32的动作列表,从记忆库batch_memory中的标记的第2列,self.n_features=2# #即RL.store_transition(observation, action, reward, observation_)中的action# #注意从0开始记,所以eval_act_index得到的是action那一列reward = batch_memory[:, self.n_features + 1]# 返回一个长度为32奖励的列表,提取出记忆库中的rewardq_target[batch_index, eval_act_index] = reward + self.gamma * np.max(q_next, axis=1)"""For example in this batch I have 2 samples and 3 actions:q_eval =[[1, 2, 3],[4, 5, 6]]q_target = q_eval =[[1, 2, 3],[4, 5, 6]]Then change q_target with the real q_target value w.r.t the q_eval's action.For example in:sample 0, I took action 0, and the max q_target value is -1;sample 1, I took action 2, and the max q_target value is -2:q_target =[[-1, 2, 3],[4, 5, -2]]So the (q_target - q_eval) becomes:       q值并不是对位相减[[(-1)-(1), 0, 0],[0, 0, (-2)-(6)]]We then backpropagate this error w.r.t the corresponding action to network,最后我们将这个 (q_target - q_eval) 当成误差, 反向传递会神经网络.所有为 0 的 action 值是当时没有选择的 action, 之前有选择的 action 才有不为0的值.我们只反向传递之前选择的 action 的值,leave other action as error=0 cause we didn't choose it."""# train eval network_, self.cost = self.sess.run([self._train_op, self.loss],feed_dict={self.s: batch_memory[:, :self.n_features],self.q_target: q_target})self.cost_his.append(self.cost)  # 记录 cost 误差# increasing epsilon  逐渐增加 epsilon, 降低行为的随机self.epsilon = self.epsilon + self.epsilon_increment if self.epsilon < self.epsilon_max else self.epsilon_maxself.learn_step_counter += 1

每200步替换一次两个网络的参数,eval网络的参数实时更新,并用于训练 target网络的用于求loss,每200步将eval的参数赋给target实现更新。

我也不知道这里为什么没有用onehot,所以莫烦在讲求值相减的时候也有点凌乱。其实就是将q_eval赋给q_target,然后按照被选择的动作索引赋q_next的值,即只改变被选择了动作位置处的q值,其他位置q值不变还是q_eval的值,这样为了方便相减,求loss值,反向传递给神经网络。

run_this.py文件,运行:

def run_maze():step = 0   #用来控制什么时候学习for episode in range(100):# 初始化环境observation = env.reset()#print(observation)while True:# 刷新环境env.render()# dqn根据观测值选择动作action = RL.choose_action(observation)# 环境根据行为给出下一个state,reward,是否终止observation_, reward, done = env.step(action)RL.store_transition(observation, action, reward, observation_)#dqn存储记忆#数量大于200以后再训练,每五步学习一次if (step > 200) and (step % 5 == 0):RL.learn()# 将下一个state_变为下次循环的stateobservation = observation_# 如果终止就跳出循环if done:breakstep += 1# end of gameprint('game over')env.destroy()

执行过程就显得比较明了了,调用之前的函数,与环境交互获得observation,选择动作,存储记忆,学习,训练网络。

以上是我对DQN代码的理解,感谢莫烦大佬,本人水平有限,以上内容如有错误之处请批评指正,有相关疑问也欢迎讨论。

详细分析莫烦DQN代码相关推荐

  1. 莫烦Python代码实践(一)——Q-Learning算法工程化解析

    提示:转载请注明出处,若本文无意侵犯到您的合法权益,请及时与作者联系. 莫烦Python代码实践(一)--Q-Learning算法工程化解析 声明 一.Q-Learning算法是什么? 二.Q-Lea ...

  2. 详细分析如何在java代码中使用继承和组合

    文章目录 继承与组合 何时在Java中使用继承 何时在Java中使用组合 继承与组成:两个例子 用Java继承重写方法 Java不具有多重继承 使用super访问父类方法 构造函数与继承一起使用 类型 ...

  3. 详细分析contrex-A9的汇编代码__switch_to(进程切换)

    //函数原型:版本linux-3.0.8 struct task_struct *__switch_to(structtask_struct *, struct thread_info *, stru ...

  4. 生成式对抗网络GAN之实现手写字体的生成(基于keras Tensorflow2.0实现)详细分析训练过程和代码

  5. 莫烦老师,DQN代码学习笔记(图片版)

    详情请见莫烦老师DQN主页:DQN 算法更新 (Tensorflow) - 强化学习 Reinforcement Learning | 莫烦Python 莫烦老师代码(没有我繁琐注释代码直通车):Mo ...

  6. 莫烦python教程部分代码

    GitHub资源整理 莫烦python教程部分代码 莫烦python教程部分代码 整理了一部分莫烦Python教程中的代码,并对代码进行了详细的注释.由于莫烦大佬在做TensorFlow教程时使用的0 ...

  7. Blueprint代码详细分析-Android10.0编译系统(七)

    摘要:Blueprint解析Android.bp到ninja的代码流程时如何走的? 阅读本文大约需要花费18分钟. 文章首发微信公众号:IngresGe 专注于Android系统级源码分析,Andro ...

  8. linux源码acl,Linux自主访问控制机制模块详细分析之posix_acl.c核心代码注释与acl.c文件介绍...

    原标题:Linux自主访问控制机制模块详细分析之posix_acl.c核心代码注释与acl.c文件介绍 2.4.4.6 核心代码注释 1 posix_acl_permission() int(stru ...

  9. 详细分析开源软件 ExifTool 的任意代码执行漏洞 (CVE-2021-22204)

     聚焦源代码安全,网罗国内外最新资讯! 编译:奇安信代码卫士 本文作者详述了自己如何从 ExifTool 发现漏洞的过程. 背景 在查看我最喜欢的漏洞奖励计划时,我发现他们使用ExifTool 从所上 ...

  10. Uboot代码结构详细分析

    1. Bootloader功能分析 Bootloader(如Uboot.Redboot.Blob.vivi等)直接和CPU.外围硬件设备(存储器.网卡.LCD等)打交道,负责初始化硬件设备,以及负责拉 ...

最新文章

  1. 5 分钟入门 Google 最强NLP模型:BERT
  2. 什么是动态DNS 动态DNS有什么用
  3. python 小游戏500行以内_[宜配屋]听图阁
  4. 第十六届全国大学生智能车竞赛线上赛点赛道审核 - 东北赛区(第一批次)
  5. javascript的self和this使用小结
  6. java序列化如何实现_Java实现序列化与反序列化的简单示例
  7. nullnulle-人事管理系统-人事档案-变更管理-人员合同变更
  8. nginx的状态是failed的解决方案
  9. java中会存在内存泄漏吗,请简单描述
  10. [vue] 怎么在watch监听开始之后立即被调用?
  11. Java题-直接赋值与重新创建内存
  12. Unity 自定义Log系统
  13. request[limit]取不到前台的值_基于uFUN开发板的心率计(二)动态阈值算法获取心率值...
  14. 6.6使用环境变量配置外部环境
  15. 一个卡片式的ViewPager,带你玩转ViewPager的PageTransformer属性!
  16. Summernote个性化定制使用帮助(三)
  17. HDU 12O3 I NEED A OFFER!
  18. psp记忆棒测试软件,乱花渐欲迷人眼——PSP用记忆棒选购指南
  19. 大型网站 + 静态页面
  20. 盘点20款让你脑洞大开的AR技术应用

热门文章

  1. 如何在Linux系统下配置JDK环境变量
  2. 锐捷交换机配置保存到计算机,锐捷交换机配置命令总结中篇
  3. 打开计算机不显示百度云管家,百度云管家怎么打不开电脑上的百度云管家打不开的解决方法...
  4. ag-grid 设置行高
  5. 魔兽世界 MPQ(MoPaQ) 文件相关资料
  6. java p2p实例_java文件p2p传输
  7. 检测VC++Redistributable运行库 vcredist_x86.exe
  8. eltable 无数据文案修改_element-table 无数据的时候,把“暂无数据” 改成其他文字或图片...
  9. 彩翼系列-彩票分析软件源代码(双色球,排三,排五,3D,22选5,30选7)源代码
  10. FPGA实现SPI 协议