Skip to content

Commit

Permalink
Remove Terminal / static table
Browse files Browse the repository at this point in the history
This also removes the Terminal type from the core, as it is not needed for the same reason. That meant I had to adapt a bit of code that relied on `getfield(p, :terminal)` working.

Fixes #1209
  • Loading branch information
visr committed Jul 15, 2024
1 parent eadc057 commit 1ef15e1
Show file tree
Hide file tree
Showing 13 changed files with 47 additions and 89 deletions.
14 changes: 7 additions & 7 deletions core/src/allocation_init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ function get_subnetwork_capacity(
if edge_metadata.edge node_ids_subnetwork
id_src, id_dst = edge_metadata.edge

node_src = getfield(p, graph[id_src].type)
node_dst = getfield(p, graph[id_dst].type)

capacity_edge = Inf

# Find flow constraints for this edge
if is_flow_constraining(node_src)
if is_flow_constraining(id_src.type)
node_src = getfield(p, graph[id_src].type)

capacity_node_src = node_src.max_flow_rate[id_src.idx]
capacity_edge = min(capacity_edge, capacity_node_src)
end
if is_flow_constraining(node_dst)
if is_flow_constraining(id_dst.type)
node_dst = getfield(p, graph[id_dst].type)
capacity_node_dst = node_dst.max_flow_rate[id_dst.idx]
capacity_edge = min(capacity_edge, capacity_node_dst)
end
Expand All @@ -66,8 +66,8 @@ function get_subnetwork_capacity(
# If allowed by the nodes from this edge,
# allow allocation flow in opposite direction of the edge
if !(
is_flow_direction_constraining(node_src) ||
is_flow_direction_constraining(node_dst)
is_flow_direction_constraining(id_src.type) ||
is_flow_direction_constraining(id_dst.type)
)
capacity[reverse(edge_metadata.edge)] = capacity_edge
end
Expand Down
2 changes: 2 additions & 0 deletions core/src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ for sv in nodeschemas
node, kind = nodetype(sv)
push!(nodekinds[node], kind)
end
# Terminal has no tables
nodekinds[:Terminal] = Symbol[]

"Convert a string from CamelCase to snake_case."
function snake_case(str::AbstractString)::String
Expand Down
8 changes: 0 additions & 8 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,13 +521,6 @@ is_pid_controlled: whether the flow rate of this outlet is governed by PID contr
end
end

"""
node_id: node ID of the Terminal node
"""
@kwdef struct Terminal <: AbstractParameterNode
node_id::Vector{NodeID}
end

"""
The data for a single compound variable
node_id:: The ID of the DiscreteControl that listens to this variable
Expand Down Expand Up @@ -722,7 +715,6 @@ const ModelGraph{T} = MetaGraph{
flow_boundary::FlowBoundary
pump::Pump{T}
outlet::Outlet{T}
terminal::Terminal
discrete_control::DiscreteControl
pid_control::PidControl{T}
user_demand::UserDemand
Expand Down
7 changes: 0 additions & 7 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,6 @@ function Outlet(db::DB, config::Config, graph::MetaGraph, chunk_sizes::Vector{In
)
end

function Terminal(db::DB, config::Config)::Terminal
node_id = get_ids(db, "Terminal")
return Terminal(NodeID.(NodeType.Terminal, node_id, eachindex(node_id)))
end

function Basin(db::DB, config::Config, graph::MetaGraph, chunk_sizes::Vector{Int})::Basin
node_id = get_ids(db, "Basin")
n = length(node_id)
Expand Down Expand Up @@ -1089,7 +1084,6 @@ function Parameters(db::DB, config::Config)::Parameters
flow_boundary = FlowBoundary(db, config, graph)
pump = Pump(db, config, graph, chunk_sizes)
outlet = Outlet(db, config, graph, chunk_sizes)
terminal = Terminal(db, config)
discrete_control = DiscreteControl(db, config, graph)
pid_control = PidControl(db, config, chunk_sizes)
user_demand = UserDemand(db, config, graph)
Expand All @@ -1111,7 +1105,6 @@ function Parameters(db::DB, config::Config)::Parameters
flow_boundary,
pump,
outlet,
terminal,
discrete_control,
pid_control,
user_demand,
Expand Down
5 changes: 0 additions & 5 deletions core/src/schema.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
@schema "ribasim.basin.concentration" BasinConcentration
@schema "ribasim.basin.concentrationexternal" BasinConcentrationExternal
@schema "ribasim.basin.concentrationstate" BasinConcentrationState
@schema "ribasim.terminal.static" TerminalStatic
@schema "ribasim.fractionalflow.static" FractionalFlowStatic
@schema "ribasim.flowboundary.static" FlowBoundaryStatic
@schema "ribasim.flowboundary.time" FlowBoundaryTime
Expand Down Expand Up @@ -220,10 +219,6 @@ end
flow_rate::Float64
end

@version TerminalStaticV1 begin
node_id::Int32
end

@version DiscreteControlVariableV1 begin
node_id::Int32
compound_variable_id::Int32
Expand Down
26 changes: 14 additions & 12 deletions core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,20 +374,21 @@ function low_storage_factor(
end

"""Whether the given node node is flow constraining by having a maximum flow rate."""
is_flow_constraining(node::AbstractParameterNode) = hasfield(typeof(node), :max_flow_rate)
function is_flow_constraining(type::NodeType.T)::Bool
type in (NodeType.LinearResistance, NodeType.Pump, NodeType.Outlet)
end

"""Whether the given node is flow direction constraining (only in direction of edges)."""
is_flow_direction_constraining(node::AbstractParameterNode) = (
node isa Union{
Pump,
Outlet,
TabulatedRatingCurve,
FractionalFlow,
Terminal,
UserDemand,
FlowBoundary,
}
)
function is_flow_direction_constraining(type::NodeType.T)::Bool
type in (
NodeType.Pump,
NodeType.Outlet,
NodeType.TabulatedRatingCurve,
NodeType.FractionalFlow,
NodeType.UserDemand,
NodeType.FlowBoundary,
)
end

function has_main_network(allocation::Allocation)::Bool
if !is_active(allocation)
Expand Down Expand Up @@ -739,6 +740,7 @@ function collect_control_mappings!(p)::Nothing
(; control_mappings) = p.discrete_control

for node_type in instances(NodeType.T)
node_type == NodeType.Terminal && continue
node = getfield(p, Symbol(snake_case(string(node_type))))
if hasfield(typeof(node), :control_mapping)
control_mappings[node_type] = node.control_mapping
Expand Down
24 changes: 18 additions & 6 deletions core/test/utils_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,21 +236,33 @@ end
end

@testitem "constraints_from_nodes" begin
using Ribasim: Model, snake_case, nodetypes, is_flow_constraining
using Ribasim:
Model,
snake_case,
nodetypes,
NodeType,
is_flow_constraining,
is_flow_direction_constraining

toml_path = normpath(@__DIR__, "../../generated_testmodels/basic/ribasim.toml")
@test ispath(toml_path)
model = Model(toml_path)
(; p) = model.integrator
constraining_types = (:Pump, :Outlet, :LinearResistance)
directed =
(:Pump, :Outlet, :TabulatedRatingCurve, :FractionalFlow, :UserDemand, :FlowBoundary)

for type in nodetypes
type == :Terminal && continue # has no parameter field
node = getfield(p, snake_case(type))
for symbol in nodetypes
type = NodeType(symbol)
if type in constraining_types
@test is_flow_constraining(node)
@test is_flow_constraining(type)
else
@test !is_flow_constraining(node)
@test !is_flow_constraining(type)
end
if type in directed
@test is_flow_direction_constraining(type)
else
@test !is_flow_direction_constraining(type)
end
end
end
Expand Down
7 changes: 1 addition & 6 deletions python/ribasim/ribasim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
PumpStaticSchema,
TabulatedRatingCurveStaticSchema,
TabulatedRatingCurveTimeSchema,
TerminalStaticSchema,
UserDemandStaticSchema,
UserDemandTimeSchema,
)
Expand Down Expand Up @@ -169,11 +168,7 @@ def __getitem__(self, index: int) -> NodeData:
)


class Terminal(MultiNodeModel):
static: TableModel[TerminalStaticSchema] = Field(
default_factory=TableModel[TerminalStaticSchema],
json_schema_extra={"sort_keys": ["node_id"]},
)
class Terminal(MultiNodeModel): ...


class PidControl(MultiNodeModel):
Expand Down
2 changes: 0 additions & 2 deletions python/ribasim/ribasim/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
pid_control,
pump,
tabulated_rating_curve,
terminal,
user_demand,
)

Expand All @@ -28,6 +27,5 @@
"pid_control",
"pump",
"tabulated_rating_curve",
"terminal",
"user_demand",
]
13 changes: 0 additions & 13 deletions python/ribasim/ribasim/nodes/terminal.py

This file was deleted.

4 changes: 0 additions & 4 deletions python/ribasim/ribasim/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,6 @@ class TabulatedRatingCurveTimeSchema(_BaseSchema):
flow_rate: Series[float] = pa.Field(nullable=False)


class TerminalStaticSchema(_BaseSchema):
node_id: Series[Int32] = pa.Field(nullable=False, default=0)


class UserDemandStaticSchema(_BaseSchema):
node_id: Series[Int32] = pa.Field(nullable=False, default=0)
active: Series[pa.BOOL] = pa.Field(nullable=True)
Expand Down
10 changes: 5 additions & 5 deletions python/ribasim/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pandas.testing import assert_frame_equal
from pydantic import ValidationError
from ribasim import Model, Node, Solver
from ribasim.nodes import basin, pump, terminal, user_demand
from ribasim.nodes import basin, pump, user_demand
from shapely.geometry import Point


Expand Down Expand Up @@ -84,13 +84,13 @@ def test_repr():


def test_extra_columns():
terminal_static = terminal.Static(meta_id=[-1, -2, -3])
assert "meta_id" in terminal_static.df.columns
assert (terminal_static.df.meta_id == [-1, -2, -3]).all()
pump_static = pump.Static(meta_id=[-1], flow_rate=[1.2])
assert "meta_id" in pump_static.df.columns
assert pump_static.df.meta_id.iloc[0] == -1

with pytest.raises(ValidationError):
# Extra column "extra" needs "meta_" prefix
terminal.Static(meta_id=[-1, -2, -3], extra=[-1, -2, -3])
pump.Static(extra=[-2], flow_rate=[1.2])


def test_extra_spatial_columns():
Expand Down
14 changes: 0 additions & 14 deletions ribasim_qgis/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,20 +638,6 @@ def attributes(cls) -> list[QgsField]:
]


class TerminalStatic(Input):
@classmethod
def input_type(cls) -> str:
return "Terminal / static"

@classmethod
def geometry_type(cls) -> str:
return "No Geometry"

@classmethod
def attributes(cls) -> list[QgsField]:
return [QgsField("node_id", QVariant.Int)]


class FlowBoundaryStatic(Input):
@classmethod
def input_type(cls) -> str:
Expand Down

0 comments on commit 1ef15e1

Please sign in to comment.