作者 | 李秋键

责编 | Carol

头图 | CSDN 付费下载自视觉中国

人工智能作为当前热门在我们生活中得到了广泛应用,尤其是在智能游戏方面,有的已经达到了可以和职业选手匹敌的效果。而DQN算法作为智能游戏的经典选择算法,其主要是通过奖励惩罚机制来迭代模型,来达到更接近于人类学习的效果。

那在强化学习中, 神经网络是如何被训练的呢? 首先, 我们需要 a1, a2 正确的Q值, 这个 Q 值我们就用之前在 Q learning 中的 Q 现实来代替. 同样我们还需要一个Q估计来实现神经网络的更新. 所以神经网络的的参数就是老的NN参数加学习率 alpha乘以Q现实和Q估计的差距。

我们通过 NN 预测出Q(s2, a1) 和 Q(s2,a2) 的值, 这就是 Q 估计. 然后我们选取 Q 估计中最大值的动作来换取环境中的奖励 reward. 而 Q 现实中也包含从神经网络分析出来的两个 Q 估计值, 不过这个 Q 估计是针对于下一步在 s’ 的估计. 最后再通过刚刚所说的算法更新神经网络中的参数.

DQN是第一个将深度学习模型与强化学习结合在一起从而成功地直接从高维的输入学习控制策略。

  • 创新点:

基于Q-Learning构造Loss Function(不算很新,过往使用线性和非线性函数拟合Q-Table时就是这样做)。

通过experience replay(经验池)解决相关性及非静态分布问题;

使用TargetNet解决稳定性问题。

  • 优点:

算法通用性,可玩不同游戏;

End-to-End 训练方式;

可生产大量样本供监督学习。

  • 缺点:

无法应用于连续动作控制;

只能处理只需短时记忆问题,无法处理需长时记忆问题(后续研究提出了使用LSTM等改进方法);

CNN不一定收敛,需精良调参。

整体的程序效果如下:

实验前的准备

首先我们使用的python版本是3.6.5所用到的库有cv2库用来图像处理;

Numpy库用来矩阵运算;TensorFlow框架用来训练和加载模型。Collection库用于高性能的数据结构。

程序的搭建

1、游戏结构设定:

我们在DQN训练前需要有自己设定好的程序,即在这里为弹珠游戏。在游戏整体框架搭建完成后,对于计算机的决策方式我们需要给他一个初始化的决策算法为了达到更快的训练效果。

程序结构的部分代码如下:

def __init__(self):self.__initGame()# 初始化一些变量self.loseReward = -1self.winReward = 1self.hitReward = 0self.paddleSpeed = 15self.ballSpeed = (7, 7)self.paddle_1_score = 0self.paddle_2_score = 0self.paddle_1_speed = 0.self.paddle_2_speed = 0.self.__reset()'''更新一帧action: [keep, up, down]'''
#     更新ball的位置self.ball_pos = self.ball_pos[0] + self.ballSpeed[0], self.ball_pos[1] + self.ballSpeed[1]# 获取当前场景(只取左半边)image = pygame.surfarray.array3d(pygame.display.get_surface())# image = image[321:, :]pygame.display.update()terminal = Falseif max(self.paddle_1_score, self.paddle_2_score) >= 20:self.paddle_1_score = 0self.paddle_2_score = 0terminal = Truereturn image, reward, terminal
def update_frame(self, action):assert len(action) == 3pygame.event.pump()reward = 0# 绑定一些对象self.score1Render = self.font.render(str(self.paddle_1_score), True, (255, 255, 255))self.score2Render = self.font.render(str(self.paddle_2_score), True, (255, 255, 255))self.screen.blit(self.background, (0, 0))pygame.draw.rect(self.screen, (255, 255, 255), pygame.Rect((5, 5), (630, 470)), 2)pygame.draw.aaline(self.screen, (255, 255, 255), (320, 5), (320, 475))self.screen.blit(self.paddle_1, self.paddle_1_pos)self.screen.blit(self.paddle_2, self.paddle_2_pos)self.screen.blit(self.ball, self.ball_pos)self.screen.blit(self.score1Render, (240, 210))self.screen.blit(self.score2Render, (370, 210))
'''游戏初始化'''def __initGame(self):pygame.init()self.screen = pygame.display.set_mode((640, 480), 0, 32)self.background = pygame.Surface((640, 480)).convert()self.background.fill((0, 0, 0))self.paddle_1 = pygame.Surface((10, 50)).convert()self.paddle_1.fill((0, 255, 255))self.paddle_2 = pygame.Surface((10, 50)).convert()self.paddle_2.fill((255, 255, 0))ball_surface = pygame.Surface((15, 15))pygame.draw.circle(ball_surface, (255, 255, 255), (7, 7), (7))self.ball = ball_surface.convert()self.ball.set_colorkey((0, 0, 0))self.font = pygame.font.SysFont("calibri", 40)'''重置球和球拍的位置'''def __reset(self):self.paddle_1_pos = (10., 215.)self.paddle_2_pos = (620., 215.)self.ball_pos = (312.5, 232.5)

