人工智能作为当前热门在我们生活中得到了广泛应用,尤其是在智能游戏方面,有的已经达到了可以和职业选手匹敌的效果。

而DQN算法作为智能游戏的经典选择算法,其主要是通过奖励惩罚机制来迭代模型,来达到更接近于人类学习的效果。

那在强化学习中, 神经网络是如何被训练的呢?

首先, 我们需要 a1, a2 正确的Q值, 这个 Q 值我们就用之前在 Q learning 中的 Q 现实来代替. 同样我们还需要一个Q估计来实现神经网络的更新. 所以神经网络的的参数就是老的NN参数加学习率 alpha乘以Q现实和Q估计的差距。

我们通过 NN 预测出Q(s2, a1) 和 Q(s2,a2) 的值, 这就是 Q 估计.

然后我们选取 Q 估计中最大值的动作来换取环境中的奖励 reward.

而 Q 现实中也包含从神经网络分析出来的两个 Q 估计值, 不过这个 Q 估计是针对于下一步在 s’ 的估计.

最后再通过刚刚所说的算法更新神经网络中的参数.

DQN是第一个将深度学习模型与强化学习结合在一起从而成功地直接从高维的输入学习控制策略。

  • 创新点:

基于Q-Learning构造Loss Function(不算很新,过往使用线性和非线性函数拟合Q-Table时就是这样做)。通过experience replay(经验池)解决相关性及非静态分布问题;使用TargetNet解决稳定性问题。

  • 优点:

算法通用性,可玩不同游戏;End-to-End 训练方式;可生产大量样本供监督学习。

  • 缺点:

无法应用于连续动作控制;只能处理只需短时记忆问题,无法处理需长时记忆问题(后续研究提出了使用LSTM等改进方法);CNN不一定收敛,需精良调参。整体的程序效果如下:

一. 实验前的准备

首先我们使用的python版本是3.6.5所用到的库有cv2库用来图像处理;Numpy库用来矩阵运算;TensorFlow框架用来训练和加载模型。Collection库用于高性能的数据结构。

二. 程序的搭建

1、游戏结构设定:

我们在DQN训练前需要有自己设定好的程序,即在这里为弹珠游戏。

在游戏整体框架搭建完成后,对于计算机的决策方式我们需要给他一个初始化的决策算法为了达到更快的训练效果。

程序结构的部分代码如下:

def __init__(self):self.__initGame()# 初始化一些变量self.loseReward = -1self.winReward = 1self.hitReward = 0self.paddleSpeed = 15self.ballSpeed = (7, 7)self.paddle_1_score = 0self.paddle_2_score = 0self.paddle_1_speed = 0.self.paddle_2_speed = 0.self.__reset()'''更新一帧action: [keep, up, down]'''
#     更新ball的位置self.ball_pos = self.ball_pos[0] + self.ballSpeed[0], self.ball_pos[1] + self.ballSpeed[1]# 获取当前场景(只取左半边)image = pygame.surfarray.array3d(pygame.display.get_surface())# image = image[321:, :]pygame.display.update()terminal = Falseif max(self.paddle_1_score, self.paddle_2_score) >= 20:self.paddle_1_score = 0self.paddle_2_score = 0terminal = Truereturn image, reward, terminal
def update_frame(self, action):assert len(action) == 3pygame.event.pump()reward = 0# 绑定一些对象self.score1Render = self.font.render(str(self.paddle_1_score), True, (255, 255, 255))self.score2Render = self.font.render(str(self.paddle_2_score), True, (255, 255, 255))self.screen.blit(self.background, (0, 0))pygame.draw.rect(self.screen, (255, 255, 255), pygame.Rect((5, 5), (630, 470)), 2)pygame.draw.aaline(self.screen, (255, 255, 255), (320, 5), (320, 475))self.screen.blit(self.paddle_1, self.paddle_1_pos)self.screen.blit(self.paddle_2, self.paddle_2_pos)self.screen.blit(self.ball, self.ball_pos)self.screen.blit(self.score1Render, (240, 210))self.screen.blit(self.score2Render, (370, 210))
'''游戏初始化'''def __initGame(self):pygame.init()self.screen = pygame.display.set_mode((640, 480), 0, 32)self.background = pygame.Surface((640, 480)).convert()self.background.fill((0, 0, 0))self.paddle_1 = pygame.Surface((10, 50)).convert()self.paddle_1.fill((0, 255, 255))self.paddle_2 = pygame.Surface((10, 50)).convert()self.paddle_2.fill((255, 255, 0))ball_surface = pygame.Surface((15, 15))pygame.draw.circle(ball_surface, (255, 255, 255), (7, 7), (7))self.ball = ball_surface.convert()self.ball.set_colorkey((0, 0, 0))self.font = pygame.font.SysFont("calibri", 40)'''重置球和球拍的位置'''def __reset(self):self.paddle_1_pos = (10., 215.)self.paddle_2_pos = (620., 215.)self.ball_pos = (312.5, 232.5)

