玩转Atari-Pong游戏

该项目基于PaddlePaddle框架完成,详情见玩转Atari-Pong游戏

  • Atari: 雅达利,最初是一家游戏公司,旗下有超过200款游戏,不过已经破产。在强化学习中,Atari游戏是经典的实验环境之一,因此,本项目旨在学习使用强化学习算法玩Atari游戏。
  • Pong: 1972年,雅达利(Atari)创办人布什内尔及达布尼推出首款街机Pong,最初仅生产12部,以简单点线接口仿真打乒乓球的游戏,奠定街机始祖地位。该游戏的简略版英文描述为:

You control the right paddle, you compete against the left paddle controlled by the computer. You each try to keep deflecting the ball away from your goal and into your opponent’s goal.

翻译成中文就是:

你控制右边的球拍,你与电脑控制的左边的球拍竞争。你们各自努力使球不断偏离自己的目标,进入对手的目标。

游戏示意图:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-IpLUQSTe-1668436245727)(https://ai-studio-static-online.cdn.bcebos.com/cbf668ff565c476db70f1c73b0d1fa7d9dc0269401b64c5b8c53389a4d1393fa)]

从该动态图可以看出,不经训练的右侧球拍完全打不过左侧球拍的,因此我们的目标就是训练右侧球拍使其战胜左侧球拍。

  • Pong环境的状态、动作与奖励:

    • 状态:Pong环境提供的状态默认是Box(210, 160, 3),也就是3通道的彩色图
    • 动作:Pong-v0和Pong-V4版本返回的动作都是Discrete(6),也就是离散的6个动作。网上有介绍:Pong 环境介绍,提到其实6个动作中有用的只有3个,可以参考该介绍,加深理解。
    • 奖励:奖励有三种状态:-1,0,1,分别表示右侧未接到球;中间过程;左侧未接到球。
  • 训练结果展示:

1.Atari环境的安装

在运行man.ipynb之前,请先运行help.ipynb生成我们的依赖环境!!!

目前Ai studio平台并没有内嵌Atari环境,需要我们自行安装,为避免反复安装,我们将安装过程写到help.ipynb。可运行我们的help.ipynb进行持久化安装。主要的安装命令如下所示:

  1. ! pip install atari_py==0.2.6 -i https://pypi.tuna.tsinghua.edu.cn/simple -t /home/aistudio/external-libraries
  2. ! pip install ale-py -i https://pypi.tuna.tsinghua.edu.cn/simple -t /home/aistudio/external-libraries
  3. ! pip install pyglet -i https://pypi.tuna.tsinghua.edu.cn/simple -t /home/aistudio/external-libraries
  4. ! pip install autorom -i https://pypi.tuna.tsinghua.edu.cn/simple -t /home/aistudio/external-libraries
  5. ! pip install AutoROM.accept-rom-license -i https://pypi.tuna.tsinghua.edu.cn/simple -t /home/aistudio/external-libraries
  6. !rar x Roms.rar
  7. !python -m atari_py.import_roms ROMS

其中需要注意:第4、5条安装命令可能无法一次成功,多运行几次即可;第6条命令一个项目仅运行一次即可。

2.导入我们的依赖包

注意要先将我们自行安装的Atari环境加入到系统中,即

sys.path.append(‘/home/aistudio/external-libraries’)

import sys
sys.path.append('/home/aistudio/external-libraries')import gym
import numpy as np
import time
import matplotlib.pyplot as plt
import paddle
import os
from collections import deque,Counter
from visualdl import LogWriter
import copy
from collections import Counter
from matplotlib import animation
from PIL import Image

3.环境测试

检测我们是否可以成功加载环境,并查看我们的状态空间和动作空间

env = gym.make('Pong-v4')
print(env.observation_space)
print(env.action_space)
Box(210, 160, 3)
Discrete(6)

4.状态的预处理

在这里我们首先定义了状态的预处理函数preprocess,该函数说明如下:

  • 输入:状态,Pong环境给出的不加任何处理的环境状态,Box(210, 160, 3)
  • 处理:处理过程可以看我们下边的过程图片。
    • 裁剪:将实际没有用的部分去除,主要是Pong环境返回的图像的上边和下边的部分
    • 下采样:在保留特征的前提下进行像素点的缩减
    • 擦除背景,在我们下采样后,环境的背景其实是有两种(109,144),这个也需要多观察才能看出,可以参考我们给出的示例图。
    • 转为灰度图:非0即1,我们仅保留左右球拍和球,减少不必要因素的干扰
    • 打平:将图像打平,进而只使用线性层进行特征学习

4.1 preprocess函数

def preprocess(image):""" 预处理 210x160x3 uint8 frame into 6400 (80x80) 1维 float vector """image = image[35:195]  # 裁剪image = image[::2, ::2, 0]  # 下采样,缩放2倍image[image == 144] = 0  # 擦除背景 (background type 1)image[image == 109] = 0  # 擦除背景 image[image != 0] = 1  # 转为灰度图,除了黑色外其他都是白色return image.astype(np.float).ravel() #打平,(6400,)

4.2 对preprocess函数进行可视化说明,展示中间过程

def show_image(status):status1=status[35:195] #裁剪有效区域status2 = status1[::2, ::2, 0] #下采样,缩减# 观察我们的像素点构成def see_color(status):allcolor=[]for i in range(80):allcolor.extend(status[i])dict_color=Counter(allcolor)print("像素点构成: ",dict_color)see_color(status2)# 观察好像素点后,擦除背景def togray(image_in):image=image_in.copy()image[image == 144] = 0  # 擦除背景 (background type 1)image[image == 109] = 0  # 擦除背景image[image != 0] = 1  # 转为灰度图,除了黑色外其他都是白色return imagestatus3=togray(status2)# 可视化我们的操作中间图def show_status(list_status):fig = plt.figure(figsize=(8, 32), dpi=200)plt.subplots_adjust(left=None, bottom=None, right=None, top=None,wspace=0.3, hspace=0)for i in range(len(list_status)):plt.subplot(1,len(list_status),i+1)plt.imshow(list_status[i],cmap=plt.cm.binary)plt.show()show_status([status,status1,status2,status3])

4.3 背景为109的preprocess展示

status = env.reset() #原始图
show_image(status)
像素点构成:  Counter({109: 6382, 101: 16, 53: 2})/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingif isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingreturn list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_max = np.asscalar(a_max.astype(scaled_dtype))

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EavJ8qDW-1668436245730)(main_files/main_13_2.png)]