2、行动决策机制:

首先在程序框架中设定不同的行动作为训练对象

# 行动paddle_1(训练对象)
if action[0] == 1:self.paddle_1_speed = 0
elif action[1] == 1:self.paddle_1_speed = -self.paddleSpeed
elif action[2] == 1:self.paddle_1_speed = self.paddleSpeed
self.paddle_1_pos = self.paddle_1_pos[0], max(min(self.paddle_1_speed + self.paddle_1_pos[1], 420), 10)

接着设置一个简单的初始化决策。根据结果判断奖励和惩罚机制,即球撞到拍上奖励,撞到墙上等等惩罚:

其中代码如下:

# 行动paddle_2(设置一个简单的算法使paddle_2的表现较优, 非训练对象)if self.ball_pos[0] >= 305.:if not self.paddle_2_pos[1] == self.ball_pos[1] + 7.5:if self.paddle_2_pos[1] < self.ball_pos[1] + 7.5:self.paddle_2_speed = self.paddleSpeedself.paddle_2_pos = self.paddle_2_pos[0], max(min(self.paddle_2_pos[1] + self.paddle_2_speed, 420), 10)if self.paddle_2_pos[1] > self.ball_pos[1] - 42.5:self.paddle_2_speed = -self.paddleSpeedself.paddle_2_pos = self.paddle_2_pos[0], max(min(self.paddle_2_pos[1] + self.paddle_2_speed, 420), 10)else:self.paddle_2_pos = self.paddle_2_pos[0], max(min(self.paddle_2_pos[1] + 7.5, 420), 10)# 行动ball#   球撞拍上if self.ball_pos[0] <= self.paddle_1_pos[0] + 10.:if self.ball_pos[1] + 7.5 >= self.paddle_1_pos[1] and self.ball_pos[1] <= self.paddle_1_pos[1] + 42.5:self.ball_pos = 20., self.ball_pos[1]self.ballSpeed = -self.ballSpeed[0], self.ballSpeed[1]reward = self.hitRewardif self.ball_pos[0] + 15 >= self.paddle_2_pos[0]:if self.ball_pos[1] + 7.5 >= self.paddle_2_pos[1] and self.ball_pos[1] <= self.paddle_2_pos[1] + 42.5:self.ball_pos = 605., self.ball_pos[1]self.ballSpeed = -self.ballSpeed[0], self.ballSpeed[1]#   拍未接到球(另外一个拍得分)if self.ball_pos[0] < 5.:self.paddle_2_score += 1reward = self.loseRewardself.__reset()elif self.ball_pos[0] > 620.:self.paddle_1_score += 1reward = self.winRewardself.__reset()#   球撞墙上if self.ball_pos[1] <= 10.:self.ballSpeed = self.ballSpeed[0], -self.ballSpeed[1]self.ball_pos = self.ball_pos[0], 10elif self.ball_pos[1] >= 455:self.ballSpeed = self.ballSpeed[0], -self.ballSpeed[1]self.ball_pos = self.ball_pos[0], 455

