Skip to content

Commit

Permalink
Implement writing our own format
Browse files Browse the repository at this point in the history
  • Loading branch information
Zinoex committed Nov 15, 2023
1 parent ede66fc commit cf9cec4
Showing 1 changed file with 64 additions and 5 deletions.
69 changes: 64 additions & 5 deletions src/Data/imdp.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@

function read_imdp_jl_file(path)
mdp_or_mc, terminal_states = Dataset(path) do dataset
n = Int32(dataset.attrib["num_states"] + 1)
initial_state = Int32(1) # dataset.attrib["initial_state"]
n = Int32(dataset.attrib["num_states"])
initial_state = dataset.attrib["initial_state"]
model = dataset.attrib["model"]

@assert model ["imdp", "imc"]
Expand Down Expand Up @@ -33,7 +33,7 @@ function read_imdp_jl_file(path)
)

prob = MatrixIntervalProbabilities(; lower = P̲, upper = P̅)
terminal_states = Int32[n] # convert.(Int32, dataset["terminal_states"][:])
terminal_states = convert.(Int32, dataset["terminal_states"][:])

if model == "imdp"
return read_imdp_jl_mdp(dataset, prob, initial_state), terminal_states
Expand All @@ -46,6 +46,7 @@ function read_imdp_jl_file(path)
end

function read_imdp_jl_mdp(dataset, prob, initial_state)
@assert dataset.attrib["model"] == "imdp"
@assert dataset.attrib["cols"] == "from/action"

stateptr = convert.(Int32, dataset["stateptr"][:])
Expand All @@ -56,12 +57,70 @@ function read_imdp_jl_mdp(dataset, prob, initial_state)
end

function read_imdp_jl_mc(dataset, prob, initial_state)
@assert dataset.attrib["model"] == "imc"
@assert dataset.attrib["cols"] == "from"

mc = IntervalMarkovChain(prob, Int32(initial_state))
return mc
end

function write_imdp_jl_file(path, mdp_or_mc)
# TODO: implement
function write_imdp_jl_file(path, mdp_or_mc, terminal_states)
Dataset(path, "c") do dataset
dataset.attrib["format"] = "sparse_csc"
dataset.attrib["num_states"] = num_states(mdp_or_mc)
dataset.attrib["rows"] = "to"
dataset.attrib["initial_state"] = initial_state(mdp_or_mc)

prob = transition_prob(mdp_or_mc)
l = lower(prob)
g = gap(prob)

defDim(dataset, "lower_colptr", length(l.colptr))
v = defVar(dataset, "lower_colptr", Int32, ("lower_colptr",))
v[:] = l.colptr

defDim(dataset, "lower_rowval", length(l.rowval))
v = defVar(dataset, "lower_rowval", Int32, ("lower_rowval",))
v[:] = l.rowval

defDim(dataset, "lower_nzval", length(l.nzval))
v = defVar(dataset, "lower_nzval", eltype(l.nzval), ("lower_nzval",))
v[:] = l.nzval

defDim(dataset, "upper_colptr", length(g.colptr))
v = defVar(dataset, "upper_colptr", Int32, ("upper_colptr",))
v[:] = g.colptr

defDim(dataset, "upper_rowval", length(g.rowval))
v = defVar(dataset, "upper_rowval", Int32, ("upper_rowval",))
v[:] = g.rowval

defDim(dataset, "upper_nzval", length(g.nzval))
v = defVar(dataset, "upper_nzval", eltype(g.nzval), ("upper_nzval",))
v[:] = l.nzval + g.nzval

defDim(dataset, "terminal_states", length(terminal_states))
v = defVar(dataset, "terminal_states", Int32, ("terminal_states",))
v[:] = terminal_states

write_imdp_jl_model_specific(dataset, mdp_or_mc)
end
end

function write_imdp_jl_model_specific(dataset, mdp::IntervalMarkovDecisionProcess)
dataset.attrib["model"] = "imdp"
dataset.attrib["cols"] = "from/action"

defDim(dataset, "stateptr", length(stateptr(mdp)))
v = defVar(dataset, "stateptr", Int32, ("stateptr",))
v[:] = stateptr(mdp)

defDim(dataset, "action_vals", length(actions(mdp)))
v = defVar(dataset, "action_vals", eltype(actions(mdp)), ("action_vals",))
v[:] = actions(mdp)
end

function write_imdp_jl_model_specific(dataset, mc::IntervalMarkovChain)
dataset.attrib["model"] = "imc"
dataset.attrib["cols"] = "from"
end

0 comments on commit cf9cec4

Please sign in to comment.