4.4 背景为144的preprocess展示

for i in range(200):action=env.action_space.sample()status,reward,done,info=env.step(action)show_image(status)
像素点构成:  Counter({144: 6366, 213: 16, 92: 16, 236: 2})

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fYpFJYfP-1668436245731)(main_files/main_15_1.png)]

5.模型的定义,简单的全连接层

class Model(paddle.nn.Layer):""" 使用全连接网络.参数:obs_dim (int): 观测空间的维度.act_dim (int): 动作空间的维度."""def __init__(self, obs_dim, act_dim):super(Model, self).__init__()hid1_size = 256hid2_size = 64self.fc1 = paddle.nn.Linear(obs_dim, hid1_size)self.fc2 = paddle.nn.Linear(hid1_size, hid2_size)self.fc3 = paddle.nn.Linear(hid2_size, act_dim)def forward(self, obs): h1 = paddle.nn.functional.relu(self.fc1(obs))h2 = paddle.nn.functional.relu(self.fc2(h1))prob = paddle.nn.functional.softmax(self.fc3(h2), axis=-1)return prob

6.策略梯度算法

强化学习的经典算法之一,可以参考我们之前的项目【强化学习】REINFORCE算法

在这里我们仅定义预测更新两个函数。

# 梯度下降算法
class PolicyGradient():def __init__(self, model, lr):self.model = modelself.optimizer = paddle.optimizer.Adam(learning_rate=lr, parameters=self.model.parameters())def predict(self, obs):prob = self.model(obs)return probdef learn(self, obs, action, reward):prob = self.model(obs)#print("prob: ",prob)log_prob = paddle.distribution.Categorical(prob).log_prob(action)loss = paddle.mean(-1 * log_prob * reward)self.optimizer.clear_grad()loss.backward()self.optimizer.step()return loss

7.策略梯度智能体

  • 我们默认从文件中加载参数进行训练,因为PG算法+Pong环境的训练需要大量的时间,一次直接训练完成很耗时;当然我们支持从0开始训练
  • sample: 在训练时调用的函数,带探索
  • predict:在预测(测试)时调用的函数,不带探索
  • learn:更新函数
  • save和load:保存参数和加载参数。注意:这里我们保存了优化器的参数,但是在加载是并未加载上优化器的参数,有报错,未进行修复,但是不加载优化器参数几乎不影响我们的训练的。(这里我其实不太明白到底需不需加载优化器参数,还望大佬不吝赐教,拜谢)
