1.估值网络简介
在强化学习中,除了上节提到的策略网络(Policy Based)直接选择Action的方法,还有一种学习Action对应的期望值(Expected Utility)的方法,称为Q-Learning,和Plolicy Based方法一样, Q-Learning不依赖环境模型。在有限马尔科夫决策过程中(Markov Decision Process)中,Q-Learning被证明最终可以找到最优的策略。简单来说,将旧的Q-Learning函数,向着学习目标(当前获得的Reward加上下一步可获得的最大期望价值)按一个较小的学习速率学习,得到新的Q-Learning函数,这个就是Q-Learning的具体的思想,学习率决定了覆盖之前掌握信息的比例,通常设为一个比较小的值,如果设定的值比较大,那么覆盖之前的信息比较多,那么会造成整个网络的动荡。

我们用来学习Q-Learning的模型可以是神经网络,这样得到的模型即是估值网络。如果其中的神经网络比较深,那就是DQN。在DQN的使用中会有很多的Trick。第一个是在DQN中引入卷积层,第二个是Experience Replay,第三个Trick就是可以再使用一个DQN网络来辅助训练,第四个Trick,如果再分拆出target DQN的方法上更进一步,那就是Double DQN,第五个Trick是使用dual DQN。

2.GridWorld的任务代码实现

