Skip to content
This repository has been archived by the owner on May 6, 2021. It is now read-only.

Commit

Permalink
Api change (#19)
Browse files Browse the repository at this point in the history
* unify interfaces

* update screen after interact

* add minor comment

* add version check

* update docker

* revert to juali v1.2 due to https://github.com/JuliaLang/julia/pull/32408\#issuecomment-522168938

* update README
  • Loading branch information
findmyway authored Aug 28, 2019
1 parent 82a21f3 commit cd77d25
Show file tree
Hide file tree
Showing 14 changed files with 146 additions and 93 deletions.
4 changes: 1 addition & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ language: julia
os:
- linux
julia:
- 1.0
- 1.1
- nightly
- 1.2
notifications:
email: false

Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM julia:1.1
FROM julia:1.2

# install dependencies
RUN set -eux; \
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ This package serves as a one-stop place for different kinds of reinforcement lea
Install:

```julia
(v1.1) pkg> add https://github.com/JuliaReinforcementLearning/ReinforcementLearningEnvironments.jl
pkg> add ReinforcementLearningEnvironments
```

## API
Expand Down Expand Up @@ -64,11 +64,11 @@ Take the `AtariEnv` for example:

1. Install this package by:
```julia
(v1.1) pkg> add ReinforcementLearningEnvironments
pkg> add ReinforcementLearningEnvironments
```
2. Install corresponding dependent package by:
```julia
(v1.1) pkg> add ArcadeLearningEnvironment
pkg> add ArcadeLearningEnvironment
```
3. Using the above two packages:
```julia
Expand Down
3 changes: 3 additions & 0 deletions src/ReinforcementLearningEnvironments.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
module ReinforcementLearningEnvironments

export RLEnvs
const RLEnvs = ReinforcementLearningEnvironments

using Reexport, Requires

include("abstractenv.jl")
Expand Down
26 changes: 24 additions & 2 deletions src/abstractenv.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export AbstractEnv, observe, reset!, interact!, action_space, observation_space, render
export AbstractEnv, observe, reset!, interact!, action_space, observation_space, render, Observation, get_reward, get_terminal, get_state, get_legal_actions

abstract type AbstractEnv end

Expand All @@ -7,4 +7,26 @@ function reset! end
function interact! end
function action_space end
function observation_space end
function render end
function render end

struct Observation{R, T, S, M<:NamedTuple}
reward::R
terminal::T
state::S
meta::M
end

Observation(;reward, terminal, state, kw...) = Observation(reward, terminal, state, merge(NamedTuple(), kw))

get_reward(obs::Observation) = obs.reward
get_terminal(obs::Observation) = obs.terminal
get_state(obs::Observation) = obs.state
get_legal_actions(obs::Observation) = obs.meta.legal_actions

# !!! >= julia v1.3
if VERSION >= v"1.3.0-rc1.0"
(env::AbstractEnv)(a) = interact!(env, a)
end

action_space(env::AbstractEnv) = env.action_space
observation_space(env::AbstractEnv) = env.observation_space
24 changes: 12 additions & 12 deletions src/environments/atari.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@ using ArcadeLearningEnvironment, GR

export AtariEnv

struct AtariEnv{To,F} <: AbstractEnv
mutable struct AtariEnv{To,F} <: AbstractEnv
ale::Ptr{Nothing}
screen::Array{UInt8, 1}
getscreen!::F
actions::Array{Int32, 1}
actions::Array{Int64, 1}
action_space::DiscreteSpace{Int}
observation_space::To
noopmax::Int
reward::Float32
end

action_space(env::AtariEnv) = env.action_space
observation_space(env::AtariEnv) = env.observation_space

"""
AtariEnv(name; colorspace = "Grayscale", frame_skip = 4, noopmax = 20,
color_averaging = true, repeat_action_probability = 0.)
Expand Down Expand Up @@ -51,24 +49,26 @@ function AtariEnv(name;
end
actions = actionset == :minimal ? getMinimalActionSet(ale) : getLegalActionSet(ale)
action_space = DiscreteSpace(length(actions))
AtariEnv(ale, screen, getscreen!, actions, action_space, observation_space, noopmax)
AtariEnv(ale, screen, getscreen!, actions, action_space, observation_space, noopmax, 0.0f0)
end

function interact!(env::AtariEnv, a)
r = act(env.ale, env.actions[a])
env.reward = act(env.ale, env.actions[a])
env.getscreen!(env.ale, env.screen)
(observation=env.screen, reward=r, isdone=game_over(env.ale))
nothing
end

function observe(env::AtariEnv)
env.getscreen!(env.ale, env.screen)
(observation=env.screen, isdone=game_over(env.ale))
end
observe(env::AtariEnv) = Observation(
reward = env.reward,
terminal = game_over(env.ale),
state = env.screen
)

function reset!(env::AtariEnv)
reset_game(env.ale)
for _ in 1:rand(0:env.noopmax) act(env.ale, Int32(0)) end
env.getscreen!(env.ale, env.screen)
env.reward = 0.0f0 # dummy
nothing
end

Expand Down
11 changes: 6 additions & 5 deletions src/environments/classic_control/cart_pole.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ function CartPoleEnv(; T = Float64, gravity = T(9.8), masscart = T(1.),
cp
end

action_space(env::CartPoleEnv) = env.action_space
observation_space(env::CartPoleEnv) = env.observation_space

function reset!(env::CartPoleEnv{T}) where T <: Number
env.state[:] = T(.1) * rand(env.rng, T, 4) .- T(.05)
env.t = 0
Expand All @@ -53,7 +50,11 @@ function reset!(env::CartPoleEnv{T}) where T <: Number
nothing
end

observe(env::CartPoleEnv) = (observation=env.state, isdone=env.done)
observe(env::CartPoleEnv) = Observation(
reward = env.done ? 0.0 : 1.0,
terminal = env.done,
state = env.state
)

function interact!(env::CartPoleEnv{T}, a) where T <: Number
env.action = a
Expand All @@ -76,7 +77,7 @@ function interact!(env::CartPoleEnv{T}, a) where T <: Number
env.done = abs(env.state[1]) > env.params.xthreshold ||
abs(env.state[3]) > env.params.thetathreshold ||
env.t >= env.params.max_steps
(observation=env.state, reward=1., isdone=env.done)
nothing
end

function plotendofepisode(x, y, d)
Expand Down
79 changes: 42 additions & 37 deletions src/environments/classic_control/mdp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,46 +9,52 @@ export MDPEnv, POMDPEnv, SimpleMDPEnv, absorbing_deterministic_tree_MDP, stochas
##### POMDPEnv
#####

mutable struct POMDPEnv{T,Ts,Ta, R<:AbstractRNG}
mutable struct POMDPEnv{T,Ts,Ta, R<:AbstractRNG} <: AbstractEnv
model::T
state::Ts
actions::Ta
action_space::DiscreteSpace
observation_space::DiscreteSpace
observation::Int
reward::Float64
rng::R
end

POMDPEnv(model; rng=Random.GLOBAL_RNG) = POMDPEnv(
model,
initialstate(model, rng),
actions(model),
DiscreteSpace(n_actions(model)),
DiscreteSpace(n_states(model)),
rng)
function POMDPEnv(model; rng=Random.GLOBAL_RNG)
state = initialstate(model, rng)
as = DiscreteSpace(n_actions(model))
os = DiscreteSpace(n_states(model))
actions_of_model = actions(model)
s, o, r = generate_sor(model, state, actions_of_model[rand(as)], rng)
obs = observationindex(model, o)
POMDPEnv(model, state, actions_of_model, as, os, obs, 0., rng)
end

function interact!(env::POMDPEnv, action)
s, o, r = generate_sor(env.model, env.state, env.actions[action], env.rng)
env.state = s
(observation = observationindex(env.model, o),
reward = r,
isdone = isterminal(env.model, s))
env.reward = r
env.observation = observationindex(env.model, o)
nothing
end

function observe(env::POMDPEnv)
(observation = observationindex(env.model, generate_o(env.model, env.state, env.rng)),
isdone = isterminal(env.model, env.state))
end
observe(env::POMDPEnv) = Observation(
reward = env.reward,
terminal = isterminal(env.model, env.state),
state = env.observation
)

#####
##### MDPEnv
#####

mutable struct MDPEnv{T, Ts, Ta, R<:AbstractRNG}
mutable struct MDPEnv{T, Ts, Ta, R<:AbstractRNG} <: AbstractEnv
model::T
state::Ts
actions::Ta
action_space::DiscreteSpace
observation_space::DiscreteSpace
reward::Float64
rng::R
end

Expand All @@ -58,10 +64,9 @@ MDPEnv(model; rng=Random.GLOBAL_RNG) = MDPEnv(
actions(model),
DiscreteSpace(n_actions(model)),
DiscreteSpace(n_states(model)),
rng)

action_space(env::Union{MDPEnv, POMDPEnv}) = env.action_space
observation_space(env::Union{MDPEnv, POMDPEnv}) = env.observation_space
0.,
rng
)

observationindex(env, o) = Int(o) + 1

Expand All @@ -74,15 +79,15 @@ function interact!(env::MDPEnv, action)
s = rand(env.rng, transition(env.model, env.state, env.actions[action]))
r = POMDPs.reward(env.model, env.state, env.actions[action])
env.state = s
(observation = stateindex(env.model, s),
reward = r,
isdone = isterminal(env.model, s))
env.reward = r
nothing
end

function observe(env::MDPEnv)
(observation = stateindex(env.model, env.state),
isdone = isterminal(env.model, env.state))
end
observe(env::MDPEnv) = Observation(
reward = env.reward,
terminal = isterminal(env.model, env.state),
state = stateindex(env.model, env.state)
)

#####
##### SimpleMDPEnv
Expand All @@ -107,14 +112,15 @@ probabilities) `reward` of type `R` (see [`DeterministicStateActionReward`](@ref
[`NormalStateActionReward`](@ref)), array of initial states
`initialstates`, and `ns` - array of 0/1 indicating if a state is terminal.
"""
mutable struct SimpleMDPEnv{T,R,S<:AbstractRNG}
mutable struct SimpleMDPEnv{T,R,S<:AbstractRNG} <: AbstractEnv
observation_space::DiscreteSpace
action_space::DiscreteSpace
state::Int
trans_probs::Array{T, 2}
reward::R
initialstates::Array{Int, 1}
isterminal::Array{Int, 1}
score::Float64
rng::S
end

Expand All @@ -125,12 +131,9 @@ function SimpleMDPEnv(ospace, aspace, state, trans_probs::Array{T, 2},
reward = DeterministicStateActionReward(reward)
end
SimpleMDPEnv{T,typeof(reward),S}(ospace, aspace, state, trans_probs,
reward, initialstates, isterminal, rng)
reward, initialstates, isterminal, 0., rng)
end

observation_space(env::SimpleMDPEnv) = env.observation_space
action_space(env::SimpleMDPEnv) = env.action_space

# reward types
"""
struct DeterministicNextStateReward
Expand Down Expand Up @@ -208,13 +211,15 @@ run!(mdp::SimpleMDPEnv, policy::Array{Int, 1}) = run!(mdp, policy[mdp.state])
function interact!(env::SimpleMDPEnv, action)
oldstate = env.state
run!(env, action)
r = reward(env.rng, env.reward, oldstate, action, env.state)
(observation = env.state, reward = r, isdone = env.isterminal[env.state] == 1)
env.score = reward(env.rng, env.reward, oldstate, action, env.state)
nothing
end

function observe(env::SimpleMDPEnv)
(observation = env.state, isdone = env.isterminal[env.state] == 1)
end
observe(env::SimpleMDPEnv) = Observation(
reward = env.score,
terminal = env.isterminal[env.state] == 1,
state = env.state
)

function reset!(env::SimpleMDPEnv)
env.state = rand(env.rng, env.initialstates)
Expand Down
11 changes: 7 additions & 4 deletions src/environments/classic_control/mountain_car.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ function MountainCarEnv(; T = Float64, continuous = false,
reset!(env)
env
end

ContinuousMountainCarEnv(; kwargs...) = MountainCarEnv(; continuous = true, kwargs...)

action_space(env::MountainCarEnv) = env.action_space
observation_space(env::MountainCarEnv) = env.observation_space
observe(env::MountainCarEnv) = (observation=env.state, isdone=env.done)
observe(env::MountainCarEnv) = Observation(
reward = env.done ? 0. : -1.,
terminal = env.done,
state = env.state
)

function reset!(env::MountainCarEnv{A, T}) where {A, T}
env.state[1] = .2 * rand(env.rng, T) - .6
Expand All @@ -78,7 +81,7 @@ function _interact!(env::MountainCarEnv, force)
env.t >= env.params.max_steps
env.state[1] = x
env.state[2] = v
(observation=env.state, reward=-1., isdone=env.done)
nothing
end

# adapted from https://github.com/JuliaML/Reinforce.jl/blob/master/src/envs/mountain_car.jl
Expand Down
Loading

0 comments on commit cd77d25

Please sign in to comment.