2048游戏DQN实验

  • 背景
    • 工作
    • 分析问题
      • 状态表征
      • 强化学习算法
      • 参数设计
    • 代码实现
    • 实验结果
      • CNN输入
      • 全连接输入
      • CNN input + Priority
    • 总结

背景

我已经做过一些强化学习相关项目,本科的时候也用min-max搜索做过2048,一直觉得2048应该是适合被强化学习解决的,但是查询发现并没有比较合适靠谱的实现代码,于是完成并开源了我的一部分实现工作,供RL learner 参考,github链接 https://github.com/YangRui2015/2048_env。

工作

  1. 修改了https://github.com/rgal/gym-2048的gym封装的2048环境,增加最大步数和最大非法步数限制,能降低训练难度,增加info输出;
  2. DQN算法实现,训练和测试,模型保存和加载(使用pytorch);
  3. logger日志代码实现,包括控制台、txt文件、tensorboard等数据格式的日志;

DQN实现功能或trick有:

  • CNN input or flattened input
  • randomly fill buffer first
  • soft target replacing
  • linear epsilon decay
  • clip gradient norm
  • Double DQN
  • priority experience replay

分析问题

一个典型的深度强化学习问题,主要有以下几个基本点:

  1. 状态、动作、奖励的设计与表征;
  2. 强化学习算法的实现和参数选择;
  3. 神经网络的设计和调参;

在以上几点完成的基础上,算法的提升主要有三个方面:

  1. 提高训练的稳定性;
  2. 提高训练的速度;
  3. 提高算法的performance;

状态表征

环境的状态输出是4*4的矩阵,针对这种状态我们通常选择flatten成一维向量做全连接输入或者使用CNN输入。CNN虽然更有利于提取一些空间的特征,但是flatten后只要网络拟合能力足够也是能够学习到这些非线性特征的。实验表明两者效果接近,但是flatten最开始阶段学习会比CNN快一些,符合我们的预期。

此外,由于状态矩阵中的值以及奖励值2~1024甚至更大,直接输入网络很容易爆炸,需要对输入状态和奖励值做预处理,这里简单的使用了log(x+1)/16实现归一化。

强化学习算法

使用DQN算法,我也尝试过A2C、PPO,但是训练效果都不好,猜想随机策略在这个问题上表现不如确定性策略。
在实现基本DQN算法上,我还实现了DDQN、Priority DQN(参考莫凡强化学习代码),以及加上了一些提高训练稳定性的方法——soft replacing、epsilon decay、clip gradient norm。

参数设计

DQN参数如下:

batch_size = 128
lr = 1e-4
epsilon = 0.15
memory_capacity = int(1e4)
gamma = 0.99
q_network_iteration = 200
soft_update_theta = 0.1
clip_norm_max = 1
train_interval = 5
conv_size = (32, 64) # num filters
fc_size = (512, 128)

代码实现

DQN需要的网络模块如下,主要区别是CNN输入还是flatten输入(具体见NN_module.py)。

# CNN网络
class CNN_Net(nn.Module):def __init__(self, input_len, output_num, conv_size=(32, 64), fc_size=(1024, 128), out_softmax=False):super(CNN_Net, self).__init__()self.input_len = input_lenself.output_num = output_numself.out_softmax = out_softmax self.conv1 = nn.Sequential(nn.Conv2d(1, conv_size[0], kernel_size=3, stride=1, padding=1),# nn.BatchNorm2d(32),nn.ReLU(inplace=True))self.conv2 = nn.Sequential(nn.Conv2d(conv_size[0], conv_size[1], kernel_size=3, stride=1, padding=1),# nn.BatchNorm2d(64),nn.ReLU(inplace=True),# nn.MaxPool2d(kernel_size=2, stride=2))self.fc1 = nn.Linear(conv_size[1] * self.input_len * self.input_len, fc_size[0])self.fc2 = nn.Linear(fc_size[0], fc_size[1])self.head = nn.Linear(fc_size[1], self.output_num)def forward(self, x):x = x.reshape(-1,1,self.input_len, self.input_len)x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))output = self.head(x)if self.out_softmax:output = F.softmax(output, dim=1)   #值函数估计不应该有softmaxreturn output# 全连接网络
class FC_Net(nn.Module):def __init__(self, input_num, output_num, fc_size=(1024, 128), out_softmax=False):super(FC_Net, self).__init__()self.input_num = input_numself.output_num = output_numself.out_softmax = out_softmax self.fc1 = nn.Linear(self.input_num, fc_size[0])self.fc2 = nn.Linear(fc_size[0], fc_size[1])self.head = nn.Linear(fc_size[1], self.output_num)def forward(self, x):x = x.reshape(-1, self.input_num)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))output = self.head(x)if self.out_softmax:output = F.softmax(output, dim=1)   #值函数估计不应该有softmaxreturn output

DQN代码实现如下(具体见DQN_agent.py):

