蒙特卡洛树搜索(MCTS)的实例代码
另一篇博客对代码的讲解
原理:
在当前树节点(设为A)状态下,如果所有子节点都展开了,则按UCT算法选择最优节点作为当前节点,循环下去,直到该节点有未展开的子节点,则从未展开的子节点里瞎选一个并展开它(设为B),从B开始进行模拟(走下去直到游戏结束),得到该此模拟的Reward, 从B开始往上回溯(一直到A),沿途累加上该次Reward和模拟次数1;
以上步骤,在A状态下可重复很多次直到超时为止,此时从A的所有子节点里选择胜率最高的一个(UCT公式的左半边),走一步(并把该子节点设为A);
以上循环执行就能一步步走下去;
#!/usr/bin/env python
# -*- coding: utf-8 -*-import sys
import math
import random
import numpy as npAVAILABLE_CHOICES = [1, -1, 2, -2]
AVAILABLE_CHOICE_NUMBER = len(AVAILABLE_CHOICES)
MAX_ROUND_NUMBER = 10class State(object):"""蒙特卡罗树搜索的游戏状态,记录在某一个Node节点下的状态数据,包含当前的游戏得分、当前的游戏round数、从开始到当前的执行记录。需要实现判断当前状态是否达到游戏结束状态,支持从Action集合中随机取出操作。"""def __init__(self):self.current_value = 0.0# For the first root node, the index is 0 and the game should start from 1self.current_round_index = 0self.cumulative_choices = []def get_current_value(self):return self.current_valuedef set_current_value(self, value):self.current_value = valuedef get_current_round_index(self):return self.current_round_indexdef set_current_round_index(self, turn):self.current_round_index = turndef get_cumulative_choices(self):return self.cumulative_choicesdef set_cumulative_choices(self, choices):self.cumulative_choices = choicesdef is_terminal(self):# The round index starts from 1 to max round numberreturn self.current_round_index == MAX_ROUND_NUMBERdef compute_reward(self):return -abs(1 - self.current_value)def get_next_state_with_random_choice(self):random_choice = random.choice([choice for choice in AVAILABLE_CHOICES])next_state = State()next_state.set_current_value(self.current_value + random_choice)next_state.set_current_round_index(self.current_round_index + 1)next_state.set_cumulative_choices(self.cumulative_choices +[random_choice])return next_statedef __repr__(self):return "State: {}, value: {}, round: {}, choices: {}".format(hash(self), self.current_value, self.current_round_index,self.cumulative_choices)class Node(object):"""蒙特卡罗树搜索的树结构的Node,包含了父节点和直接点等信息,还有用于计算UCB的遍历次数和quality值,还有游戏选择这个Node的State。"""def __init__(self):self.parent = Noneself.children = []self.visit_times = 0self.quality_value = 0.0self.state = Nonedef set_state(self, state):self.state = statedef get_state(self):return self.statedef get_parent(self):return self.parentdef set_parent(self, parent):self.parent = parentdef get_children(self):return self.childrendef get_visit_times(self):return self.visit_timesdef set_visit_times(self, times):self.visit_times = timesdef visit_times_add_one(self):self.visit_times += 1def get_quality_value(self):return self.quality_valuedef set_quality_value(self, value):self.quality_value = valuedef quality_value_add_n(self, n):self.quality_value += ndef is_all_expand(self):return len(self.children) == AVAILABLE_CHOICE_NUMBERdef add_child(self, sub_node):sub_node.set_parent(self)self.children.append(sub_node)def __repr__(self):return "Node: {}, Q/N: {}/{}, state: {}".format(hash(self), self.quality_value, self.visit_times, self.state)def tree_policy(node):"""蒙特卡罗树搜索的Selection和Expansion阶段,传入当前需要开始搜索的节点(例如根节点),根据exploration/exploitation算法返回最好的需要expend的节点,注意如果节点是叶子结点直接返回。基本策略是先找当前未选择过的子节点,如果有多个则随机选。如果都选择过就找权衡过exploration/exploitation的UCB值最大的,如果UCB值相等则随机选。"""# Check if the current node is the leaf nodewhile node.get_state().is_terminal() == False:if node.is_all_expand():node = best_child(node, True)else:# Return the new sub nodesub_node = expand(node)return sub_node# Return the leaf nodereturn nodedef default_policy(node):"""蒙特卡罗树搜索的Simulation阶段,输入一个需要expand的节点,随机操作后创建新的节点,返回新增节点的reward。注意输入的节点应该不是子节点,而且是有未执行的Action可以expend的。基本策略是随机选择Action。"""# Get the state of the gamecurrent_state = node.get_state()# Run until the game overwhile current_state.is_terminal() == False:# Pick one random action to play and get next statecurrent_state = current_state.get_next_state_with_random_choice()final_state_reward = current_state.compute_reward()return final_state_rewarddef expand(node):"""输入一个节点,在该节点上拓展一个新的节点,使用random方法执行Action,返回新增的节点。注意,需要保证新增的节点与其他节点Action不同。"""tried_sub_node_states = [sub_node.get_state() for sub_node in node.get_children()]new_state = node.get_state().get_next_state_with_random_choice()# Check until get the new state which has the different action from otherswhile new_state in tried_sub_node_states:new_state = node.get_state().get_next_state_with_random_choice()sub_node = Node()sub_node.set_state(new_state)node.add_child(sub_node)return sub_nodedef best_child(node, is_exploration):"""使用UCB算法,权衡exploration和exploitation后选择得分最高的子节点,注意如果是预测阶段直接选择当前Q值得分最高的。"""# TODO: Use the min float valuebest_score = -sys.maxsizebest_sub_node = None# Travel all sub nodes to find the best onefor sub_node in node.get_children():# Ignore exploration for inferenceif is_exploration:C = 1 / math.sqrt(2.0)else:C = 0.0# UCB = quality / times + C * sqrt(2 * ln(total_times) / times)left = sub_node.get_quality_value() / sub_node.get_visit_times()right = 2.0 * math.log(node.get_visit_times()) / sub_node.get_visit_times()score = left + C * math.sqrt(right)if score > best_score:best_sub_node = sub_nodebest_score = scorereturn best_sub_nodedef backup(node, reward):"""蒙特卡洛树搜索的Backpropagation阶段,输入前面获取需要expend的节点和新执行Action的reward,反馈给expend节点和上游所有节点并更新对应数据。"""# Update util the root nodewhile node != None:# Update the visit timesnode.visit_times_add_one()# Update the quality valuenode.quality_value_add_n(reward)# Change the node to the parent nodenode = node.parentdef monte_carlo_tree_search(node):"""实现蒙特卡洛树搜索算法,传入一个根节点,在有限的时间内根据之前已经探索过的树结构expand新节点和更新数据,然后返回只要exploitation最高的子节点。蒙特卡洛树搜索包含四个步骤,Selection、Expansion、Simulation、Backpropagation。前两步使用tree policy找到值得探索的节点。第三步使用default policy也就是在选中的节点上随机算法选一个子节点并计算reward。最后一步使用backup也就是把reward更新到所有经过的选中节点的节点上。进行预测时,只需要根据Q值选择exploitation最大的节点即可,找到下一个最优的节点。"""computation_budget = 2# Run as much as possible under the computation budgetfor i in range(computation_budget):# 1. Find the best node to expandexpand_node = tree_policy(node)# 2. Random run to add node and get rewardreward = default_policy(expand_node)# 3. Update all passing nodes with rewardbackup(expand_node, reward)# N. Get the best next nodebest_next_node = best_child(node, False)return best_next_nodedef main():# Create the initialized state and initialized nodeinit_state = State()init_node = Node()init_node.set_state(init_state)current_node = init_node# Set the rounds to playfor i in range(10):print("Play round: {}".format(i + 1))current_node = monte_carlo_tree_search(current_node)print("Choose node: {}".format(current_node))if __name__ == "__main__":main()
蒙特卡洛树搜索(MCTS)的实例代码相关推荐
- 面向初学者的蒙特卡洛树搜索MCTS详解及其实现
目录 0. 序言 1. 蒙特卡洛算法的前身今世 2. 蒙特卡洛搜索算法的原理 2.1 Exploration and Exploitation(探索与利用) 2.2 Upper Confidence ...
- 蒙特卡洛搜索树python_python实现的基于蒙特卡洛树搜索(MCTS)与UCT RAVE的五子棋游戏...
更新 2017.2.23有更新,见文末. MCTS与UCT 下面的内容引用自徐心和与徐长明的论文<计算机博弈原理与方法学概述>: 蒙特卡洛模拟对局就是从某一棋局出发,随机走棋.有人形象地比 ...
- 强化学习(八):Dyna架构与蒙特卡洛树搜索MCTS
强化学习(八):Dyna架构与蒙特卡洛树搜索MCTS 在基于表格型强化学习方法中,比较常见的方法有动态规划法.蒙特卡洛法,时序差分法,多步引导法等.其中动态规划法是一种基于模型的方法(Model- ...
- 蒙特卡洛树搜索 MCTS
原文地址 http://mcts.ai/about/index.html 什么是 MCTS? 全称 Monte Carlo Tree Search,是一种人工智能问题中做出最优决策的方法,一般是在组合 ...
- python实现的基于蒙特卡洛树搜索(MCTS)与UCT RAVE的五子棋游戏
转自: http://www.cnblogs.com/xmwd/p/python_game_based_on_MCTS_and_UCT_RAVE.html 更新 2017.2.23有更新,见文末 ...
- 蒙特卡洛树搜索 MCTS 入门
引言 你如果是第一次听到蒙特卡洛,可能会认为这是一个人名.那么你就大错特错,蒙特卡洛不是一个人名,而是一个地方,还一个赌场名!!!但是这不是我们的重点. 我们今天的主题就是入门蒙特卡洛树搜索, ...
- 蒙特卡洛树搜索(MCTS)实现简易五子棋AI
蒙特卡洛树搜索算法可以通过自我对弈模拟得到不同状态分支中获胜的概率,从而获得最优的策略.代码部分可以分为Node类和State类.Node类通过关联父节点和子节点实现树结构,同时保存每个节点的属性:S ...
- 蒙特卡洛树搜索(MCTS)详解
蒙特卡洛树搜索(MCTS)详解 蒙特卡洛树搜索是一种经典的树搜索算法,名镇一时的 AlphaGo 的技术背景就是结合蒙特卡洛树搜索和深度策略价值网络,因此击败了当时的围棋世界冠军.它对于求解这种大规模 ...
- 强化学习笔记:AlphaGo(AlphaZero) ,蒙特卡洛树搜索(MCTS)
1 AlphaZero的状态 围棋的棋盘是 19 × 19 的网格,可以在两条线交叉的地方放置棋子,一共有 361 个可以放置棋子的位置,因此动作空间是 A = {1, · · , 361}.比如动 ...
最新文章
- Apache启动时报Could not reliably determine the server's fully qualified domain name
- xdebug与wincachegrind配置
- HDU - 2389 Rain on your Parade(Hopcroft-Krap算法求二分图最大匹配)
- Windows 系统版本判断
- git clone 时候出现Please make sure you have the correct access rights and the repository exists.
- Spring的XML解析原理,java接口流程图
- LeetCode 712. Minimum ASCII Delete Sum for Two Strings
- JS 操作 HTML 和 AJAX 请求后台数据
- 为什么快速排序比归并排序快
- unity3D游戏素材素材哪家强?Top3都在这!
- 拜访名寺古刹之圆通寺
- 初步观察UE蓝图的“Branch节点”,这个最简单的K2Node的代码
- [轻笔记] SHAP值的计算步骤
- 使用 JDB 调试 Android 应用程序
- 自媒体到底有多赚钱?首选赛道推荐
- 输入年,月,输出这一年的这个月有多少天
- 培养气质的98个好习惯
- Android SDK工具链清单
- UVA - 1471 Defense Lines 贪心+二分
- C语言编程输出象棋棋盘