3、DQN算法搭建:

为了方便整体算法的调用,我们首先定义神经网络的函数,包括卷积层损失等函数定义具体如下可见:

'''获得初始化weight权重'''def init_weight_variable(self, shape):return tf.Variable(tf.truncated_normal(shape, stddev=0.01))'''获得初始化bias权重'''def init_bias_variable(self, shape):return tf.Variable(tf.constant(0.01, shape=shape))'''卷积层'''def conv2D(self, x, W, stride):return tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding="SAME")'''池化层'''def maxpool(self, x):return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')'''计算损失'''def compute_loss(self, q_values, action_now, target_q_values):tmp = tf.reduce_sum(tf.multiply(q_values, action_now), reduction_indices=1)loss = tf.reduce_mean(tf.square(target_q_values - tmp))return loss'''下一帧'''def next_frame(self, action_now, scene_now, gameState):x_now, reward, terminal = gameState.update_frame(action_now)x_now = cv2.cvtColor(cv2.resize(x_now, (80, 80)), cv2.COLOR_BGR2GRAY)_, x_now = cv2.threshold(x_now, 127, 255, cv2.THRESH_BINARY)x_now = np.reshape(x_now, (80, 80, 1))scene_next = np.append(x_now, scene_now[:, :, 0:3], axis=2)return scene_next, reward, terminal'''计算target_q_values'''def compute_target_q_values(self, reward_batch, q_values_batch, minibatch):target_q_values = []for i in range(len(minibatch)):if minibatch[i][4]:target_q_values.append(reward_batch[i])else:target_q_values.append(reward_batch[i] + self.gamma * np.max(q_values_batch[i]))return target_q_values

然后定义整体的类变量DQN,分别定义初始化和训练函数,其中网络层哪里主要就是神经网络层的调用。然后在训练函数里面记录当前动作和数据加载入优化器中达到模型训练效果。

其中代码如下:

def __init__(self, options):self.options = optionsself.num_action = options['num_action']self.lr = options['lr']self.modelDir = options['modelDir']self.init_prob = options['init_prob']self.end_prob = options['end_prob']self.OBSERVE = options['OBSERVE']self.EXPLORE = options['EXPLORE']self.action_interval = options['action_interval']self.REPLAY_MEMORY = options['REPLAY_MEMORY']self.gamma = options['gamma']self.batch_size = options['batch_size']self.save_interval = options['save_interval']self.logfile = options['logfile']self.is_train = options['is_train']'''训练网络'''def train(self, session):x, q_values_ph = self.create_network()action_now_ph = tf.placeholder('float', [None, self.num_action])target_q_values_ph = tf.placeholder('float', [None])# 计算lossloss = self.compute_loss(q_values_ph, action_now_ph, target_q_values_ph)# 优化目标trainStep = tf.train.AdamOptimizer(self.lr).minimize(loss)# 游戏gameState = PongGame()# 用于记录数据dataDeque = deque()# 当前的动作action_now = np.zeros(self.num_action)action_now[0] = 1# 初始化游戏状态x_now, reward, terminal = gameState.update_frame(action_now)x_now = cv2.cvtColor(cv2.resize(x_now, (80, 80)), cv2.COLOR_BGR2GRAY)_, x_now = cv2.threshold(x_now, 127, 255, cv2.THRESH_BINARY)scene_now = np.stack((x_now, )*4, axis=2)# 读取和保存checkpointsaver = tf.train.Saver()session.run(tf.global_variables_initializer())checkpoint = tf.train.get_checkpoint_state(self.modelDir)if checkpoint and checkpoint.model_checkpoint_path:saver.restore(session, checkpoint.model_checkpoint_path)print('[INFO]: Load %s successfully...' % checkpoint.model_checkpoint_path)else:print('[INFO]: No weights found, start to train a new model...')prob = self.init_probnum_frame = 0logF = open(self.logfile, 'a')while True:q_values = q_values_ph.eval(feed_dict={x: [scene_now]})action_idx = get_action_idx(q_values=q_values, prob=prob, num_frame=num_frame, OBSERVE=self.OBSERVE, num_action=self.num_action)action_now = np.zeros(self.num_action)action_now[action_idx] = 1prob = down_prob(prob=prob, num_frame=num_frame, OBSERVE=self.OBSERVE, EXPLORE=self.EXPLORE, init_prob=self.init_prob, end_prob=self.end_prob)for _ in range(self.action_interval):scene_next, reward, terminal = self.next_frame(action_now=action_now, scene_now=scene_now,                                                            gameState=gameState)scene_now = scene_nextdataDeque.append((scene_now, action_now, reward, scene_next, terminal))if len(dataDeque) > self.REPLAY_MEMORY:dataDeque.popleft()loss_now = Noneif (num_frame > self.OBSERVE):minibatch = random.sample(dataDeque, self.batch_size)scene_now_batch = [mb[0] for mb in minibatch]action_batch = [mb[1] for mb in minibatch]reward_batch = [mb[2] for mb in minibatch]scene_next_batch = [mb[3] for mb in minibatch]q_values_batch = q_values_ph.eval(feed_dict={x: scene_next_batch})target_q_values = self.compute_target_q_values(reward_batch, q_values_batch, minibatch)trainStep.run(feed_dict={target_q_values_ph: target_q_values,action_now_ph: action_batch,x: scene_now_batch})loss_now = session.run(loss, feed_dict={target_q_values_ph: target_q_values,action_now_ph: action_batch,x: scene_now_batch})num_frame += 1if num_frame % self.save_interval == 0:name = 'DQN_Pong'saver.save(session, os.path.join(self.modelDir, name), global_step=num_frame)log_content = '<Frame>: %s, <Prob>: %s, <Action>: %s, <Reward>: %s, <Q_max>: %s, <Loss>: %s' % (str(num_frame), str(prob), str(action_idx), str(reward), str(np.max(q_values)), str(loss_now))logF.write(log_content + '\n')print(log_content)logF.close()'''创建网络'''def create_network(self):'''W_conv1 = self.init_weight_variable([9, 9, 4, 16])b_conv1 = self.init_bias_variable([16])W_conv2 = self.init_weight_variable([7, 7, 16, 32])b_conv2 = self.init_bias_variable([32])W_conv3 = self.init_weight_variable([5, 5, 32, 32])b_conv3 = self.init_bias_variable([32])W_conv4 = self.init_weight_variable([5, 5, 32, 64])b_conv4 = self.init_bias_variable([64])W_conv5 = self.init_weight_variable([3, 3, 64, 64])b_conv5 = self.init_bias_variable([64])'''W_conv1 = self.init_weight_variable([8, 8, 4, 32])b_conv1 = self.init_bias_variable([32])W_conv2 = self.init_weight_variable([4, 4, 32, 64])b_conv2 = self.init_bias_variable([64])W_conv3 = self.init_weight_variable([3, 3, 64, 64])b_conv3 = self.init_bias_variable([64])# 5 * 5 * 64 = 1600W_fc1 = self.init_weight_variable([1600, 512])b_fc1 = self.init_bias_variable([512])W_fc2 = self.init_weight_variable([512, self.num_action])b_fc2 = self.init_bias_variable([self.num_action])# input placeholderx = tf.placeholder('float', [None, 80, 80, 4])'''conv1 = tf.nn.relu(tf.layers.batch_normalization(self.conv2D(x, W_conv1, 4) + b_conv1, training=self.is_train, momentum=0.9))conv2 = tf.nn.relu(tf.layers.batch_normalization(self.conv2D(conv1, W_conv2, 2) + b_conv2, training=self.is_train, momentum=0.9))conv3 = tf.nn.relu(tf.layers.batch_normalization(self.conv2D(conv2, W_conv3, 2) + b_conv3, training=self.is_train, momentum=0.9))conv4 = tf.nn.relu(tf.layers.batch_normalization(self.conv2D(conv3, W_conv4, 1) + b_conv4, training=self.is_train, momentum=0.9))conv5 = tf.nn.relu(tf.layers.batch_normalization(self.conv2D(conv4, W_conv5, 1) + b_conv5, training=self.is_train, momentum=0.9))flatten = tf.reshape(conv5, [-1, 1600])'''conv1 = tf.nn.relu(self.conv2D(x, W_conv1, 4) + b_conv1)pool1 = self.maxpool(conv1)conv2 = tf.nn.relu(self.conv2D(pool1, W_conv2, 2) + b_conv2)conv3 = tf.nn.relu(self.conv2D(conv2, W_conv3, 1) + b_conv3)flatten = tf.reshape(conv3, [-1, 1600])fc1 = tf.nn.relu(tf.layers.batch_normalization(tf.matmul(flatten, W_fc1) + b_fc1, training=self.is_train, momentum=0.9))fc2 = tf.matmul(fc1, W_fc2) + b_fc2return x, fc2

到这里,我们整体的程序就搭建完成,下面为我们程序的运行结果:

源码地址:

https://pan.baidu.com/s/1ksvjIiQ0BfXOah4PIE1arg

提取码:p74p

作者简介:

李秋键,CSDN博客专家,CSDN达人课作者。硕士在读于中国矿业大学,开发有taptap竞赛获奖等等。

推荐阅读
  • 利用 AssemblyAI 在 PyTorch 中建立端到端的语音识别模型

  • 京东姚霆:推理能力,正是多模态技术未来亟需突破的瓶颈

  • 性能超越最新序列推荐模型,华为诺亚方舟提出记忆增强的图神经网络

  • FPGA 无解漏洞 “StarBleed”轰动一时,今天来扒一下技术细节!

  • 真惨!连各大编程语言都摆起地摊了

  • 发送0.55 ETH花费近260万美元!这笔神秘交易引发大猜想

你点的每个“在看”,我都认真当成了AI

Python 还能实现哪些 AI 游戏?附上代码一起来一把!相关推荐

  1. ai python 代码提示插件_Python 还能实现哪些 AI 游戏?附上代码一起来一把!

    人工智能作为当前热门在我们生活中得到了广泛应用,尤其是在智能游戏方面,有的已经达到了可以和职业选手匹敌的效果. 而DQN算法作为智能游戏的经典选择算法,其主要是通过奖励惩罚机制来迭代模型,来达到更接近 ...

  2. python游戏脚本实例-Python使用pygame模块编写俄罗斯方块游戏的代码实例

    文章先介绍了关于俄罗斯方块游戏的几个术语. 边框――由10*20个空格组成,方块就落在这里面. 盒子――组成方块的其中小方块,是组成方块的基本单元. 方块――从边框顶掉下的东西,游戏者可以翻转和改变位 ...

  3. 最简单的python语言实现汉诺塔游戏

    最简单的python语言实现汉诺塔游戏 实现代码 def hanoi(n,ch1,ch2,ch3):if n==1:print(ch1, '->', ch3)else:hanoi(n - 1, ...

  4. 我的名片能运行Linux和Python,还能玩2048小游戏,成本只要20元

    晓查 发自 凹非寺  量子位 报道 | 公众号 QbitAI 猜猜它是什么?印着姓名.职位和邮箱,看起来是个名片.可是右下角有芯片,看起来又像是个PCB电路板. 其实它是一台超迷你的ARM计算机,不仅 ...

  5. Serpent.AI - 游戏代理框架(Python)

    Serpent.AI - 游戏代理框架(Python) Serpent.AI是一个简单而强大的新颖框架,可帮助开发人员创建游戏代理.将您拥有的任何视频游戏变成一个成熟的实验的沙箱环境,所有这些都是熟悉 ...

  6. 贪吃蛇博弈算法python_算法应用实践:如何用Python写一个贪吃蛇AI

    原标题:算法应用实践:如何用Python写一个贪吃蛇AI 前言 这两天在网上看到一张让人涨姿势的图片,图片中展示的是贪吃蛇游戏, 估计大部分人都玩过.但如果仅仅是贪吃蛇游戏,那么它就没有什么让人涨姿势 ...

  7. python飞机游戏视频教程_10分钟教你用Python做个打飞机小游戏超详细教程

    01 前言 这次还是用python的pygame库来做的游戏.关于这个库的内容,读者可以上网了解一下.本文只讲解用到的知识.代码参考自网上,自己也做了一点代码简化.尽量把最核心的方面用最简单的方式呈现 ...

  8. Python 为何能坐稳 AI 时代头牌语言

    谁会成为AI 和大数据时代的第一开发语言?这本已是一个不需要争论的问题.如果说三年前,Matlab.Scala.R.Java 和 Python还各有机会,局面尚且不清楚,那么三年之后,趋势已经非常明确 ...

  9. 用python自带的tkinter做游戏(一)—— 贪吃蛇 篇

    用python自带的tkinter做游戏(一)-- 贪吃蛇 篇 本人新手,刚自学python满半年,现分享下心得,希望各位老手能指点一二,也希望和我一样的新手能共勉,谢谢~ 大家都知道用python做 ...

最新文章

  1. 【Java集合框架】ArrayList类方法简明解析(举例说明)
  2. EPSON机器人建立TCP/IP通讯的简单demo
  3. .net一个函数要用另一个函数的值_【195期】MySQL中的条件判断函数 CASE WHEN、IF、IFNULL你会用吗?...
  4. SAP CRM产品主数据里的七种ID
  5. FLEX Array和ArrayCollection的区别
  6. 前端 鼠标一次移动半个像素_今天来说说鼠标的DPI该怎么设置
  7. 京东黑科技引爆车联网时代 你的爱车升级了吗?
  8. linux 查看c库版本号,C语言再学习 -- 查看版本及内核信息(转)
  9. mysql上传spc数据慢_SPC实施篇:控制图数据处理这8个细节要注意!
  10. eclipse添加约束文件
  11. PC蓝牙加串口调试助手调试蓝牙设备
  12. Web系统大规模并发-电商秒杀与抢购
  13. SQLZOO 练习题 6 JOIN
  14. 支持向量机(SVM)入门理解与推导
  15. 亚马逊长尾关键词是什么?亚马逊长尾关键词优势
  16. 熟悉的人不认识我了,不熟悉的人认识我了
  17. 阿里云主机免费申请级网站配置
  18. python使用继承开发动物和狗
  19. 读取和博客可视化分析
  20. [ctf.show.reverse] re3

热门文章

  1. L09-10老男孩Linux运维实战培训-Nginx服务生产实战应用指南05(架构解决方案)
  2. pycharm第一个Python程序
  3. Qt 自定义信号与槽
  4. Nagios+pnp4nagios+rrdtool 安装配置为nagios添加自定义插件(三)
  5. 2018/8/24阅读文献 A Unified Model for Multi-Objective Evolutionary Algorithms with Elitism
  6. UOJ #53.线段树区间修改
  7. 学习总结--团队项目
  8. jQuery-1.样式篇---选择器
  9. 石子合并[DP-N3]
  10. junit、hamcrest、eclemma的安装与使用