class DQN():batch_size = 128lr = 1e-4epsilon = 0.15   memory_capacity =  int(1e4)gamma = 0.99q_network_iteration = 200save_path = "./save/"soft_update_theta = 0.1clip_norm_max = 1train_interval = 5conv_size = (32, 64)   # num filtersfc_size = (512, 128)def __init__(self, num_state, num_action, enable_double=False, enable_priority=True):super(DQN, self).__init__()self.num_state = num_stateself.num_action = num_actionself.state_len = int(np.sqrt(self.num_state))self.enable_double = enable_doubleself.enable_priority = enable_priorityself.eval_net, self.target_net = CNN_Net(self.state_len, num_action,self.conv_size, self.fc_size), CNN_Net(self.state_len, num_action, self.conv_size, self.fc_size)# self.eval_net, self.target_net = FC_Net(self.num_state, self.num_action), FC_Net(self.num_state, self.num_action)self.learn_step_counter = 0self.buffer = Buffer(self.num_state, 'priority', self.memory_capacity)# self.memory = np.zeros((self.memory_capacity, num_state * 2 + 2))     self.initial_epsilon = self.epsilonself.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=self.lr)def select_action(self, state, random=False, deterministic=False):state = torch.unsqueeze(torch.FloatTensor(state), 0) if not random and np.random.random() > self.epsilon or deterministic:  # greedy policyaction_value = self.eval_net.forward(state)action = torch.max(action_value.reshape(-1,4), 1)[1].data.numpy()else: # random policyaction = np.random.randint(0,self.num_action)return actiondef store_transition(self, state, action, reward, next_state):state = state.reshape(-1)next_state = next_state.reshape(-1)transition = np.hstack((state, [action, reward], next_state))self.buffer.store(transition)# index = self.memory_counter % self.memory_capacity# self.memory[index, :] = transition# self.memory_counter += 1def update(self):#soft update the parametersif self.learn_step_counter % self.q_network_iteration ==0 and self.learn_step_counter:for p_e, p_t in zip(self.eval_net.parameters(), self.target_net.parameters()):p_t.data = self.soft_update_theta * p_e.data + (1 - self.soft_update_theta) * p_t.dataself.learn_step_counter+=1#sample batch from memoryif self.enable_priority:batch_memory, (tree_idx, ISWeights) = self.buffer.sample(self.batch_size)else:batch_memory, _ = self.buffer.sample(self.batch_size)batch_state = torch.FloatTensor(batch_memory[:, :self.num_state])batch_action = torch.LongTensor(batch_memory[:, self.num_state: self.num_state+1].astype(int))batch_reward = torch.FloatTensor(batch_memory[:, self.num_state+1: self.num_state+2])batch_next_state = torch.FloatTensor(batch_memory[:,-self.num_state:])#q_evalq_eval_total = self.eval_net(batch_state)q_eval = q_eval_total.gather(1, batch_action)q_next = self.target_net(batch_next_state).detach()if self.enable_double:q_eval_argmax = q_eval_total.max(1)[1].view(self.batch_size, 1)q_max = q_next.gather(1, q_eval_argmax).view(self.batch_size, 1)else:q_max = q_next.max(1)[0].view(self.batch_size, 1)q_target = batch_reward + self.gamma * q_maxif self.enable_priority:abs_errors = (q_target - q_eval.data).abs()self.buffer.update(tree_idx, abs_errors)# loss = (torch.FloatTensor(ISWeights) * (q_target - q_eval).pow(2)).mean()   loss = (q_target - q_eval).pow(2).mean() # 可能去掉ISweight更好??# print(ISWeights)# print(loss)# import pdb; pdb.set_trace()else:loss = F.mse_loss(q_eval, q_target)self.optimizer.zero_grad()loss.backward()nn.utils.clip_grad_norm_(self.eval_net.parameters(), self.clip_norm_max)self.optimizer.step()return lossdef save(self, path=None, name='dqn_net.pkl'):path = self.save_path if not path else pathutils.check_path_exist(path)torch.save(self.eval_net.state_dict(), path + name)def load(self, path=None, name='dqn_net.pkl'):path = self.save_path if not path else pathself.eval_net.load_state_dict(torch.load(path + name))def epsilon_decay(self, episode, total_episode):self.epsilon = self.initial_epsilon * (1 - episode / total_episode)

其中buffer类实现了普通版和priority版,具体见Buffer_module.py,主函数实现main_dqn.py

实验结果

CNN输入

全连接输入

CNN input + Priority


总结

实验最好的结果已经到了平均得分6000分,继续训练还能继续增长,但是太花费时间和计算资源,于是我没有继续实验了。整个过程对解决实际强化学习问题能有更好的认识,并且实现通过实验能更好的理解一些方法或trick提出的原因。

希望能对大家有帮助。