class Agent():def __init__(self, algorithm):self.alg=algorithmif os.path.exists("./savemodel"):print("开始从文件加载参数....")try:self.load()print("从文件加载参数结束....")except:print("从文件加载参数失败,从0开始训练....")def sample(self, obs):""" 根据观测值 obs 采样(带探索)一个动作"""obs = paddle.to_tensor(obs, dtype='float32')prob = self.alg.predict(obs)#print("prob:",prob)prob = prob.numpy()act = np.random.choice(len(prob), 1, p=prob)[0]  # 根据动作概率选取动作return actdef predict(self, obs):""" 根据观测值 obs 选择最优动作"""obs = paddle.to_tensor(obs, dtype='float32')prob = self.alg.predict(obs)act = prob.argmax().numpy()[0]  # 根据动作概率选择概率最高的动作return actdef learn(self, obs, act, reward):""" 根据训练数据更新一次模型参数"""act = np.expand_dims(act, axis=-1)reward = np.expand_dims(reward, axis=-1)obs = paddle.to_tensor(obs, dtype='float32')act = paddle.to_tensor(act, dtype='int32')reward = paddle.to_tensor(reward, dtype='float32')#print("gggggggggggggg",obs.shape,act.shape,reward.shape)loss = self.alg.learn(obs, act, reward)return loss.numpy()[0]def save(self):paddle.save(self.alg.model.state_dict(),'./savemodel/PG-Pong_net.pdparams')paddle.save(self.alg.optimizer.state_dict(), "./savemodel/opt.pdopt")def load(self):# 加载网络参数model_state_dict=paddle.load('./savemodel/PG-Pong_net.pdparams')self.alg.model.set_state_dict(model_state_dict)# # 加载优化器参数# optimizer_state_dict=paddle.load("./savemodel/opt.pdopt")# self.alg.optimizer.set_state_dict(optimizer_state_dict)

8. 训练与测试

8.1 定义训练函数

# 训练一个episode
def run_train_episode(agent, env):obs_list, action_list, reward_list = [], [], []obs = env.reset()while True:obs = preprocess(obs)  # from shape (210, 160, 3) to (6400,)obs_list.append(obs)action = agent.sample(obs)action_list.append(action)obs, reward, done, info = env.step(action)# if reward!=0:#     print("reward: ",action)reward_list.append(reward)if done:breakreturn obs_list, action_list, reward_list

8.2 定义预测函数

# 评估 agent, 跑 5 个episode,总reward求平均
def run_evaluate_episodes(agent, env, render=False):eval_reward = []for i in range(5):obs = env.reset()episode_reward = 0while True:obs = preprocess(obs)  # from shape (210, 160, 3) to (6400,)action = agent.predict(obs)obs, reward, isOver, _ = env.step(action)episode_reward += rewardif render:env.render()if isOver:breakeval_reward.append(episode_reward)return np.mean(eval_reward)

8.3 定义奖励处理函数

进行奖励衰减操作,衰减因子gamma默认为0.99

def calc_reward_to_go(reward_list, gamma=0.99):"""calculate discounted reward"""reward_arr = np.array(reward_list)for i in range(len(reward_arr) - 2, -1, -1):# G_t = r_t + γ·r_t+1 + ... = r_t + γ·G_t+1reward_arr[i] += gamma * reward_arr[i + 1]# normalize episode rewardsreward_arr -= np.mean(reward_arr)reward_arr /= np.std(reward_arr)return reward_arr

8.4 训练与预测的主函数

便于演示,我们仅进行100次的继续训练,读者可自行增加次数以获得更好的训练效果

def main():env = gym.make('Pong-v4')obs_dim = 80 * 80act_dim = env.action_space.nprint('obs_dim {}, act_dim {}'.format(obs_dim, act_dim))# 根据parl框架构建agentLEARNING_RATE = 5e-4model = Model(obs_dim=obs_dim, act_dim=act_dim)alg = PolicyGradient(model, lr=LEARNING_RATE)agent = Agent(alg)twriter=LogWriter('./logs/PG_Pong')for i in range(100): # default 3000obs_list, action_list, reward_list = run_train_episode(agent, env)twriter.add_scalar('reward',sum(reward_list),i)if i % 50 == 0:print("Episode {}, Reward Sum {}.".format(i, sum(reward_list)))batch_obs = np.array(obs_list)batch_action = np.array(action_list)batch_reward = calc_reward_to_go(reward_list)#print("ggggggggggggg",batch_obs.shape)agent.learn(batch_obs, batch_action, batch_reward)last_test_total_reward=0if (i + 1) % 100 == 0:# render=True 查看显示效果total_reward = run_evaluate_episodes(agent, env, render=False)print('Test reward: {}'.format(total_reward))# save the parametersif last_test_total_reward<total_reward:last_test_total_reward=total_rewardagent.save()# 运行整个程序
main()
obs_dim 6400, act_dim 6W1022 22:01:06.998914   174 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1022 22:01:07.003042   174 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.开始从文件加载参数....
从文件加载参数结束....
Episode 0, Reward Sum 14.0.
Episode 50, Reward Sum 8.0.
Test reward: 12.0

