Skip to content

Commit

Permalink
Make node ID unique to the node type (#1113)
Browse files Browse the repository at this point in the history
Needed for #1110.

Before node IDs were globally unique. If you had the ID, you could look
up the type in the Node table. When designing #1110 we came to the
conclusion that this was not the right choice. It requires you to be
aware of the IDs that are used throughout the model, whereas with this
you only need to make sure that `Pump #5` doesn't already exist.

Most of the updates in this PR were to correct the tests. Some of the
code and error messages became easier to read. If we talk about a node
then we always know what the type is, there is no need to look it up
first.

Most tables stay the same, e.g. if you have a `Terminal / static` table
with a `node_id` Int column, you know that this refers to a Terminal
NodeID. Only when connecting to other nodes do we need to specify the
type next to the ID. So the `Edge` table now gets `from_node_type` next
to `from_node_id`, the `PidControl / static` gets `listen_node_type`
next to `listen_node_id`, etc. These extra columns are currently
automatically filled in by Ribasim-Python on model save, hence they
don't require changing the test models.

In terms of implementation, this basically adds the `type` field to
`NodeID` and fixes the resulting errors.

```julia
struct NodeID
    type::NodeType.T
    value::Int
end
```

It does not yet test if models with the same node IDs (`Pump #1` and
`Basin #1`) work, but this is hard to do right now with Ribasim Python,
so best left for a later moment.

---------

Co-authored-by: Hofer-Julian <[email protected]>
  • Loading branch information
visr and Hofer-Julian authored Feb 12, 2024
1 parent bbd57ca commit a39cc24
Show file tree
Hide file tree
Showing 20 changed files with 470 additions and 395 deletions.
22 changes: 10 additions & 12 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,29 +123,27 @@ function get_value(
(; basin, flow_boundary, level_boundary) = p

if variable == "level"
hasindex_basin, basin_idx = id_index(basin.node_id, node_id)
level_boundary_idx = findsorted(level_boundary.node_id, node_id)

if hasindex_basin
if node_id.type == NodeType.Basin
_, basin_idx = id_index(basin.node_id, node_id)
_, level = get_area_and_level(basin, basin_idx, u[basin_idx])
elseif level_boundary_idx !== nothing
elseif node_id.type == NodeType.LevelBoundary
level_boundary_idx = findsorted(level_boundary.node_id, node_id)
level = level_boundary.level[level_boundary_idx](t + Δt)
else
error(
"Level condition node '$node_id' is neither a basin nor a level boundary.",
)
end

value = level

elseif variable == "flow_rate"
flow_boundary_idx = findsorted(flow_boundary.node_id, node_id)

if flow_boundary_idx === nothing
if node_id.type == NodeType.FlowBoundary
flow_boundary_idx = findsorted(flow_boundary.node_id, node_id)
value = flow_boundary.flow_rate[flow_boundary_idx](t + Δt)
else
error("Flow condition node $node_id is not a flow boundary.")
end

value = flow_boundary.flow_rate[flow_boundary_idx](t + Δt)
else
error("Unsupported condition variable $variable.")
end
Expand Down Expand Up @@ -418,7 +416,7 @@ function update_basin(integrator)::Nothing
)

for row in timeblock
hasindex, i = id_index(node_id, NodeID(row.node_id))
hasindex, i = id_index(node_id, NodeID(NodeType.Basin, row.node_id))
@assert hasindex "Table 'Basin / time' contains non-Basin IDs"
set_table_row!(table, row, i)
end
Expand Down Expand Up @@ -461,7 +459,7 @@ function update_tabulated_rating_curve!(integrator)::Nothing
id = first(group).node_id
level = [row.level for row in group]
flow_rate = [row.flow_rate for row in group]
i = searchsortedfirst(node_id, NodeID(id))
i = searchsortedfirst(node_id, NodeID(NodeType.TabulatedRatingCurve, id))
tables[i] = LinearInterpolation(flow_rate, level; extrapolate = true)
end
return nothing
Expand Down
18 changes: 13 additions & 5 deletions core/src/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra
execute(db, "SELECT fid, type, allocation_network_id FROM Node ORDER BY fid")
edge_rows = execute(
db,
"SELECT fid, from_node_id, to_node_id, edge_type, allocation_network_id FROM Edge ORDER BY fid",
"SELECT fid, from_node_type, from_node_id, to_node_type, to_node_id, edge_type, allocation_network_id FROM Edge ORDER BY fid",
)
# Node IDs per subnetwork
node_ids = Dict{Int, Set{NodeID}}()
Expand All @@ -34,7 +34,7 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra
graph_data = nothing,
)
for row in node_rows
node_id = NodeID(row.fid)
node_id = NodeID(row.type, row.fid)
# Process allocation network ID
if ismissing(row.allocation_network_id)
allocation_network_id = 0
Expand All @@ -51,15 +51,23 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra
flow_vertical_dict[node_id] = flow_vertical_counter
end
end
for (; fid, from_node_id, to_node_id, edge_type, allocation_network_id) in edge_rows
for (;
fid,
from_node_type,
from_node_id,
to_node_type,
to_node_id,
edge_type,
allocation_network_id,
) in edge_rows
try
# hasfield does not work
edge_type = getfield(EdgeType, Symbol(edge_type))
catch
error("Invalid edge type $edge_type.")
end
id_src = NodeID(from_node_id)
id_dst = NodeID(to_node_id)
id_src = NodeID(from_node_type, from_node_id)
id_dst = NodeID(to_node_type, to_node_id)
if ismissing(allocation_network_id)
allocation_network_id = 0
end
Expand Down
31 changes: 25 additions & 6 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,35 @@
# EdgeType.flow and NodeType.FlowBoundary
@enumx EdgeType flow control none
@eval @enumx NodeType $(config.nodetypes...)

