-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrrt_baselines.jl
42 lines (33 loc) · 1.2 KB
/
rrt_baselines.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
using ReinforcementLearning
using Flux
using Zygote
using Plots
using Random
using Logging
using StatsBase
using BSON
using WandbMacros
using RRTStar
using Metareasoning
projectname = "RRTStarMetareasoning"
gf = 1.0 # Small
# gf = 2.0 # Large
experimentname = "StaticGrowthFactor=$gf"
logdir = "logs/$projectname/$experimentname"
env = RRTStarControlEnv(max_samples=1000, monitoring_interval=1000÷20, # Each episode = 20 steps.
α = 1, β = 0, allow_interrupt_action=false,
initial_growth_factor=gf, growth_factor_range=(gf, gf), # no wiggle room to adjust growth factor
focus_level = 0.0 # no bias
)
# --------------------- Evaluation ------------------------------------------
Random.seed!(env, 42) # Seed for testing
println("Evaluating")
Rs = [evaluate_policy_one_episode(RandomPolicy(), env) for i in 1:1000]
mkpath(logdir)
open(f -> join(f, Rs, "\n"), "$logdir/final_scores.csv", "w")
println("Mean utlity: ", mean(Rs))
println("Making a clip")
clip = record_clip(env, policy=RandomPolicy(), steps=20 * 10)
mkpath("$logdir/clips")
mp4(clip, "$logdir/clips/final_video.mp4", fps=10)
gif(clip, "$logdir/clips/final_video.gif", fps=10)