2048游戏DQN实验相关推荐

  1. 2048游戏c语言实验报告,2048游戏语言实验报告.doc

    2048游戏语言实验报告 成绩评定 教师签名 评定日期 嘉应学院 计算机学院 实验报告 课程名称: C程序设计 开课学期: 2015-2016学年第1学期 班 级: 计算机1505 指导老师: 陈广明 ...

  2. Python 《Python 实现 2048 游戏》实验报告

    74340da14d79fae0a21de03d44699f80b6c624f3.jpg 2048 游戏 wiki:<2048>是一款单人在线和移动端游戏,由19岁的意大利人 Gabrie ...

  3. python课堂实验_用Python做2048游戏 网易云课堂配套实验课。通过GUI来体验编程的乐趣。...

    标签: 第1节 认识wxpython 第2节 画几个形状 第3节 再做个计算器 第4节 最后实现个2048游戏 实验1-认识wxpython 一.实验说明 1. 环境登录 无需密码自动登录,系统用户名 ...

  4. 【STM32单片机】2048游戏设计

    文章目录 一.简介 二.硬件资源 1.硬件准备 2.硬件连接 四.软件设计 1.软件结构 2.主要代码 五.实验现象 一.简介 本项目支持STM32F103/STM32F407控制器,使用TFTLCD ...

  5. 2048游戏分析、讨论与扩展 - Part I - 游戏分析与讨论

    2048这个游戏从刚出开始就风靡整个世界.本技术博客的目的是想对2048涉及到相关的所有问题进行细致的分析与讨论,得到一些大家能够接受并且理解的结果.在这基础上,扩展2048的游戏性,使其变得更好玩, ...

  6. 带你用Python制作超级经典的2048游戏(文末赠书)

    名字:阿玥的小东东 学习:Python.C/C++ 主页链接:阿玥的小东东的博客_CSDN博客-python&&c++高级知识,过年必备,C/C++知识讲解领域博主 目录 2048游戏 ...

  7. 是男人就下100层【第五层】——2048游戏从源代码到公布市场

    上一篇<是男人就下100层[第五层]--换肤版2048游戏>中阳光小强对2048游戏用自己的方式进行了实现,并分享了核心源码,这一篇阳光小强打算将该项目的全部源码公开并结合这个实例在这篇文 ...

  8. Cocos2d-xna : 横版战略游戏开发实验5 TiledMap实现关卡地图

    Cocos2d-xna : 横版战略游戏开发实验5 TiledMap实现关卡地图 在前面的几篇中动手实验使用了CCSprite.CCScene.CCLayer.CCAction.CCMenu等coco ...

  9. 2048游戏-AI程序算法分析

    针对目前火爆的2048游戏,有人实现了一个AI程序,可以以较大概率(高于90%)赢得游戏,并且作者在stackoverflow上简要介绍了AI的算法框架和实现思路.但是这个回答主要集中在启发函数的选取 ...

  10. 2048游戏的python实现

    2019独角兽企业重金招聘Python工程师标准>>> 一个2048小游戏的python实现 今天看了OSC网友xiaohui_hubei的2048游戏代码感觉很有意思,特意花时间玩 ...

最新文章

  1. 个人理解卷积 池化 的用处
  2. 理解并演示:思科的netflow功能(200-120新增考点)
  3. MFC显示JPG、JIF图片
  4. Qt下的OpenGL 编程(1)Qt下的OpenGL编程必须步骤
  5. 手把手教会你(单/多)文件上传(并修改文件默认的最大最小值)
  6. 所谓语音合成 是计算机根据语言学,计算语言学完整1
  7. 计算机专业能不能转音乐系,中国音乐学院可以转专业吗,中国音乐学院新生转专业政策...
  8. numpy中的方差、协方差、相关系数
  9. TWebBrowser 与 MSHTML(3): window 对象的属性、方法、事件纵览
  10. HDR色调映射(一):基础概念
  11. python爬虫——爬起点中文网小说
  12. 使用python对目录下的文件进行分类
  13. 关于同步、异步传输的解释
  14. ONF与天地互连共同成立开放SDN推广中心(OSPC)
  15. 植物大战僵尸二:游戏界面的绘制
  16. 企业电子邮件系统全局地址簿管理及使用方法介绍
  17. thinkphp5 验证码跨域/验证失败 问题解决方案
  18. 云服务器的防火墙有什么作用?
  19. 屏山计算机学校,四川省屏山县职业技术学校怎么样、好不好
  20. 听迅雷COO程浩先生演讲有感

热门文章

  1. OTL/OCL/BTL/甲类/乙类/甲乙类
  2. 未知错误,可能由于拨号连接未创建成功
  3. 【WLAN从入门到精通-基础篇】第1期——WLAN定义和基本架构
  4. 无线网络连接 wlan test
  5. linux文件系统 ubi,UBI文件系统简介
  6. 支付宝登录java和android
  7. 活出富有成效和充实的十年:让新的一年有个好开始的三条秘诀
  8. luogu P5294 [HNOI2019]序列
  9. 华人工程师在美国-从微软高管离职说起
  10. ATTCK随笔系列之二:偷天陷阱