ReinforcementLearning.jl, as the name says, is a package for reinforcement learning research in Julia.
Our design principles are:
- Reusability and extensibility: Provide elaborately designed components and interfaces to help users implement new algorithms.
- Easy experimentation: Make it easy for new users to run benchmark experiments, compare different algorithms, evaluate and diagnose agents.
- Reproducibility: Facilitate reproducibility from traditional tabular methods to modern deep reinforcement learning algorithms.
ReinforcementLearning.jl
itself is just a wrapper around several other packages inside the JuliaReinforcementLearning org. The relationship between different packages is described below:
+-------------------------------------------------------------------------------------------+ | | | ReinforcementLearning.jl | | | | +------------------------------+ | | | ReinforcementLearningBase.jl | | | +--------|---------------------+ | | | | | | +--------------------------------------+ | | | | ReinforcementLearningEnvironments.jl | | | | | | | | | | (Conditionally depends on) | | | | | | | | | | ArcadeLearningEnvironment.jl | | | +-------->+ OpenSpiel.jl | | | | | POMDPs.jl | | | | | PyCall.jl | | | | | ViZDoom.jl | | | | | Maze.jl(WIP) | | | | +--------------------------------------+ | | | | | | +------------------------------+ | | +-------->+ ReinforcementLearningCore.jl | | | +--------|---------------------+ | | | | | | +-----------------------------+ | | |--------->+ ReinforcementLearningZoo.jl | | | | +-----------------------------+ | | | | | | +----------------------------------------+ | | +--------->+ ReinforcementLearningAnIntroduction.jl | | | +----------------------------------------+ | +-------------------------------------------------------------------------------------------+
This package can be installed from the package manager in Julia's REPL:
] add ReinforcementLearning
using ReinforcementLearning
using StatsBase:mean
env = CartPoleEnv(;T=Float32, seed=123)
agent = Agent(
policy = RandomPolicy(env;seed=456),
trajectory = CircularCompactSARTSATrajectory(; capacity=3, state_type=Float32, state_size = (4,)),
)
hook = ComposedHook(TotalRewardPerEpisode(), TimePerStep())
run(agent, env, StopAfterEpisode(10_000), hook)
@info "stats for random policy" avg_reward = mean(hook[1].rewards) avg_fps = 1 / mean(hook[2].times)
# ┌ Info: stats for random policy
# │ avg_reward = 21.0591
# └ avg_fps = 1.6062450808744398e6
See also here for detailed explanation.