2、行动决策机制:

首先在程序框架中设定不同的行动作为训练对象

# 行动paddle_1(训练对象)
if action[0] == 1:self.paddle_1_speed = 0
elif action[1] == 1:self.paddle_1_speed = -self.paddleSpeed
elif action[2] == 1:self.paddle_1_speed = self.paddleSpeed
self.paddle_1_pos = self.paddle_1_pos[0], max(min(self.paddle_1_speed + self.paddle_1_pos[1], 420), 10)

接着设置一个简单的初始化决策。根据结果判断奖励和惩罚机制,即球撞到拍上奖励,撞到墙上等等惩罚:其中代码如下:

# 行动paddle_2(设置一个简单的算法使paddle_2的表现较优, 非训练对象)if self.ball_pos[0] >= 305.:if not self.paddle_2_pos[1] == self.ball_pos[1] + 7.5:if self.paddle_2_pos[1] < self.ball_pos[1] + 7.5:self.paddle_2_speed = self.paddleSpeedself.paddle_2_pos = self.paddle_2_pos[0], max(min(self.paddle_2_pos[1] + self.paddle_2_speed, 420), 10)if self.paddle_2_pos[1] > self.ball_pos[1] - 42.5:self.paddle_2_speed = -self.paddleSpeedself.paddle_2_pos = self.paddle_2_pos[0], max(min(self.paddle_2_pos[1] + self.paddle_2_speed, 420), 10)else:self.paddle_2_pos = self.paddle_2_pos[0], max(min(self.paddle_2_pos[1] + 7.5, 420), 10)# 行动ball#   球撞拍上if self.ball_pos[0] <= self.paddle_1_pos[0] + 10.:if self.ball_pos[1] + 7.5 >= self.paddle_1_pos[1] and self.ball_pos[1] <= self.paddle_1_pos[1] + 42.5:self.ball_pos = 20., self.ball_pos[1]self.ballSpeed = -self.ballSpeed[0], self.ballSpeed[1]reward = self.hitRewardif self.ball_pos[0] + 15 >= self.paddle_2_pos[0]:if self.ball_pos[1] + 7.5 >= self.paddle_2_pos[1] and self.ball_pos[1] <= self.paddle_2_pos[1] + 42.5:self.ball_pos = 605., self.ball_pos[1]self.ballSpeed = -self.ballSpeed[0], self.ballSpeed[1]#   拍未接到球(另外一个拍得分)if self.ball_pos[0] < 5.:self.paddle_2_score += 1reward = self.loseRewardself.__reset()elif self.ball_pos[0] > 620.:self.paddle_1_score += 1reward = self.winRewardself.__reset()#   球撞墙上if self.ball_pos[1] <= 10.:self.ballSpeed = self.ballSpeed[0], -self.ballSpeed[1]self.ball_pos = self.ball_pos[0], 10elif self.ball_pos[1] >= 455:self.ballSpeed = self.ballSpeed[0], -self.ballSpeed[1]self.ball_pos = self.ball_pos[0], 455

3、DQN算法搭建:

为了方便整体算法的调用,我们首先定义神经网络的函数,包括卷积层损失等函数定义具体如下可见:

'''获得初始化weight权重'''def init_weight_variable(self, shape):return tf.Variable(tf.truncated_normal(shape, stddev=0.01))'''获得初始化bias权重'''def init_bias_variable(self, shape):return tf.Variable(tf.constant(0.01, shape=shape))'''卷积层'''def conv2D(self, x, W, stride):return tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding="SAME")'''池化层'''def maxpool(self, x):return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')'''计算损失'''def compute_loss(self, q_values, action_now, target_q_values):tmp = tf.reduce_sum(tf.multiply(q_values, action_now), reduction_indices=1)loss = tf.reduce_mean(tf.square(target_q_values - tmp))return loss'''下一帧'''def next_frame(self, action_now, scene_now, gameState):x_now, reward, terminal = gameState.update_frame(action_now)x_now = cv2.cvtColor(cv2.resize(x_now, (80, 80)), cv2.COLOR_BGR2GRAY)_, x_now = cv2.threshold(x_now, 127, 255, cv2.THRESH_BINARY)x_now = np.reshape(x_now, (80, 80, 1))scene_next = np.append(x_now, scene_now[:, :, 0:3], axis=2)return scene_next, reward, terminal'''计算target_q_values'''def compute_target_q_values(self, reward_batch, q_values_batch, minibatch):target_q_values = []for i in range(len(minibatch)):if minibatch[i][4]:target_q_values.append(reward_batch[i])else:target_q_values.append(reward_batch[i] + self.gamma * np.max(q_values_batch[i]))return target_q_values

然后定义整体的类变量DQN,分别定义初始化和训练函数,其中网络层哪里主要就是神经网络层的调用。

然后在训练函数里面记录当前动作和数据加载入优化器中达到模型训练效果。

其中代码如下:

def __init__(self, options):self.options = optionsself.num_action = options['num_action']self.lr = options['lr']self.modelDir = options['modelDir']self.init_prob = options['init_prob']self.end_prob = options['end_prob']self.OBSERVE = options['OBSERVE']self.EXPLORE = options['EXPLORE']self.action_interval = options['action_interval']self.REPLAY_MEMORY = options['REPLAY_MEMORY']self.gamma = options['gamma']self.batch_size = options['batch_size']self.save_interval = options['save_interval']self.logfile = options['logfile']self.is_train = options['is_train']'''训练网络'''def train(self, session):x, q_values_ph = self.create_network()action_now_ph = tf.placeholder('float', [None, self.num_action])target_q_values_ph = tf.placeholder('float', [None])# 计算lossloss = self.compute_loss(q_values_ph, action_now_ph, target_q_values_ph)# 优化目标trainStep = tf.train.AdamOptimizer(self.lr).minimize(loss)# 游戏gameState = PongGame()# 用于记录数据dataDeque = deque()# 当前的动作action_now = np.zeros(self.num_action)action_now[0] = 1# 初始化游戏状态x_now, reward, terminal = gameState.update_frame(action_now)x_now = cv2.cvtColor(cv2.resize(x_now, (80, 80)), cv2.COLOR_BGR2GRAY)_, x_now = cv2.threshold(x_now, 127, 255, cv2.THRESH_BINARY)scene_now = np.stack((x_now, )*4, axis=2)# 读取和保存checkpointsaver = tf.train.Saver()session.run(tf.global_variables_initializer())checkpoint = tf.train.get_checkpoint_state(self.modelDir)if checkpoint and checkpoint.model_checkpoint_path:saver.restore(session, checkpoint.model_checkpoint_path)print('[INFO]: Load %s successfully...' % checkpoint.model_checkpoint_path)else:print('[INFO]: No weights found, start to train a new model...')prob = self.init_probnum_frame = 0logF = open(self.logfile, 'a')while True:q_values = q_values_ph.eval(feed_dict={x: [scene_now]})action_idx = get_action_idx(q_values=q_values, prob=prob, num_frame=num_frame, OBSERVE=self.OBSERVE, num_action=self.num_action)action_now = np.zeros(self.num_action)action_now[action_idx] = 1prob = down_prob(prob=prob, num_frame=num_frame, OBSERVE=self.OBSERVE, EXPLORE=self.EXPLORE, init_prob=self.init_prob, end_prob=self.end_prob)for _ in range(self.action_interval):scene_next, reward, terminal = self.next_frame(action_now=action_now, scene_now=scene_now,                                                            gameState=gameState)scene_now = scene_nextdataDeque.append((scene_now, action_now, reward, scene_next, terminal))if len(dataDeque) > self.REPLAY_MEMORY:dataDeque.popleft()loss_now = Noneif (num_frame > self.OBSERVE):minibatch = random.sample(dataDeque, self.batch_size)scene_now_batch = [mb[0] for mb in minibatch]action_batch = [mb[1] for mb in minibatch]reward_batch = [mb[2] for mb in minibatch]scene_next_batch = [mb[3] for mb in minibatch]q_values_batch = q_values_ph.eval(feed_dict={x: scene_next_batch})target_q_values = self.compute_target_q_values(reward_batch, q_values_batch, minibatch)trainStep.run(feed_dict={target_q_values_ph: target_q_values,action_now_ph: action_batch,x: scene_now_batch})loss_now = session.run(loss, feed_dict={target_q_values_ph: target_q_values,action_now_ph: action_batch,x: scene_now_batch})num_frame += 1if num_frame % self.save_interval == 0:name = 'DQN_Pong'saver.save(session, os.path.join(self.modelDir, name), global_step=num_frame)log_content = '<Frame>: %s, <Prob>: %s, <Action>: %s, <Reward>: %s, <Q_max>: %s, <Loss>: %s' % (str(num_frame), str(prob), str(action_idx), str(reward), str(np.max(q_values)), str(loss_now))logF.write(log_content + 'n')print(log_content)logF.close()'''创建网络'''def create_network(self):'''W_conv1 = self.init_weight_variable([9, 9, 4, 16])b_conv1 = self.init_bias_variable([16])W_conv2 = self.init_weight_variable([7, 7, 16, 32])b_conv2 = self.init_bias_variable([32])W_conv3 = self.init_weight_variable([5, 5, 32, 32])b_conv3 = self.init_bias_variable([32])W_conv4 = self.init_weight_variable([5, 5, 32, 64])b_conv4 = self.init_bias_variable([64])W_conv5 = self.init_weight_variable([3, 3, 64, 64])b_conv5 = self.init_bias_variable([64])'''W_conv1 = self.init_weight_variable([8, 8, 4, 32])b_conv1 = self.init_bias_variable([32])W_conv2 = self.init_weight_variable([4, 4, 32, 64])b_conv2 = self.init_bias_variable([64])W_conv3 = self.init_weight_variable([3, 3, 64, 64])b_conv3 = self.init_bias_variable([64])# 5 * 5 * 64 = 1600W_fc1 = self.init_weight_variable([1600, 512])b_fc1 = self.init_bias_variable([512])W_fc2 = self.init_weight_variable([512, self.num_action])b_fc2 = self.init_bias_variable([self.num_action])# input placeholderx = tf.placeholder('float', [None, 80, 80, 4])'''conv1 = tf.nn.relu(tf.layers.batch_normalization(self.conv2D(x, W_conv1, 4) + b_conv1, training=self.is_train, momentum=0.9))conv2 = tf.nn.relu(tf.layers.batch_normalization(self.conv2D(conv1, W_conv2, 2) + b_conv2, training=self.is_train, momentum=0.9))conv3 = tf.nn.relu(tf.layers.batch_normalization(self.conv2D(conv2, W_conv3, 2) + b_conv3, training=self.is_train, momentum=0.9))conv4 = tf.nn.relu(tf.layers.batch_normalization(self.conv2D(conv3, W_conv4, 1) + b_conv4, training=self.is_train, momentum=0.9))conv5 = tf.nn.relu(tf.layers.batch_normalization(self.conv2D(conv4, W_conv5, 1) + b_conv5, training=self.is_train, momentum=0.9))flatten = tf.reshape(conv5, [-1, 1600])'''conv1 = tf.nn.relu(self.conv2D(x, W_conv1, 4) + b_conv1)pool1 = self.maxpool(conv1)conv2 = tf.nn.relu(self.conv2D(pool1, W_conv2, 2) + b_conv2)conv3 = tf.nn.relu(self.conv2D(conv2, W_conv3, 1) + b_conv3)flatten = tf.reshape(conv3, [-1, 1600])fc1 = tf.nn.relu(tf.layers.batch_normalization(tf.matmul(flatten, W_fc1) + b_fc1, training=self.is_train, momentum=0.9))fc2 = tf.matmul(fc1, W_fc2) + b_fc2return x, fc2

到这里,我们整体的程序就搭建完成,下面为我们程序的运行结果:

源码地址:https://pan.baidu.com/s/1ksvjIiQ0BfXOah4PIE1arg

提取码:p74p

此文转自CSDN 作者 | 李秋键

ai python 代码提示插件_Python 还能实现哪些 AI 游戏?附上代码一起来一把!相关推荐

  1. dw如何写php代码提示,DW CS5 jquery代码提示插件

    喜欢使用Dreamweaver(业内简称dw)做php开发的朋友应该都知道dw是从6.0开始才支持jquery代码提示的.那么对于电脑上安装的是dw cs5而又需要jquery代码提示的该怎么办呢?将 ...

  2. python控制软件点击_Python小程序 控制鼠标循环点击代码实例

    Python小程序 控制鼠标循环点击代码实例 这篇文章主要介绍了Python小程序 控制鼠标循环点击代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以 ...

  3. 谈谈iceCode代码高亮插件的开发问题,由我们国人开发代码高亮插件!

    做为一名资深的开发者,有时候再写一些技术性的文章时,常常需要使用代码高亮插件来展示自己的代码,大家都知道SyntaxHighlighter.Google Code Prettify.Highlight ...

  4. ai二维码插件_送你60款AI脚本插件包,已整合成插件面板的形式,方便在AI中调用...

    送你60款AI脚本插件包,已整合成插件面板的形式,方便在AI中调用. (领取方式见文章末尾) [AI脚本插件合集包] 此AI插件包目前有66款ai脚本插件,已经整合成插件面板的形式,方便在AI中调用. ...

  5. python井字棋_python实现简单井字棋游戏

    井字棋,英文名叫Tic-Tac-Toe,是一种在3*3格子上进行的连珠游戏,和五子棋类似,由于棋盘一般不画边框,格线排成井字故得名.游戏需要的工具仅为纸和笔,然后由分别代表O和X的两个游戏者轮流在格子 ...

  6. visual studio 代码提示插件_请收好:10 个实用的 VS Code 插件

    英文:Daan,翻译:CSDN - Elle 无论你是经验丰富的开发者,还是刚开始工作的新手,你都会想让自己的开发工作尽可能轻松一点.正确的工具使用则可以帮助你实现这个目标. 如果你选用 VS Cod ...

  7. idea代码提示插件_IDEA 插件推荐 —— 让你写出好代码的神器!

    概述 今天介绍的插件主要是围绕编码规范的.有追求的程序员,往往都有代码洁癖,要尽量减少代码的「坏味道」. 代码静态检查是有很多种类,例如圈复杂度.重复率等.业界提供了很多静态检查的插件来识别这些不合规 ...

  8. visual studio 代码提示插件_程序员请收好:10个非常有用的Visual Studio Code插件

    作者 | Daan 译者 | Elle 出品 | CSDN(ID:CSDNnews) [导读]一个插件列表,可以让你的程序员生活变得轻松许多.无论你是经验丰富的开发人员还是刚刚开始第一份工作的初级开发 ...

  9. visual studio 代码提示插件_程序员请收好:10个非常实用的 VS Code 插件

    关注上方"数据挖掘工程师",选择"星标公众号", 关键时间,第一时间送达! 编译:CSDN-Elle,作者:Daan 无论你是经验丰富的开发人员还是刚刚开始第一 ...

最新文章

  1. python快速编程入门课后程序题答案-Python编程从零基础到项目实战 完整PPT+习题答案...
  2. 使用ffmpeg+nginx将rtmp直播流转为hls直播流
  3. python由列表中提取出来的浮点型字符串不能直接转换成整形
  4. mysql pool返回值_Mysql成神之路-InnoDB 的 Buffer Pool
  5. 求高光谱图像相关系数矩阵
  6. c#使用HttpClient调用WebApi
  7. 保存您的lambda,以备不时之需-保存到文件
  8. (JAVA)String类之比较方法
  9. 数据结构课上笔记12
  10. vim命令杂烩(复制粘贴、建文件、撤销等)
  11. java串口助手_java 串口调试助手 源码
  12. mysql 新增字段 添加字段 删除字段 修改字段 级联删除 级联更新 等
  13. shell脚本只运行一个实例
  14. 大数据、Hadoop、Hbase介绍
  15. CSS现状和如何学习
  16. 第四章 《无冬之夜》
  17. 羡慕的核心是焦虑_焦虑是自由的头晕
  18. PS去除图片中文字的方法详细图文教程
  19. uubox.net 网站的第二阶段完成,修复了部分的bug,增加了图片浏览和mp3在线播放等功能...
  20. 小程序封装请求工具http.js

热门文章

  1. redis list放入对象_Redis从入门到入土:详细讲解内存模型以及常用命令
  2. oom机制分析及对应优化策略
  3. zpk在MATLAB中是什么意思,_MATLAB在控制系统中应用 .ppt
  4. Git使用出现git@github.com: Permission denied (publickey)
  5. Java实现数组转字符串及字符串转数组的方法
  6. java 启动jar包JVM参数
  7. ORACLE查询保留字
  8. c程序100例第3题
  9. 开源软件许可协议简介
  10. 挑战Textarea——把textarea中的HTML写入数据库