From 5514f7315803266ff80506d9b0dd6404e0f05f5b Mon Sep 17 00:00:00 2001 From: liuanji <904051563@qq.com> Date: Sun, 17 Nov 2019 12:01:29 -0800 Subject: [PATCH] change alg name to WU-UCT --- Node/{P_UCTnode.py => WU_UCTnode.py} | 10 +++++----- Tree/{P_UCT.py => WU_UCT.py} | 6 +++--- main.py | 11 ++++++----- 3 files changed, 14 insertions(+), 13 deletions(-) rename Node/{P_UCTnode.py => WU_UCTnode.py} (97%) rename Tree/{P_UCT.py => WU_UCT.py} (99%) diff --git a/Node/P_UCTnode.py b/Node/WU_UCTnode.py similarity index 97% rename from Node/P_UCTnode.py rename to Node/WU_UCTnode.py index de3b5f8..ed33068 100755 --- a/Node/P_UCTnode.py +++ b/Node/WU_UCTnode.py @@ -5,7 +5,7 @@ from Utils.MovingAvegCalculator import MovingAvegCalculator -class P_UCTnode(): +class WU_UCTnode(): def __init__(self, action_n, state, checkpoint_idx, parent, tree, prior_prob = None, is_head = False): self.action_n = action_n @@ -65,7 +65,7 @@ def all_child_updated(self): # Shallowly clone itself, contains necessary data only. def shallow_clone(self): - node = P_UCTnode( + node = WU_UCTnode( action_n = self.action_n, state = deepcopy(self.state), checkpoint_idx = self.checkpoint_idx, @@ -159,7 +159,7 @@ def update_history(self, idx, action_taken, reward): self.traverse_history[idx] = (action_taken, reward) return True - # Incomplete update, called by P_UCT.py + # Incomplete update, called by WU_UCT.py def update_incomplete(self, idx): action_taken = self.traverse_history[idx][0] @@ -169,7 +169,7 @@ def update_incomplete(self, idx): self.children_visit_count[action_taken] += 1 self.visit_count += 1 - # Complete update, called by P_UCT.py + # Complete update, called by WU_UCT.py def update_complete(self, idx, accu_reward): if idx not in self.traverse_history: raise RuntimeError("idx {} should be in traverse_history".format(idx)) @@ -195,7 +195,7 @@ def add_child(self, action, child_state, checkpoint_idx, prior_prob = None): if self.children[action] is not None: node = self.children[action] else: - node = P_UCTnode( + node = WU_UCTnode( action_n = self.action_n, state = child_state, checkpoint_idx = checkpoint_idx, diff --git a/Tree/P_UCT.py b/Tree/WU_UCT.py similarity index 99% rename from Tree/P_UCT.py rename to Tree/WU_UCT.py index 815c3e3..8764806 100755 --- a/Tree/P_UCT.py +++ b/Tree/WU_UCT.py @@ -4,7 +4,7 @@ import time import logging -from Node.P_UCTnode import P_UCTnode +from Node.WU_UCTnode import WU_UCTnode from Env.EnvWrapper import EnvWrapper @@ -13,7 +13,7 @@ from Mem.CheckpointManager import CheckpointManager -class P_UCT(): +class WU_UCT(): def __init__(self, env_params, max_steps = 1000, max_depth = 20, max_width = 5, gamma = 1.0, expansion_worker_num = 16, simulation_worker_num = 16, policy = "Random", seed = 123, device = "cpu", record_video = False): @@ -152,7 +152,7 @@ def simulate_single_move(self, state): # Construct root node self.checkpoint_data_manager.checkpoint_env("main", self.global_saving_idx) - self.root_node = P_UCTnode( + self.root_node = WU_UCTnode( action_n = self.action_n, state = state, checkpoint_idx = self.global_saving_idx, diff --git a/main.py b/main.py index 61e44e6..6397fd6 100755 --- a/main.py +++ b/main.py @@ -3,7 +3,7 @@ import scipy.io as sio import os -from Tree.P_UCT import P_UCT +from Tree.WU_UCT import WU_UCT from Tree.UCT import UCT from Utils.NetworkDistillation.Distillation import train_distillation @@ -60,10 +60,11 @@ def main(): if args.mode == "MCTS": # Model initialization if args.model == "WU-UCT": - MCTStree = P_UCT(env_params, args.MCTS_max_steps, args.MCTS_max_depth, - args.MCTS_max_width, args.gamma, args.expansion_worker_num, - args.simulation_worker_num, policy = args.policy, - seed = args.seed, device = args.device, record_video = args.record_video) + MCTStree = WU_UCT(env_params, args.MCTS_max_steps, args.MCTS_max_depth, + args.MCTS_max_width, args.gamma, args.expansion_worker_num, + args.simulation_worker_num, policy = args.policy, + seed = args.seed, device = args.device, + record_video = args.record_video) elif args.model == "UCT": MCTStree = UCT(env_params, args.MCTS_max_steps, args.MCTS_max_depth, args.MCTS_max_width, args.gamma, policy = args.policy, seed = args.seed)