9.使用训练好的网络进行测试并生成动图

9.1 gif动图生成函数

def save_frames_as_gif(frames, filename):#Mess with this to change frame sizeplt.figure(figsize=(frames[0].shape[1]/100, frames[0].shape[0]/100), dpi=300)patch = plt.imshow(frames[0])plt.axis('off')def animate(i):patch.set_data(frames[i])anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50)anim.save(filename, writer='pillow', fps=60)

9.2 从文件加载模型参数

model=Model(6400,6)
model_state_dict=paddle.load("./savemodel/PG-Pong_net.pdparams")
model.set_state_dict(model_state_dict)

9.4 使用训练好的模型进行测试并保存过程为动图

env=gym.make('Pong-v4')state=env.reset()
frames = []
done=0
i=0
reward_list=[]
while not done:frames.append(env.render(mode="rgb_array"))obs = preprocess(state)obs = paddle.to_tensor(obs, dtype='float32')prob = model(obs)action = prob.argmax().numpy()[0]next_state,reward,done,_=env.step(action)if reward!=0:reward_list.append(reward)print(i,"   ",reward,done)state=next_statei+=1reward_counter=Counter(reward_list)
print(reward_counter)
print("你的得分为:",reward_counter[1.0],'对手得分为:',reward_counter[-1.0])
if reward_counter[1.0]>reward_counter[-1.0]:print("恭喜您赢了!!!")
else:print("惜败,惜败,训练一下智能体网络再来挑战吧QWQ")save_frames_as_gif(frames, filename="Pong-v4_trained.gif")env.close()
199     1.0 False
732     1.0 False
937     1.0 False
1547     1.0 False
1676     1.0 False
1877     1.0 False
2165     1.0 False
2451     1.0 False
2575     1.0 False
2705     1.0 False
2995     1.0 False
3125     1.0 False
3331     1.0 False
3454     1.0 False
3584     1.0 False
3793     1.0 False
4885     1.0 False
5096     1.0 False
5698     1.0 False
5992     1.0 False
6202     1.0 True
Counter({1.0: 21})
你的得分为: 21 对手得分为: 0
恭喜您赢了!!!

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yCQ3yFtg-1668436245733)(main_files/main_37_1.png)]

10. 总结

本项目参考自飞桨PARL,鼓励大家给点点stars
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-adwdBVQe-1668436245734)(https://ai-studio-static-online.cdn.bcebos.com/92d792700be949219afc12e2d76920190d929c42685e4d29917d3b34fd86fec7)]

本项目目前通过5000+回合的训练,我们的智能体已经学会通过快速抖动法取得游戏的胜利了,但是大概率还不能完全碾压,后续有时间会继续训练或采取更加高效的算法进行改进。然后,这是我的第一个Atari游戏项目,之前都在在经典的控制游戏下进行实验,环境的转变使得学习的难度也上升,训练时间也在增加,学到的东西也在增加,挺好的…还请大佬多多指教,小黑还有很多路要走,嘿嘿!

之前的强化学习项目有:

  • DQN+CartPole-v0
  • A2C+CartPole-v0
  • DDPG+Pendulum-v0
  • TD3+Pendulum-v0
  • REINFORCE+CartPole-v0
  • PPO+CartPole-v0
  • SAC+Pendulum-v0

欢迎大家来交流学习!!!

tionType=1&shared=1)

  • DDPG+Pendulum-v0
  • TD3+Pendulum-v0
  • REINFORCE+CartPole-v0
  • PPO+CartPole-v0
  • SAC+Pendulum-v0

欢迎大家来交流学习!!!

