From 61c563d4ea52c8b7fea4f8d156c0689f56d23a63 Mon Sep 17 00:00:00 2001 From: Sylvain Leclerc Date: Tue, 13 Feb 2024 08:26:04 +0100 Subject: [PATCH] Implement change requests from ports review (#7) Signed-off-by: Sylvain Leclerc --- src/andromede/expression/context_adder.py | 4 +- src/andromede/expression/copy.py | 12 +- src/andromede/expression/degree.py | 12 +- src/andromede/expression/evaluate.py | 12 +- src/andromede/expression/expression.py | 25 +++++ src/andromede/expression/indexing.py | 28 +++-- src/andromede/expression/port_resolver.py | 10 +- src/andromede/expression/print.py | 12 +- src/andromede/expression/visitor.py | 14 +-- src/andromede/model/model.py | 106 +++++++++++++++++- src/andromede/simulation/linearize.py | 38 ++++--- src/andromede/simulation/optimization.py | 15 +-- src/andromede/study/__init__.py | 1 - src/andromede/study/network.py | 31 ----- .../expressions/test_port_resolver.py | 1 - tests/andromede/test_andromede.py | 22 +--- tests/andromede/test_model.py | 19 ++++ 17 files changed, 224 insertions(+), 138 deletions(-) diff --git a/src/andromede/expression/context_adder.py b/src/andromede/expression/context_adder.py index c04b7227..812e95f7 100644 --- a/src/andromede/expression/context_adder.py +++ b/src/andromede/expression/context_adder.py @@ -38,12 +38,12 @@ def variable(self, node: VariableNode) -> ExpressionNode: def parameter(self, node: ParameterNode) -> ExpressionNode: return ComponentParameterNode(self.component_id, node.name) - def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode: + def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode: raise ValueError( "This expression has already been associated to another component." ) - def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode: + def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode: raise ValueError( "This expression has already been associated to another component." ) diff --git a/src/andromede/expression/copy.py b/src/andromede/expression/copy.py index 8388c42a..c135ee59 100644 --- a/src/andromede/expression/copy.py +++ b/src/andromede/expression/copy.py @@ -43,12 +43,6 @@ class CopyVisitor(ExpressionVisitorOperations[ExpressionNode]): Simply copies the whole AST. """ - def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode: - return ComponentParameterNode(node.component_id, node.name) - - def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode: - return ComponentVariableNode(node.component_id, node.name) - def literal(self, node: LiteralNode) -> ExpressionNode: return LiteralNode(node.value) @@ -63,6 +57,12 @@ def variable(self, node: VariableNode) -> ExpressionNode: def parameter(self, node: ParameterNode) -> ExpressionNode: return ParameterNode(node.name) + def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode: + return ComponentVariableNode(node.component_id, node.name) + + def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode: + return ComponentParameterNode(node.component_id, node.name) + def copy_expression_range( self, expression_range: ExpressionRange ) -> ExpressionRange: diff --git a/src/andromede/expression/degree.py b/src/andromede/expression/degree.py index 69037545..cfd175cd 100644 --- a/src/andromede/expression/degree.py +++ b/src/andromede/expression/degree.py @@ -41,12 +41,6 @@ class ExpressionDegreeVisitor(ExpressionVisitor[int]): Computes degree of expression with respect to variables. """ - def comp_parameter(self, node: ComponentParameterNode) -> int: - return 0 - - def comp_variable(self, node: ComponentVariableNode) -> int: - return 1 - def literal(self, node: LiteralNode) -> int: return 0 @@ -78,6 +72,12 @@ def variable(self, node: VariableNode) -> int: def parameter(self, node: ParameterNode) -> int: return 0 + def comp_variable(self, node: ComponentVariableNode) -> int: + return 1 + + def comp_parameter(self, node: ComponentParameterNode) -> int: + return 0 + def time_operator(self, node: TimeOperatorNode) -> int: if node.name in ["TimeShift", "TimeEvaluation"]: return visit(node.operand, self) diff --git a/src/andromede/expression/evaluate.py b/src/andromede/expression/evaluate.py index 5981f648..b51c0e86 100644 --- a/src/andromede/expression/evaluate.py +++ b/src/andromede/expression/evaluate.py @@ -102,12 +102,6 @@ class EvaluationVisitor(ExpressionVisitorOperations[float]): context: ValueProvider - def comp_parameter(self, node: ComponentParameterNode) -> float: - return self.context.get_component_parameter_value(node.component_id, node.name) - - def comp_variable(self, node: ComponentVariableNode) -> float: - return self.context.get_component_variable_value(node.component_id, node.name) - def literal(self, node: LiteralNode) -> float: return node.value @@ -120,6 +114,12 @@ def variable(self, node: VariableNode) -> float: def parameter(self, node: ParameterNode) -> float: return self.context.get_parameter_value(node.name) + def comp_parameter(self, node: ComponentParameterNode) -> float: + return self.context.get_component_parameter_value(node.component_id, node.name) + + def comp_variable(self, node: ComponentVariableNode) -> float: + return self.context.get_component_variable_value(node.component_id, node.name) + def time_operator(self, node: TimeOperatorNode) -> float: raise NotImplementedError() diff --git a/src/andromede/expression/expression.py b/src/andromede/expression/expression.py index 81290807..ffbd4e22 100644 --- a/src/andromede/expression/expression.py +++ b/src/andromede/expression/expression.py @@ -179,6 +179,14 @@ def param(name: str) -> ParameterNode: @dataclass(frozen=True, eq=False) class ComponentParameterNode(ExpressionNode): + """ + Represents one parameter of one component. + + When building actual equations for a system, + we need to associated each parameter to its + actual component, at some point. + """ + component_id: str name: str @@ -189,6 +197,14 @@ def comp_param(component_id: str, name: str) -> ComponentParameterNode: @dataclass(frozen=True, eq=False) class ComponentVariableNode(ExpressionNode): + """ + Represents one variable of one component. + + When building actual equations for a system, + we need to associated each variable to its + actual component, at some point. + """ + component_id: str name: str @@ -321,6 +337,15 @@ def expression_range( class InstancesTimeIndex: + """ + Defines a set of time indices on which a time operator operates. + + In particular, it defines time indices created by the shift operator. + + The actual indices can either be defined as a time range defined by + 2 expression, or as a list of expressions. + """ + expressions: Union[List[ExpressionNode], ExpressionRange] def __init__( diff --git a/src/andromede/expression/indexing.py b/src/andromede/expression/indexing.py index 37547706..11051dd5 100644 --- a/src/andromede/expression/indexing.py +++ b/src/andromede/expression/indexing.py @@ -68,16 +68,6 @@ class TimeScenarioIndexingVisitor(ExpressionVisitor[IndexingStructure]): context: IndexingStructureProvider - def comp_parameter(self, node: ComponentParameterNode) -> IndexingStructure: - return self.context.get_component_parameter_structure( - node.component_id, node.name - ) - - def comp_variable(self, node: ComponentVariableNode) -> IndexingStructure: - return self.context.get_component_variable_structure( - node.component_id, node.name - ) - def literal(self, node: LiteralNode) -> IndexingStructure: return IndexingStructure(False, False) @@ -109,6 +99,16 @@ def parameter(self, node: ParameterNode) -> IndexingStructure: scenario = self.context.get_parameter_structure(node.name).scenario == True return IndexingStructure(time, scenario) + def comp_variable(self, node: ComponentVariableNode) -> IndexingStructure: + return self.context.get_component_variable_structure( + node.component_id, node.name + ) + + def comp_parameter(self, node: ComponentParameterNode) -> IndexingStructure: + return self.context.get_component_parameter_structure( + node.component_id, node.name + ) + def time_operator(self, node: TimeOperatorNode) -> IndexingStructure: time_operator_cls = getattr(andromede.expression.time_operator, node.name) if time_operator_cls.rolling(): @@ -126,10 +126,14 @@ def scenario_operator(self, node: ScenarioOperatorNode) -> IndexingStructure: return IndexingStructure(visit(node.operand, self).time, False) def port_field(self, node: PortFieldNode) -> IndexingStructure: - raise ValueError("Should be instantiated before computing indexing structure.") + raise ValueError( + "Port fields must be resolved before computing indexing structure." + ) def port_field_aggregator(self, node: PortFieldAggregatorNode) -> IndexingStructure: - raise ValueError("Should be instantiated before computing indexing structure.") + raise ValueError( + "Port fields aggregators must be resolved before computing indexing structure." + ) def compute_indexation( diff --git a/src/andromede/expression/port_resolver.py b/src/andromede/expression/port_resolver.py index e7f9dc08..6f333408 100644 --- a/src/andromede/expression/port_resolver.py +++ b/src/andromede/expression/port_resolver.py @@ -42,21 +42,21 @@ class PortResolver(CopyVisitor): their corresponding expression. """ - ports_expressions: Dict[PortFieldKey, List[ExpressionNode]] component_id: str + ports_expressions: Dict[PortFieldKey, List[ExpressionNode]] def port_field(self, node: PortFieldNode) -> ExpressionNode: - expression = self.ports_expressions[ + expressions = self.ports_expressions[ PortFieldKey( self.component_id, PortFieldId(node.port_name, node.field_name) ) ] - if len(expression) != 1: + if len(expressions) != 1: raise ValueError( f"Invalid number of expression for port : {node.port_name}" ) else: - return expression[0] + return expressions[0] def port_field_aggregator(self, node: PortFieldAggregatorNode) -> ExpressionNode: if node.aggregator != "PortSum": @@ -80,4 +80,4 @@ def resolve_port( component_id: str, ports_expressions: Dict[PortFieldKey, List[ExpressionNode]], ) -> ExpressionNode: - return visit(expression, PortResolver(ports_expressions, component_id)) + return visit(expression, PortResolver(component_id, ports_expressions)) diff --git a/src/andromede/expression/print.py b/src/andromede/expression/print.py index f181c17b..c01ae76f 100644 --- a/src/andromede/expression/print.py +++ b/src/andromede/expression/print.py @@ -54,12 +54,6 @@ class PrinterVisitor(ExpressionVisitor[str]): TODO: remove parenthis where not necessary. """ - def comp_parameter(self, node: ComponentParameterNode) -> str: - return f"{node.component_id}.{node.name}" - - def comp_variable(self, node: ComponentVariableNode) -> str: - return f"{node.component_id}.{node.name}" - def literal(self, node: LiteralNode) -> str: return str(node.value) @@ -98,6 +92,12 @@ def variable(self, node: VariableNode) -> str: def parameter(self, node: ParameterNode) -> str: return node.name + def comp_variable(self, node: ComponentVariableNode) -> str: + return f"{node.component_id}.{node.name}" + + def comp_parameter(self, node: ComponentParameterNode) -> str: + return f"{node.component_id}.{node.name}" + # TODO: Add pretty print for node.instances_index def time_operator(self, node: TimeOperatorNode) -> str: return f"({visit(node.operand, self)}.{str(node.name)}({node.instances_index}))" diff --git a/src/andromede/expression/visitor.py b/src/andromede/expression/visitor.py index a89b3d88..25bbfb02 100644 --- a/src/andromede/expression/visitor.py +++ b/src/andromede/expression/visitor.py @@ -86,31 +86,31 @@ def parameter(self, node: ParameterNode) -> T: ... @abstractmethod - def time_operator(self, node: TimeOperatorNode) -> T: + def comp_parameter(self, node: ComponentParameterNode) -> T: ... @abstractmethod - def time_aggregator(self, node: TimeAggregatorNode) -> T: + def comp_variable(self, node: ComponentVariableNode) -> T: ... @abstractmethod - def scenario_operator(self, node: ScenarioOperatorNode) -> T: + def time_operator(self, node: TimeOperatorNode) -> T: ... @abstractmethod - def port_field(self, node: PortFieldNode) -> T: + def time_aggregator(self, node: TimeAggregatorNode) -> T: ... @abstractmethod - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> T: + def scenario_operator(self, node: ScenarioOperatorNode) -> T: ... @abstractmethod - def comp_parameter(self, node: ComponentParameterNode) -> T: + def port_field(self, node: PortFieldNode) -> T: ... @abstractmethod - def comp_variable(self, node: ComponentVariableNode) -> T: + def port_field_aggregator(self, node: PortFieldAggregatorNode) -> T: ... diff --git a/src/andromede/model/model.py b/src/andromede/model/model.py index e102144a..ba54ac31 100644 --- a/src/andromede/model/model.py +++ b/src/andromede/model/model.py @@ -19,14 +19,33 @@ from dataclasses import dataclass, field from typing import Dict, Iterable, Optional -from andromede.expression import ExpressionNode +from andromede.expression import ( + AdditionNode, + ComparisonNode, + DivisionNode, + ExpressionNode, + ExpressionVisitor, + LiteralNode, + MultiplicationNode, + NegationNode, + ParameterNode, + SubstractionNode, + VariableNode, +) from andromede.expression.degree import is_linear from andromede.expression.expression import ( + BinaryOperatorNode, ComponentParameterNode, ComponentVariableNode, + PortFieldAggregatorNode, + PortFieldNode, + ScenarioOperatorNode, + TimeAggregatorNode, + TimeOperatorNode, ) from andromede.expression.indexing import IndexingStructureProvider, compute_indexation from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.visitor import T, visit from andromede.model.constraint import Constraint from andromede.model.parameter import Parameter from andromede.model.port import PortType @@ -45,12 +64,16 @@ def get_variable_structure(self, name: str) -> IndexingStructure: def get_component_parameter_structure( self, component_id: str, name: str ) -> IndexingStructure: - raise NotImplementedError("Cannot have instantiated parameters in models.") + raise NotImplementedError( + "Cannot have parameters associated to components in models." + ) def get_component_variable_structure( self, component_id: str, name: str ) -> IndexingStructure: - raise NotImplementedError("Cannot have instantiated parameters in models.") + raise NotImplementedError( + "Cannot have variables associated to components in models." + ) return Provider() @@ -83,6 +106,15 @@ class PortFieldDefinition: port_field: PortFieldId definition: ExpressionNode + def __post_init__(self) -> None: + _validate_port_field_expression(self) + + +def port_field_def( + port_name: str, field_name: str, definition: ExpressionNode +) -> PortFieldDefinition: + return PortFieldDefinition(PortFieldId(port_name, field_name), definition) + @dataclass(frozen=True) class Model: @@ -128,7 +160,6 @@ def __post_init__(self) -> None: raise ValueError( f"Invalid port field in port field definition: {port_field}" ) - # TODO: should we check something on the expression ? (comparison...) def get_all_constraints(self) -> Iterable[Constraint]: """ @@ -178,3 +209,70 @@ def model( if port_fields_definitions else {}, ) + + +class _PortFieldExpressionChecker(ExpressionVisitor[None]): + """ + Visits the whole expression to check there is no: + comparison, other port field, component-associated parametrs or variables... + """ + + def literal(self, node: LiteralNode) -> None: + pass + + def negation(self, node: NegationNode) -> None: + visit(node.operand, self) + + def _visit_binary_op(self, node: BinaryOperatorNode) -> None: + visit(node.left, self) + visit(node.right, self) + + def addition(self, node: AdditionNode) -> None: + self._visit_binary_op(node) + + def substraction(self, node: SubstractionNode) -> None: + self._visit_binary_op(node) + + def multiplication(self, node: MultiplicationNode) -> None: + self._visit_binary_op(node) + + def division(self, node: DivisionNode) -> None: + self._visit_binary_op(node) + + def comparison(self, node: ComparisonNode) -> None: + raise ValueError("Port definition cannot contain a comparison operator.") + + def variable(self, node: VariableNode) -> None: + pass + + def parameter(self, node: ParameterNode) -> None: + pass + + def comp_parameter(self, node: ComponentParameterNode) -> None: + raise ValueError( + "Port definition must not contain a parameter associated to a component." + ) + + def comp_variable(self, node: ComponentVariableNode) -> None: + raise ValueError( + "Port definition must not contain a variable associated to a component." + ) + + def time_operator(self, node: TimeOperatorNode) -> None: + visit(node.operand, self) + + def time_aggregator(self, node: TimeAggregatorNode) -> None: + visit(node.operand, self) + + def scenario_operator(self, node: ScenarioOperatorNode) -> None: + visit(node.operand, self) + + def port_field(self, node: PortFieldNode) -> None: + raise ValueError("Port definition cannot reference another port field.") + + def port_field_aggregator(self, node: PortFieldAggregatorNode) -> None: + raise ValueError("Port definition cannot contain port field aggregation.") + + +def _validate_port_field_expression(definition: PortFieldDefinition) -> None: + visit(definition.definition, _PortFieldExpressionChecker()) diff --git a/src/andromede/simulation/linearize.py b/src/andromede/simulation/linearize.py index b2ccd91f..9fe6738a 100644 --- a/src/andromede/simulation/linearize.py +++ b/src/andromede/simulation/linearize.py @@ -62,6 +62,24 @@ def variable(self, node: VariableNode) -> LinearExpression: def parameter(self, node: ParameterNode) -> LinearExpression: raise ValueError("Parameters must be evaluated before linearization.") + def comp_variable(self, node: ComponentVariableNode) -> LinearExpression: + return LinearExpression( + [ + Term( + 1, + node.component_id, + node.name, + self.structure_provider.get_component_variable_structure( + node.component_id, node.name + ), + ) + ], + 0, + ) + + def comp_parameter(self, node: ComponentParameterNode) -> LinearExpression: + raise ValueError("Parameters must be evaluated before linearization.") + def time_operator(self, node: TimeOperatorNode) -> LinearExpression: if self.value_provider is None: raise ValueError( @@ -123,24 +141,8 @@ def port_field(self, node: PortFieldNode) -> LinearExpression: raise ValueError("Port fields must be replaced before linearization.") def port_field_aggregator(self, node: PortFieldAggregatorNode) -> LinearExpression: - raise ValueError("Port fields must be replaced before linearization.") - - def comp_parameter(self, node: ComponentParameterNode) -> LinearExpression: - raise ValueError("Parameters must be evaluated before linearization.") - - def comp_variable(self, node: ComponentVariableNode) -> LinearExpression: - return LinearExpression( - [ - Term( - 1, - node.component_id, - node.name, - self.structure_provider.get_component_variable_structure( - node.component_id, node.name - ), - ) - ], - 0, + raise ValueError( + "Port fields aggregators must be replaced before linearization." ) diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 89dc4909..60d9916b 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -58,17 +58,6 @@ class TimestepComponentVariableKey: scenario: Optional[int] = None -@dataclass(eq=True, frozen=True) -class TimestepFlowVariableKey: - """ - Identifies the solver variable for one timestep and one link. - """ - - link_id: str - block_timestep: Optional[int] = None - scenario: Optional[int] = None - - def _get_parameter_value( context: "OptimizationContext", block_timestep: int, @@ -141,9 +130,7 @@ def get_parameter_value(self, name: str) -> float: ) def parameter_is_constant_over_time(self, name: str) -> bool: - return _parameter_is_constant_over_time( - component, name, context, block_timestep, scenario - ) + return not component.model.parameters[name].structure.time return Provider() diff --git a/src/andromede/study/__init__.py b/src/andromede/study/__init__.py index 4daaa94e..2a89893a 100644 --- a/src/andromede/study/__init__.py +++ b/src/andromede/study/__init__.py @@ -21,7 +21,6 @@ TimeSeriesData, ) from .network import ( - Arc, Component, Network, Node, diff --git a/src/andromede/study/network.py b/src/andromede/study/network.py index 85ae905f..44b77974 100644 --- a/src/andromede/study/network.py +++ b/src/andromede/study/network.py @@ -46,26 +46,11 @@ class Node(Component): pass -@dataclass(frozen=True) -class Arc: - """ - An arc between 2 nodes of the network. - TODO: we could imagine that it would be a component. - """ - - id: str - node1_id: str - node2_id: str - - @dataclass(frozen=True) class PortRef: component: Component port_id: str - def get_id(self) -> str: - return f"{self.port_id}_{self.component.id}" - @dataclass() class PortsConnection: @@ -79,9 +64,6 @@ def __init__(self, port1: PortRef, port2: PortRef): self.master_port = {} self.__validate_ports() - def get_id(self) -> str: - return f"{self.port1.get_id()}__{self.port2.get_id()}" - def __validate_ports(self) -> None: model1 = self.port1.component.model model2 = self.port2.component.model @@ -134,7 +116,6 @@ class Network: def __init__(self, id: str): self.id: str = id self._nodes: Dict[str, Node] = {} - self._arcs: Dict[str, Arc] = {} self._components: Dict[str, Component] = {} self._connections: List[PortsConnection] = [] @@ -174,18 +155,6 @@ def all_components(self) -> Iterable[Component]: """ return itertools.chain(self.nodes, self.components) - def add_arc(self, arc: Arc) -> None: - self._check_node_exists(arc.node1_id) - self._check_node_exists(arc.node2_id) - self._arcs[arc.id] = arc - - def get_arc(self, arc_id: str) -> Arc: - return self._arcs[arc_id] - - @property - def arcs(self) -> Iterable[Arc]: - return self._arcs.values() - def connect(self, port1: PortRef, port2: PortRef) -> None: ports_connection = PortsConnection(port1, port2) self._connections.append(ports_connection) diff --git a/tests/andromede/expressions/test_port_resolver.py b/tests/andromede/expressions/test_port_resolver.py index 38a9a216..e38f20cb 100644 --- a/tests/andromede/expressions/test_port_resolver.py +++ b/tests/andromede/expressions/test_port_resolver.py @@ -42,7 +42,6 @@ def test_port_field_resolution_sum(): ports_expressions[key] = [var("flow1"), var("flow2")] expression_2 = port_field("port", "field").sum_connections() - # TODO remove 0 from sum() assert expressions_equal( resolve_port(expression_2, "com_id", ports_expressions), var("flow1") + var("flow2"), diff --git a/tests/andromede/test_andromede.py b/tests/andromede/test_andromede.py index ef5038ef..5aeec26a 100644 --- a/tests/andromede/test_andromede.py +++ b/tests/andromede/test_andromede.py @@ -13,7 +13,6 @@ import pytest from andromede.expression import literal, param, var -from andromede.expression.expression import port_field from andromede.expression.indexing_structure import IndexingStructure from andromede.libs.standard import ( BALANCE_PORT_TYPE, @@ -27,14 +26,7 @@ THERMAL_CLUSTER_MODEL_HD, UNSUPPLIED_ENERGY_MODEL, ) -from andromede.model import ( - Constraint, - Model, - ModelPort, - float_parameter, - float_variable, - model, -) +from andromede.model import Model, ModelPort, float_parameter, float_variable, model from andromede.model.model import PortFieldDefinition, PortFieldId from andromede.simulation import ( BlockBorderManagement, @@ -43,7 +35,6 @@ build_problem, ) from andromede.study import ( - Arc, ConstantData, DataBase, Network, @@ -59,13 +50,12 @@ def test_network() -> None: network = Network("test") assert network.id == "test" assert list(network.nodes) == [] - assert list(network.arcs) == [] assert list(network.components) == [] + assert list(network.all_components) == [] + assert list(network.connections) == [] with pytest.raises(KeyError): network.get_node("N") - with pytest.raises(KeyError): - network.get_arc("L") N1 = Node(model=NODE_BALANCE_MODEL, id="N1") N2 = Node(model=NODE_BALANCE_MODEL, id="N2") @@ -73,12 +63,6 @@ def test_network() -> None: network.add_node(N2) assert list(network.nodes) == [N1, N2] assert network.get_node(N1.id) == N1 - - with pytest.raises(ValueError): - network.add_arc(Arc("L", "N", "N2")) - network.add_arc(Arc("L", "N1", "N2")) - assert list(network.arcs) == [Arc("L", "N1", "N2")] - assert network.get_arc("L") == Arc("L", "N1", "N2") assert network.get_component("N1") == Node(model=NODE_BALANCE_MODEL, id="N1") with pytest.raises(KeyError): network.get_component("unknown") diff --git a/tests/andromede/test_model.py b/tests/andromede/test_model.py index 8bf08b67..85238fbc 100644 --- a/tests/andromede/test_model.py +++ b/tests/andromede/test_model.py @@ -15,11 +15,15 @@ from andromede.expression.expression import ( ExpressionNode, ExpressionRange, + comp_param, + comp_var, literal, param, + port_field, var, ) from andromede.model import Constraint, float_variable, model +from andromede.model.model import PortFieldDefinition, port_field_def @pytest.mark.parametrize( @@ -187,3 +191,18 @@ def test_instantiating_a_model_with_non_linear_scenario_operator_in_the_objectiv objective_contribution=var("generation").variance(), ) assert str(exc.value) == "Objective contribution must be a linear expression." + + +@pytest.mark.parametrize( + "expression", + [ + var("x") <= 0, + comp_var("c", "x"), + comp_param("c", "x"), + port_field("p", "f"), + port_field("p", "f").sum_connections(), + ], +) +def test_invalid_port_field_definition_should_raise(expression: ExpressionNode) -> None: + with pytest.raises(ValueError) as exc: + port_field_def(port_name="p", field_name="f", definition=expression)