# Support creating a NodeType enum instance from a symbol or string
function NodeType.T(s::Symbol)::NodeType.T
symbol_map = EnumX.symbol_map(NodeType.T)
for (sym, val) in symbol_map
sym == s && return NodeType.T(val)
end
throw(ArgumentError("Invalid value for NodeType: $s"))
end

NodeType.T(str::AbstractString) = NodeType.T(Symbol(str))

struct NodeID
type::NodeType.T
value::Int
end

NodeID(type::Symbol, value::Int) = NodeID(NodeType.T(type), value)
NodeID(type::AbstractString, value::Int) = NodeID(NodeType.T(type), value)

Base.Int(id::NodeID) = id.value
Base.convert(::Type{NodeID}, value::Int) = NodeID(value)
Base.convert(::Type{Int}, id::NodeID) = id.value
Base.broadcastable(id::NodeID) = Ref(id)
Base.show(io::IO, id::NodeID) = print(io, '#', Int(id))
Base.show(io::IO, id::NodeID) = print(io, id.type, " #", Int(id))

function Base.isless(id_1::NodeID, id_2::NodeID)::Bool
if id_1.type != id_2.type
error("Cannot compare NodeIDs of different types")
end
return Int(id_1) < Int(id_2)
end

Expand Down Expand Up @@ -64,8 +85,6 @@ struct Allocation
}
end

@enumx EdgeType flow control none

"""
Type for storing metadata of nodes in the graph
type: type of the node
Expand Down Expand Up @@ -318,7 +337,7 @@ struct Pump{T} <: AbstractParameterNode
control_mapping,
is_pid_controlled,
) where {T}
if valid_flow_rates(node_id, get_tmp(flow_rate, 0), control_mapping, :Pump)
if valid_flow_rates(node_id, get_tmp(flow_rate, 0), control_mapping)
return new{T}(
node_id,
active,
Expand Down Expand Up @@ -363,7 +382,7 @@ struct Outlet{T} <: AbstractParameterNode
control_mapping,
is_pid_controlled,
) where {T}
if valid_flow_rates(node_id, get_tmp(flow_rate, 0), control_mapping, :Outlet)
if valid_flow_rates(node_id, get_tmp(flow_rate, 0), control_mapping)
return new{T}(
node_id,
active,
Expand Down
Loading

0 comments on commit a39cc24

Please sign in to comment.