#coding:utf-8
#这里也是导入常用的依赖库
#为了直接能够在终端中运行代码,我还是把魔法命定注释掉了,具体的魔法命令的解释可以看上一个实战import numpy as np
import random
import itertools
import scipy.misc
import matplotlib.pyplot as plt
import tensorflow as tf
import os
# %matplotlib inline #先是创建环境内物体对象的class
class gameOb():def __init__(self, coordinates, size, intensity, channel, reward, name):self.x = coordinates[0]self.y = coordinates[1]self.size = sizeself.intensity = intensityself.channel = channelself.reward = reward self.name = name#创建GridWorld环境的class
class gameEnv():def __init__(self, size):self.sizeX = sizeself.sizeY = sizeself.actions = 4self.objects = []a = self.reset()plt.imshow(a, interpolation = "nearest")#hero是用户控制的对象,4个goal的reward为1, 2个fire的reward为-1def reset(self):self.objects = []hero = gameOb(self.newPosition(), 1, 1, 2, None, 'hero')self.objects.append(hero)goal = gameOb(self.newPosition(), 1, 1, 1, 1, 'goal')self.objects.append(goal)hole = gameOb(self.newPosition(), 1, 1, 0, -1, 'fire')self.objects.append(hole) goal2 = gameOb(self.newPosition(), 1, 1, 1, 1, 'goal')self.objects.append(goal2)hole2 = gameOb(self.newPosition(), 1, 1, 0, -1, 'fire')self.objects.append(hole2)goal3 = gameOb(self.newPosition(), 1, 1, 1, 1, 'goal')self.objects.append(goal3)goal4 = gameOb(self.newPosition(), 1, 1, 1, 1, 'goal')self.objects.append(goal4)state = self.renderEnv()self.state = statereturn state#实现英雄角色移动的方向0,1, 2,3,分别代表下,上, 左,右def moveChar(self, direction):hero = self.objects[0]heroX = hero.xheroY = hero.yif direction == 0 and hero.y >= 1:hero.y -= 1if direction == 1 and hero.y <= self.sizeY-2:hero.y += 1if direction == 2 and hero.x >= 1:hero.x -= 1if direction == 3 and hero.x <= self.sizeX - 2:hero.x += 1self.objects[0] = hero #定义新的位置def newPosition(self):iterables = [range(self.sizeX), range(self.sizeY)]points = []for t in itertools.product(*iterables):points.append(t)currentPositions = []for objectA in self.objects:if (objectA.x, objectA.y) not in currentPositions:currentPositions.append((objectA.x, objectA.y))for pos in currentPositions:points.remove(pos)location = np.random.choice(range(len(points)), replace = False)return points[location]#定义checkGoal函数,用来检查hero是否触碰了goal或者firedef checkGoal(self):others = []for obj in self.objects:if obj.name == 'hero':hero = obj else:others.append(obj)for other in others:if hero.x == other.x and hero.y == other.y:self.objects.remove(other)if other.reward == 1:self.objects.append(gameOb(self.newPosition(), 1, 1, 1, 1, 'goal'))else:self.objects.append(gameOb(self.newPosition(), 1, 1, 0, -1, 'fire'))return other.reward, Falsereturn 0.0, False#渲染图像尺寸def renderEnv(self):a = np.ones([self.sizeY+2, self.sizeX+2, 3])a[1:-1, 1:-1, :] = 0hero = Nonefor item in self.objects:a[item.y+1: item.y + item.size + 1, item.x + 1 : item.x + item.size + 1, item.channel] = item.intensityb = scipy.misc.imresize(a[:, :, 0], [84, 84, 1], interp = 'nearest')c = scipy.misc.imresize(a[:, :, 1], [84, 84, 1], interp = 'nearest')d = scipy.misc.imresize(a[:, :, 2], [84, 84, 1], interp = 'nearest')a = np.stack([b, c, d], axis = 2)return a#定义执行的Action的方法def step(self, action):self.moveChar(action)reward, done = self.checkGoal()state = self.renderEnv()return state, reward, done#设置尺寸为5
env = gameEnv(size = 5)#定义DQN(Deep Q-Network)网络
class Qnetwork():def __init__(self, h_size):self.scalarInput = tf.placeholder(shape = [None, 21168], dtype = tf.float32)self.imageIn = tf.reshape(self.scalarInput, shape = [-1, 84, 84, 3])self.conv1 = tf.contrib.layers.convolution2d(inputs = self.imageIn, num_outputs = 32, kernel_size = [8, 8], stride = [4, 4], padding = 'VALID', biases_initializer = None)self.conv2 = tf.contrib.layers.convolution2d(inputs = self.conv1, num_outputs = 64, kernel_size = [4, 4], stride = [2, 2], padding = 'VALID', biases_initializer = None)self.conv3 = tf.contrib.layers.convolution2d(inputs = self.conv2, num_outputs = 64, kernel_size = [3, 3], stride = [1, 1], padding = 'VALID', biases_initializer = None)self.conv4 = tf.contrib.layers.convolution2d(inputs = self.conv3, num_outputs = 512, kernel_size = [7, 7], stride = [1, 1], padding = 'VALID', biases_initializer = None)self.streamAC, self.streamVC = tf.split(self.conv4, 2, 3)self.streamA = tf.contrib.layers.flatten(self.streamAC)self.streamV = tf.contrib.layers.flatten(self.streamVC)self.AW = tf.Variable(tf.random_normal([h_size // 2, env.actions]))self.VW = tf.Variable(tf.random_normal([h_size // 2, 1]))self.Adavantage = tf.matmul(self.streamA, self.AW)self.Value = tf.matmul(self.streamV, self.VW)self.Qout = self.Value + tf.subtract(self.Adavantage, tf.reduce_mean(self.Adavantage, reduction_indices = 1, keep_dims = True))self.predict = tf.argmax(self.Qout, 1)self.targetQ = tf.placeholder(shape = [None], dtype = tf.float32)self.actions = tf.placeholder(shape = [None], dtype = tf.int32)self.actions_onehot = tf.one_hot(self.actions, env.actions, dtype = tf.float32)self.Q = tf.reduce_sum(tf.multiply(self.Qout, self.actions_onehot), reduction_indices = 1)self.td_error = tf.square(self.targetQ - self.Q)self.loss = tf.reduce_mean(self.td_error)self.trainer = tf.train.AdamOptimizer(learning_rate = 0.0001)self.UpdateModel = self.trainer.minimize(self.loss)#实现Experience Replay策略
class experience_buffer():def __init__(self, buffer_size = 50000):self.buffer = []self.buffer_size = buffer_sizedef add(self, experience):if len(self.buffer) + len(experience) >= self.buffer_size:self.buffer[0: (len(experience) + len(self.buffer)) - self.buffer_size] = []self.buffer.extend(experience)def sample(self, size):return np.reshape(np.array(random.sample(self.buffer, size)), [size, 5])
#把当前state扁平为1维向量的函数
def processState(states):return np.reshape(states, [21168])#更新模型参数
def updateTargetGraph(tfVars, tau):total_vars = len(tfVars)op_holder = []for idx, var in enumerate(tfVars[0: total_vars // 2]):op_holder.append(tfVars[idx + total_vars // 2].assign((var.value() * tau) + ((1 - tau) * tfVars[idx + total_vars // 2].value())))return op_holderdef updateTarget(op_holder,sess):for op in op_holder:sess.run(op)#设置一些训练参数
batch_size = 32
update_freq = 4
y = .99
startE = 1
endE = 0.1
anneling_steps = 10000.
num_episodes = 10000
pre_train_steps = 10000
max_epLength = 50
load_model = False
path = "./dqn"
h_size = 512
tau = 0.001#初始化
mainQN = Qnetwork(h_size)
targetQN = Qnetwork(h_size)
init = tf.global_variables_initializer()trainables = tf.trainable_variables()
targetOps = updateTargetGraph(trainables, tau)myBuffer = experience_buffer()e = startE
stepDrop = (startE - endE) / anneling_stepsrList = []
total_steps = 0saver = tf.train.Saver()
if not os.path.exists(path):os.makedirs(path)#创建默认的session
with tf.Session() as sess:if load_model == True:print('Load Model...')ckpt = tf.train.get_checkpoint_state(path)saver.restore(sess, ckpt.model_checkpoint_path)sess.run(init)updateTarget(targetOps, sess)for i in range(num_episodes + 1):episodeBuffer = experience_buffer()s = env.reset()s = processState(s)d = FalserAll = 0j = 0while j < max_epLength:j += 1if np.random.rand(1) < e or total_steps < pre_train_steps:a = np.random.randint(0, 4)else:a = sess.run(mainQN.predict, feed_dict = {mainQN.scalarInput: [s]})[0]s1, r, d = env.step(a)s1 = processState(s1)total_steps += 1episodeBuffer.add(np.reshape(np.array([s, a, r, s1, d]), [1, 5]))if total_steps > pre_train_steps:if e > endE:e -= stepDropif total_steps % (update_freq) == 0:trainBatch = myBuffer.sample(batch_size)A = sess.run(mainQN.predict, feed_dict = {mainQN.scalarInput: np.vstack(trainBatch[:, 3])})Q = sess.run(targetQN.Qout, feed_dict = {targetQN.scalarInput: np.vstack(trainBatch[:, 3])})doubleQ = Q[range(batch_size), A]targetQ = trainBatch[:, 2] + y * doubleQ_ = sess.run(mainQN.UpdateModel, feed_dict = {mainQN.scalarInput: np.vstack(trainBatch[:, 0]), mainQN.targetQ: targetQ,mainQN.actions:trainBatch[:, 1]})updateTarget(targetOps, sess)rAll += r s = s1if d == True:breakmyBuffer.add(episodeBuffer.buffer)rList.append(rAll)if i > 0 and i % 25 == 0:print('episode', i, ', average reward of last 25 episode', np.mean(rList[-25:]))if i > 0 and i % 1000 == 0:saver.save(sess, path + '/model-' + str(i) + '.cptk')print("Saved Model")saver.save(sess, path + '/model-' + str(i) + '.cptk')rMat = np.resize(np.array(rList), [len(rList) // 100, 100])
rMean = np.average(rMat, 1)
plt.plot(rMean)

这个还是要训练好久,不过还是蛮好玩的,如果可以用强化学习训练一个监督机器人,这样LZ就不会有拖延症啦O(∩_∩)O

TensorFlow实战14:实现估值网络(强化学习二)相关推荐

  1. 《强化学习周刊》第14期:元强化学习的最新研究与应用

    No.14 智源社区 强化学习组 强 化 学  习 研究 观点 资源 活动 关于周刊 强化学习作为人工智能领域研究热点之一,它与元学习相结合的研究进展与成果也引发了众多关注.为帮助研究与工程人员了解该 ...

  2. tensorflow命令行安装失败_2019-1 强化学习入坑记之ancanda安装

    入门RL强化学习,首先要装Tensorflow环境,用ananconda最佳,以此记录我的安装过程 计划: ancanda安装 python3.6 环境设置 tensorflow cpu版本安装 实验 ...

  3. 第二十七课.深度强化学习(二)

    目录 概述 价值学习 Deep Q Network DQN的训练:TD算法(Temporal Difference Learning) 策略学习 Policy Network 策略网络训练:Polic ...

  4. CUDA入门和网络加速学习(二)

    0. 简介 最近作者希望系统性的去学习一下CUDA加速的相关知识,正好看到深蓝学院有这一门课程.所以这里作者以此课程来作为主线来进行记录分享,方便能给CUDA网络加速学习的萌新们去提供一定的帮助. 1 ...

  5. 浅谈强化学习二之马尔卡夫决策过程与动态规划

    书接上文,目前普遍认为强化学习的算法分为基于值函数和基于策略搜索以及其他强化学习算法. 先说强化学习的基础,提及强化学习,就要先认知马尔可夫.确认过眼神,大家都是被公式折磨的人,这里就不讲公式了,只是 ...

  6. 资源 | 《GAN实战:生成对抗网络深度学习》牛津大学Jakub著作(附下载)

    来源:专知 本文共1000字,建议阅读5分钟. 本书囊括了关于GAN的定义.训练.变体等,是关于GAN的最好的书籍之一. [ 导读 ]生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可 ...

  7. .NET网络编程学习(二)

    System.Net.Sockets有很多类,其中最重要的就是Socket类. Socket类 public class Socket : IDisposable Socket 类为网络通信提供了一套 ...

  8. Tensroflow练习,包括强化学习、推荐系统、nlp等

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 代码和数据集  获取: 关注微信公众号 datayx  然后回复  tf  即可获取. AI项 ...

  9. 《机器学习实战:基于Scikit-Learn、Keras和TensorFlow(第2版)》学习笔记

    文章目录 书籍信息 技术和工具 Scikit-Learn TensorFlow Keras Jupyter notebook 资源 书籍配套资料 流行的开放数据存储库 元门户站点(它们会列出开放的数据 ...

最新文章

  1. 使用Java代码连接SAP ABAP Netweaver服务器
  2. 链表排序c++代码_[链表面试算法](一) 链表的删除-相关题型总结(6题)
  3. java音频采样_音频重采样的坑
  4. 基于linux的MsQUIC编译及样例运行
  5. [python教程入门学习]就业寒冬,从拉勾招聘看Python就业前景
  6. vscode markdown_VS Code中的Markdown插件
  7. Double值保留两位小数的四种方法
  8. mysql 限制单个用户资源_限制MySQL数据库单个用户最大连接数等的方法
  9. 网站加载时间测试、网页元素加载性能及网站状态监控工具集合介绍
  10. 离散小波变换wavedec matlab,MATLAB小波变换指令及其功能介绍(超级有用)
  11. 2014世界10大DRAM公司
  12. 恒天然NZMP品牌干酪在2018年国际奶酪大赛中荣获八枚奖牌
  13. TVS管与压敏电阻的性能比较
  14. 3D数学基础——矩阵的介绍与使用
  15. C++之 fgets函数
  16. MatLab函数:pol2cart()
  17. 一条sql执行出现错误Unknown column 'e.sal' in 'on clause'
  18. VS2017项目配置X86改配置x64位
  19. 【C++】Visual Studio教程(十三) -默认键盘快捷方式
  20. 奇门仓储场景具体应用

热门文章

  1. 每一页都是干货,送精选15本Python新书,我必须推荐给你
  2. 矩阵理论| 特殊矩阵:初等矩阵(1) - (行列式、逆矩阵、特征向量)、初等矩阵的相关定理和性质
  3. 【学习笔记】C# 静态类
  4. ClickHouse在工业互联网场景的OLAP平台建设实践
  5. 大学生毕业前必须做的20件事
  6. 出租车计价 (15分)
  7. 山东电销机器人_客服人员,你担心山东百应电销营销机器人系统抢你饭碗吗?...
  8. 计算机考试前的心情作文,考试前的心情作文100字
  9. 那些让人睡不着觉的bug,你有没有遭遇过?
  10. 软件测试工作经验总结