diff --git a/.gitignore b/.gitignore index 7612197..f074a5b 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,4 @@ docs/site/ # environment. Manifest.toml -# Setting files -.vscode/ diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..7a73a41 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,2 @@ +{ +} \ No newline at end of file diff --git a/Project.toml b/Project.toml index ddb72b0..8a6e5ca 100644 --- a/Project.toml +++ b/Project.toml @@ -7,3 +7,4 @@ version = "0.1.0" AlphaZero = "8ed9eb0b-7496-408d-8c8b-2119aeea02cd" CommonRLInterface = "d842c3ba-07a1-494f-bbec-f5741b0a3e98" Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/src/dummy_example/game.jl b/src/dummy_example/game.jl index 7a48fcc..866dead 100644 --- a/src/dummy_example/game.jl +++ b/src/dummy_example/game.jl @@ -4,11 +4,6 @@ using Crayons const RL = CommonRLInterface -# To avoid episodes of unbounded length, we put an arbitrary limit to the length of an -# episode. Because time is not captured in the state, this introduces a slight bias in -# the value function. -const EPISODE_LENGTH_BOUND = 15 - mutable struct World <: AbstractEnv state::Int # The sum of the previous steps time::Int # Count of the steps taken @@ -25,13 +20,13 @@ end RL.actions(env::World) = [0, 1, 2] RL.observe(env::World) = env.state -RL.terminated(env::World) = env.time > EPISODE_LENGTH_BOUND +RL.terminated(env::World) = env.time > 10 function RL.act!(env::World, a) env.state += a env.time += 1 if env.state == 10 - return 1 + return 5 / env.time elseif env.state > 10 return -1 end @@ -40,7 +35,7 @@ end @provide RL.player(env::World) = 1 # A one player game @provide RL.players(env::World) = [1] -@provide RL.observations(env::World) = (env.state, env.time) #[SA[x, y] for x in 1:env.size[1], y in 1:env.size[2]] +@provide RL.observations(env::World) = (env.state, env.time) @provide RL.clone(env::World) = World(env.state, env.state) @provide RL.state(env::World) = env.state @provide RL.setstate!(env::World, s) = (env.state = s) @@ -66,7 +61,7 @@ const action_names = ["Stay", "Add One", "Add Two"] function GI.action_string(env::World, a) idx = findfirst(==(a), RL.actions(env)) - return isnothing(idx) ? "?" : action_names[idx + 1] + return isnothing(idx) ? "?" : action_names[idx] end @@ -76,7 +71,7 @@ function GI.parse_action(env::World, s) end function GI.read_state(env::World) - return nothing + return (env.state, env.time) end GI.heuristic_value(::World) = 0. diff --git a/src/dummy_example/params.jl b/src/dummy_example/params.jl index 8cbcc21..f5448ce 100644 --- a/src/dummy_example/params.jl +++ b/src/dummy_example/params.jl @@ -1,8 +1,8 @@ Network = NetLib.SimpleNet netparams = NetLib.SimpleNetHP( - width=100, - depth_common=4, + width=25, + depth_common=3, use_batch_norm=false) self_play = SelfPlayParams( @@ -10,7 +10,7 @@ self_play = SelfPlayParams( num_games=1000, num_workers=4, batch_size=4, - use_gpu=false, + use_gpu=true, reset_every=16, flip_probability=0., alternate_colors=false), @@ -34,13 +34,13 @@ arena = ArenaParams( update_threshold=0.00) learning = LearningParams( - use_gpu=false, + use_gpu=true, use_position_averaging=false, samples_weighing_policy=CONSTANT_WEIGHT, - rewards_renormalization=10, + rewards_renormalization=1, l2_regularization=1e-4, optimiser=Adam(lr=5e-3), - batch_size=64, + batch_size=32, loss_computation_batch_size=2048, nonvalidity_penalty=1., min_checkpoints_per_epoch=1, @@ -51,7 +51,7 @@ params = Params( arena=arena, self_play=self_play, learning=learning, - num_iters=5, + num_iters=2, memory_analysis=nothing, # ternary_outcome=false, use_symmetries=false, diff --git a/src/dummy_example/run.jl b/src/dummy_example/run.jl index 6bd26e1..a8bb9e9 100644 --- a/src/dummy_example/run.jl +++ b/src/dummy_example/run.jl @@ -6,4 +6,5 @@ using AlphaZero: Scripts experiment = Experiment( "AddToTen", GameSpec(), params, Network, netparams, benchmark) -Scripts.dummy_run(experiment) \ No newline at end of file +Scripts.train(experiment) +Scripts.explore(experiment)