Skip to content

Commit

Permalink
Bumb MCTS API
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasmarsh committed Feb 21, 2024
1 parent 60e42e6 commit a4eea82
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
23 changes: 15 additions & 8 deletions src/agent/mcts2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,27 @@ use std::sync::{Mutex, MutexGuard, OnceLock};
use crate::core::{game::Color, game::State, r#move::Move};

use mcts::game::{Game, PlayerIndex};
use mcts::strategies::mcts::{util, TreeSearch};
use mcts::strategies::mcts::select;
use mcts::strategies::mcts::{util, SearchConfig, TreeSearch};
use mcts::strategies::Search;

type NegoTS = TreeSearch<Nego, util::ScalarAmaf>;
type NegoTS = TreeSearch<Nego, util::Ucb1Tuned>;

static MCTS_CELL: OnceLock<Mutex<NegoTS>> = OnceLock::new();

fn get_agent() -> MutexGuard<'static, NegoTS> {
MCTS_CELL
.get_or_init(|| {
let mut mcts = NegoTS::default();
mcts.verbose = true;
let mcts = NegoTS::default()
.config(
SearchConfig::default()
.expand_threshold(2)
.max_iterations(usize::MAX)
.select(select::Ucb1Tuned {
exploration_constant: 1.625,
}),
)
.verbose(true);
Mutex::new(mcts)
})
.lock()
Expand All @@ -24,10 +33,7 @@ fn get_agent() -> MutexGuard<'static, NegoTS> {

pub fn step(state: &State, timeout: std::time::Duration) -> Option<Move> {
let mut mcts = get_agent();
mcts.strategy.max_time = timeout;
mcts.strategy.max_iterations = usize::MAX;
mcts.strategy.select.exploration_constant = 0.1;
mcts.strategy.playouts_before_expanding = 2;
mcts.config.max_time = timeout;
// mcts.strategy.max_iterations = 40000;
Some(mcts.choose_action(state))
}
Expand All @@ -38,6 +44,7 @@ impl PlayerIndex for Color {
}
}

#[derive(Clone)]
struct Nego;

impl Game for Nego {
Expand Down
2 changes: 1 addition & 1 deletion src/ui/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl UIState {
fn new() -> Self {
Self {
show_spinner: false,
agent_white: Agent::Mcts2(std::time::Duration::from_secs(10)),
agent_white: Agent::Mcts2(std::time::Duration::from_secs(40)),
// agent_white: Agent::Human,
agent_black: Agent::Mcts(std::time::Duration::from_secs(10)),
user: None,
Expand Down

0 comments on commit a4eea82

Please sign in to comment.