-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathquick.jl
299 lines (253 loc) · 10 KB
/
quick.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
struct QuickMDP{ID,S,A,D<:NamedTuple} <: MDP{S,A}
data::D
end
"""
QuickMDP(gen::Function, [id]; kwargs...)
Construct a generative MDP model with the function `gen` and keyword arguments.
`gen` should take three arguments: a state, an action, and a random number generator. It should return a `NamedTuple` with keys `sp` for the next state and `r` for the reward.
Keywords can be static objects or functions. See the QuickPOMDPs.jl documentation for more information.
"""
QuickMDP(gen::Function, id=uuid4(); kwargs...) = QuickMDP(id; gen=gen, kwargs...)
"""
QuickMDP([id]; kwargs...)
Construct an MDP model with keyword arguments. Keywords can be static objects or functions. See the QuickPOMDPs.jl documentation for more information.
"""
function QuickMDP(id=uuid4(); kwargs...)
kwd = Dict{Symbol, Any}(kwargs)
for (k, v) in pairs(kwd)
kwd[k] = preprocess(Val(k), v)
end
quick_defaults!(kwd)
S = infer_statetype(kwd)
A = infer_actiontype(kwd)
d = namedtuple(keys(kwd)...)(values(kwd)...)
qm = QuickMDP{id, S, A, typeof(d)}(d)
return qm
end
id(::QuickMDP{ID}) where ID = ID
struct QuickPOMDP{ID,S,A,O,D<:NamedTuple} <: POMDP{S,A,O}
data::D
end
"""
QuickPOMDP(gen::Function, [id]; kwargs...)
Construct a generative POMDP model with the function `gen` and keyword arguments.
`gen` should take three arguments: a state, an action, and a random number generator. It should return a `NamedTuple` with keys `sp` for the next state, `o` for the observation, and `r` for the reward.
Keywords can be static objects or functions. See the QuickPOMDPs.jl documentation for more information.
"""
QuickPOMDP(gen::Function, id=uuid4(); kwargs...) = QuickPOMDP(id; gen=gen, kwargs...)
"""
QuickPOMDP([id]; kwargs...)
Construct an POMDP model with keyword arguments. Keywords can be static objects or functions. See the QuickPOMDPs.jl documentation for more information.
"""
function QuickPOMDP(id=uuid4(); kwargs...)
kwd = Dict{Symbol, Any}(kwargs)
for (k, v) in pairs(kwd)
kwd[k] = preprocess(Val(k), v)
end
quick_defaults!(kwd)
quick_warnings(kwd)
S = infer_statetype(kwd)
A = infer_actiontype(kwd)
O = infer_obstype(kwd)
d = namedtuple(keys(kwd)...)(values(kwd)...)
qm = QuickPOMDP{id, S, A, O, typeof(d)}(d)
return qm
end
id(::QuickPOMDP{ID}) where ID = ID
const QuickModel = Union{QuickMDP, QuickPOMDP}
"Function that is called on each keyword argument before anything else is done. This was designed as a hook to allow other packages to handle PyObjects."
preprocess(x) = x
preprocess(argval::Val, x) = preprocess(x)
function quick_defaults!(kwd::Dict)
kwd[:discount] = get(kwd, :discount, 1.0)
kwd[:isterminal] = get(kwd, :isterminal, false)
if !haskey(kwd, :stateindex)
if haskey(kwd, :states)
states = _call(Val(:states), kwd[:states], ())
if hasmethod(length, typeof((states,))) && length(states) < Inf
kwd[:stateindex] = Dict(s=>i for (i,s) in enumerate(states))
end
end
end
if !haskey(kwd, :actionindex)
if haskey(kwd, :actions)
ka = kwd[:actions]
# check if only a state-dependent function (e.g. s->(1,2)) is provided
dynamic_actions_only = (ka isa Function && !hasmethod(ka, Tuple{})) || ka isa Dict
if !dynamic_actions_only
actions = _call(Val(:actions), ka, ())
if hasmethod(length, typeof((actions,))) && length(actions) < Inf
kwd[:actionindex] = Dict(s=>i for (i,s) in enumerate(actions))
end
end
end
end
if !haskey(kwd, :obsindex)
if haskey(kwd, :observations)
observations = _call(Val(:observations), kwd[:observations], ())
if hasmethod(length, typeof((observations,))) && length(observations) < Inf
kwd[:obsindex] = Dict(s=>i for (i,s) in enumerate(observations))
end
end
end
end
function quick_warnings(kwd)
if haskey(kwd, :initialstate)
isd = _call(Val(:initialstate), kwd[:initialstate], ())
try rand(MersenneTwister(0), isd)
catch ex
if ex isa MethodError || ex isa ArgumentError
@warn("Unable to call rand(rng, $isd). Is the `initialstate` that you supplied a distribution?")
else
rethrow(ex)
end
end
end
if haskey(kwd, :reward) && !(kwd[:reward] isa Function)
@warn("`reward` must be a function; got $(kwd[:reward])")
end
end
function infer_statetype(kwd)
if haskey(kwd, :statetype)
st = _call(Val(:statetype), kwd[:statetype], (), NamedTuple())
elseif haskey(kwd, :states)
st = eltype(_call(Val(:states), kwd[:states], (), NamedTuple()))
elseif haskey(kwd, :initialstate)
st = typeof(rand(MersenneTwister(0), _call(Val(:initialstate), kwd[:initialstate], (), NamedTuple())))
else
st = Any
end
if st == Any
@warn("Unable to infer state type for a Quick(PO)MDP; using Any. This may have significant performance consequences. Use the statetype keyword argument to specify a concrete state type.")
end
return st
end
function infer_actiontype(kwd)
if haskey(kwd, :actiontype)
at = _call(Val(:actiontype), kwd[:actiontype], (), NamedTuple())
elseif haskey(kwd, :actions)
kwa = kwd[:actions]
if kwa isa Function && !hasmethod(kwd[:actions], Tuple{})
at = Any
elseif kwa isa Dict
at = valtype(kwa)
else
at = eltype(_call(Val(:actions), kwd[:actions], (), NamedTuple()))
end
else
at = Any
end
if at == Any
@warn("Unable to infer action type for a Quick(PO)MDP; using Any. This may have significant performance consequences. Use the actiontype keyword argument to specify a concrete action type.")
end
return at
end
function infer_obstype(kwd)
if haskey(kwd, :obstype)
ot = _call(Val(:obstype), kwd[:obstype], (), NamedTuple())
elseif haskey(kwd, :observations)
ot = eltype(_call(Val(:observations), kwd[:observations], (), NamedTuple()))
elseif haskey(kwd, :initialobs) && haskey(kwd, :initialstate)
s0 = rand(MersenneTwister(0), _call(Val(:initialstate), kwd[:initialstate], (), NamedTuple()))
ot = typeof(rand(MersenneTwister(0), _call(Val(:initialobs), kwd[:initialobs], (s0,), NamedTuple())))
else
ot = Any
end
if ot == Any
@warn("Unable to infer observation type for a QuickPOMDP; using Any. This may have significant performance consequences. Use the obstype keyword argument to specify a concrete observation type.")
end
return ot
end
function _call(namev::Val{name}, m::QuickModel, args, kwargs=NamedTuple()) where name
_call(namev,
get(m.data, name) do
throw(MissingQuickArgument(m, name))
end,
args,
kwargs)
end
_call(::Val, f::Function, args, kwargs=NamedTuple()) = f(args...; kwargs...)
_call(v::Val, object, args, kwargs=NamedTuple()) = object
_call(v::Val, d::Dict, args, kwargs=NamedTuple()) = d[args...]
macro forward_to_data(f)
@assert f.head == :. "@forward_to_data must be used with a module-qualified function expression, e.g. @forward_to_data POMDPs.discount"
quote
$f(m::QuickModel, args...; kwargs...) = _call(Val($(f.args[2])), m, args, kwargs)
end
end
function POMDPs.transition(m::QuickModel, s, a)
if haskey(m.data, :transition)
return m.data.transition(s, a)
else
throw(MissingQuickArgument(m, :transition, types=[Function], also=[:gen]))
end
end
function POMDPs.observation(m::QuickPOMDP, args...)
if haskey(m.data, :observation)
obs = m.data[:observation]
if static_hasmethod(obs, typeof(args))
return obs(args...)
elseif length(args) == 3 && static_hasmethod(obs, typeof(args[2:3]))
return obs(args[2:3]...)
else
return obs(args...)
end
return m.data.observation(args...)
else
throw(MissingQuickArgument(m, :observation, types=[Function], also=[:gen]))
end
end
function POMDPs.reward(m::QuickModel, args...)
if haskey(m.data, :reward)
r = m.data[:reward]
if static_hasmethod(r, typeof(args)) # static_hasmethod could cause issues, but I think it is worth doing in this single spot
return r(args...)
elseif m isa POMDP && length(args) == 4
if static_hasmethod(r, typeof(args[1:3])) # (s, a, sp, o) -> (s, a, sp)
return r(args[1:3]...)
elseif static_hasmethod(r, typeof(args[1:2])) # (s, a, sp, o) -> (s, a)
return r(args[1:2]...)
end
elseif length(args) == 3 && static_hasmethod(r, typeof(args[1:2])) # (s, a, sp) -> (s, a)
return r(args[1:2]...)
else
return r(args...)
end
else
throw(MissingQuickArgument(m, :reward))
end
end
@forward_to_data POMDPs.initialstate
@forward_to_data POMDPs.initialobs
function POMDPs.gen(m::QuickModel, s, a, rng)
if haskey(m.data, :gen)
return m.data.gen(s, a, rng)
else
return NamedTuple()
end
end
@forward_to_data POMDPs.states
@forward_to_data POMDPs.actions
@forward_to_data POMDPs.observations
@forward_to_data POMDPs.discount
@forward_to_data POMDPs.stateindex
@forward_to_data POMDPs.actionindex
@forward_to_data POMDPs.obsindex
@forward_to_data POMDPs.isterminal
function POMDPTools.obs_weight(m::QuickPOMDP, args...)
if haskey(m.data, :obs_weight)
return _call(Val(:obs_weight), m, args)
elseif haskey(m.data, :observation)
return pdf(observation(m, args[1:end-1]...), args[end])
else
throw(MissingQuickArgument(m, :obs_weight, types=[Function], also=[:observation]))
end
end
@forward_to_data POMDPTools.render
function POMDPTools.StateActionReward(m::Union{QuickPOMDP,QuickMDP})
if hasmethod(m.data[:reward], Tuple{statetype(m), actiontype(m)})
return FunctionSAR(m)
else
return LazyCachedSAR(m)
end
end