强化学习:玩转Atari-Pong游戏相关推荐

  1. 用深度强化学习玩atari游戏_被追捧为“圣杯”的深度强化学习已走进死胡同

    作者 | 朱仲光 编译 | 夕颜出品 | AI科技大本营(ID:rgznai1100) [导读]近年来,深度强化学习成为一个被业界和学术界追捧的热门技术,社区甚至将它视为金光闪闪的通向 AGI 的圣杯 ...

  2. 深度强化学习加载Atari游戏运行库:Could not find module “XXXX\lib\site-packages\atari_py\ale_interface\ale_c.dll“

    深度强化学习加载Atari游戏运行库:Could not find module "XXXX\lib\site-packages\atari_py\ale_interface\ale_c.d ...

  3. 教程:用强化学习玩转恐龙跳跳

    DeepMind在2013年发表了一篇题为<用深度强化学习玩Atari>的文章,介绍了一种新的用于强化学习的深度学习模型,并展示了它仅使用原始像素作为输入来掌握Atari 2600计算机游 ...

  4. 用深度强化学习玩atari游戏_(一)深度强化学习·入门从游戏开始

    1.在开始正式进入学习之前,有几个概念需要澄清,这样有利于我们对后续的学习有一个大致的框架感 监督型学习与无监督型学习 深度强化学习的范畴 监督型学习是基于已有的带有分类标签的数据集合,来拟合神经网络 ...

  5. 深度学习算法(第37期)----如何用强化学习玩游戏?

    上期我们一起学习了强化学习中的时间差分学习和近似Q学习的相关知识, 今天我们一起用毕生所学来训练一个玩游戏的AI智能体. 由于我们将使用 Atari 环境,我们必须首先安装 OpenAI gym 的 ...

  6. 【强化学习】Playing Atari with Deep Reinforcement Learning (2013)

    Playing Atari with Deep Reinforcement Learning (2013) 这篇文章提出了第一个可以直接用强化学习成功学习控制policies的深度学习模型. 输入是r ...

  7. 用强化学习玩《超级马里奥》

    Pytorch的一个强化的学习教程( Train a Mario-playing RL Agent)使用超级玛丽游戏来学习双Q网络(强化学习的一种类型),官网的文章只有代码, 所以本文将配合官网网站的 ...

  8. CNTK与深度强化学习笔记: Cart Pole游戏示例

    CNTK与深度强化学习笔记之二: Cart Pole游戏示例 前言 前面一篇文章,CNTK与深度强化学习笔记之一: 环境搭建和基本概念,非常概要的介绍了CNTK,深度强化学习和DQN的一些基本概念.这 ...

  9. AlphaStar再升级:多智能体强化学习玩《星际争霸2》,排名超99.8%人类玩家

    [进群了解最新免费公开课.技术沙龙信息] 作者 | DeepMind 译者 | 刘畅 编辑 | Jane 出品 | AI科技大本营(ID:rgznai100) AlphaStar是第一个在没有任何游戏 ...

  10. 用深度强化学习玩超级马里奥兄弟

    介绍 从本文中,你将学习如何使用 Deep Q-Network 和 Double Deep Q-Network(带代码!)玩超级马里奥兄弟. 超级马里奥是任天堂在 1980 年代开发和发行的著名游戏. ...

最新文章

  1. jquery过滤HTML标签方法
  2. mysql5.0操作手册_MySQL 操作手册
  3. 怎样升级android10版本,手机怎么升级win10系统 win10手机版升级教程
  4. ps导出gif颜色不对_PS の手绘《超详细的动态表情包新手绘制指南》
  5. 2013年总结(4)-人脉
  6. this指针常识性问题
  7. win10系统安装sql不上服务器,win10安装sql2000没有反应怎么办_win10安装不了sql2000的解决方法...
  8. macOS 安装postman 中文语言包
  9. Word文档中统一字符串八大妙法(转)
  10. C# 电子印章制作管理系统
  11. 燃尽图 (Burn up and Burn down Chart)—介绍
  12. php 编码转换 乱码解决
  13. my ReadTravel_ Choson / Tailand Racha Island / Phuket Island / Malaysia
  14. APP2SD图文储存卡分区教程
  15. SQL经典案例(学生表,课程表,选课表,教师表) 练习
  16. ABAQUS学习(2):Abaqus求解好后导出点坐标/位移/应变
  17. Gulp折腾记 - (3)常用任务构建的demo[改进版]
  18. 如何用在自己的网页中嵌入腾讯视频网页播放器播放一些文件
  19. 网站推广策略-网站推广120种实用方法_打杂的_新浪博客
  20. 音频信号处理(二)语音信号采集处理与基音周期

热门文章

  1. devcpp的简单使用
  2. 【UVa11584】划分成回文串
  3. 不变初心数 (15 分) C语言
  4. 搭建免费私人服务器---用你的笔记本做服务器
  5. 消极和积极的道德--给亲爱的安德烈
  6. 电压暂降求交流,加Q
  7. 打乱魔方软件_家里魔方吃灰了?这三款魔方App教你轻松上手
  8. 《基于Cortex-M4的虚拟机制作与测试》课程设计 结题报告
  9. 按键精灵手机版_关于截屏一些方法
  10. raid5换硬盘显示ready_[原创]戴尔服务器raid5更换硬盘状态foreign怎么改成ready