Skip to content

Commit

Permalink
change alg name to WU-UCT
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Nov 17, 2019
1 parent 067ccbe commit 5514f73
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
10 changes: 5 additions & 5 deletions Node/P_UCTnode.py → Node/WU_UCTnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from Utils.MovingAvegCalculator import MovingAvegCalculator


class P_UCTnode():
class WU_UCTnode():
def __init__(self, action_n, state, checkpoint_idx, parent, tree,
prior_prob = None, is_head = False):
self.action_n = action_n
Expand Down Expand Up @@ -65,7 +65,7 @@ def all_child_updated(self):

# Shallowly clone itself, contains necessary data only.
def shallow_clone(self):
node = P_UCTnode(
node = WU_UCTnode(
action_n = self.action_n,
state = deepcopy(self.state),
checkpoint_idx = self.checkpoint_idx,
Expand Down Expand Up @@ -159,7 +159,7 @@ def update_history(self, idx, action_taken, reward):
self.traverse_history[idx] = (action_taken, reward)
return True

# Incomplete update, called by P_UCT.py
# Incomplete update, called by WU_UCT.py
def update_incomplete(self, idx):
action_taken = self.traverse_history[idx][0]

Expand All @@ -169,7 +169,7 @@ def update_incomplete(self, idx):
self.children_visit_count[action_taken] += 1
self.visit_count += 1

# Complete update, called by P_UCT.py
# Complete update, called by WU_UCT.py
def update_complete(self, idx, accu_reward):
if idx not in self.traverse_history:
raise RuntimeError("idx {} should be in traverse_history".format(idx))
Expand All @@ -195,7 +195,7 @@ def add_child(self, action, child_state, checkpoint_idx, prior_prob = None):
if self.children[action] is not None:
node = self.children[action]
else:
node = P_UCTnode(
node = WU_UCTnode(
action_n = self.action_n,
state = child_state,
checkpoint_idx = checkpoint_idx,
Expand Down
6 changes: 3 additions & 3 deletions Tree/P_UCT.py → Tree/WU_UCT.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
import logging

from Node.P_UCTnode import P_UCTnode
from Node.WU_UCTnode import WU_UCTnode

from Env.EnvWrapper import EnvWrapper

Expand All @@ -13,7 +13,7 @@
from Mem.CheckpointManager import CheckpointManager


class P_UCT():
class WU_UCT():
def __init__(self, env_params, max_steps = 1000, max_depth = 20, max_width = 5,
gamma = 1.0, expansion_worker_num = 16, simulation_worker_num = 16,
policy = "Random", seed = 123, device = "cpu", record_video = False):
Expand Down Expand Up @@ -152,7 +152,7 @@ def simulate_single_move(self, state):

# Construct root node
self.checkpoint_data_manager.checkpoint_env("main", self.global_saving_idx)
self.root_node = P_UCTnode(
self.root_node = WU_UCTnode(
action_n = self.action_n,
state = state,
checkpoint_idx = self.global_saving_idx,
Expand Down
11 changes: 6 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import scipy.io as sio
import os

from Tree.P_UCT import P_UCT
from Tree.WU_UCT import WU_UCT
from Tree.UCT import UCT

from Utils.NetworkDistillation.Distillation import train_distillation
Expand Down Expand Up @@ -60,10 +60,11 @@ def main():
if args.mode == "MCTS":
# Model initialization
if args.model == "WU-UCT":
MCTStree = P_UCT(env_params, args.MCTS_max_steps, args.MCTS_max_depth,
args.MCTS_max_width, args.gamma, args.expansion_worker_num,
args.simulation_worker_num, policy = args.policy,
seed = args.seed, device = args.device, record_video = args.record_video)
MCTStree = WU_UCT(env_params, args.MCTS_max_steps, args.MCTS_max_depth,
args.MCTS_max_width, args.gamma, args.expansion_worker_num,
args.simulation_worker_num, policy = args.policy,
seed = args.seed, device = args.device,
record_video = args.record_video)
elif args.model == "UCT":
MCTStree = UCT(env_params, args.MCTS_max_steps, args.MCTS_max_depth,
args.MCTS_max_width, args.gamma, policy = args.policy, seed = args.seed)
Expand Down

0 comments on commit 5514f73

Please sign in to comment.