Skip to content

Commit

Permalink
Implement change requests from ports review (#7)
Browse files Browse the repository at this point in the history
Signed-off-by: Sylvain Leclerc <[email protected]>
  • Loading branch information
sylvlecl authored Feb 13, 2024
1 parent 0dbceb0 commit 61c563d
Show file tree
Hide file tree
Showing 17 changed files with 224 additions and 138 deletions.
4 changes: 2 additions & 2 deletions src/andromede/expression/context_adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def variable(self, node: VariableNode) -> ExpressionNode:
def parameter(self, node: ParameterNode) -> ExpressionNode:
return ComponentParameterNode(self.component_id, node.name)

def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode:
def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode:
raise ValueError(
"This expression has already been associated to another component."
)

def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode:
def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode:
raise ValueError(
"This expression has already been associated to another component."
)
Expand Down
12 changes: 6 additions & 6 deletions src/andromede/expression/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ class CopyVisitor(ExpressionVisitorOperations[ExpressionNode]):
Simply copies the whole AST.
"""

def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode:
return ComponentParameterNode(node.component_id, node.name)

def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode:
return ComponentVariableNode(node.component_id, node.name)

def literal(self, node: LiteralNode) -> ExpressionNode:
return LiteralNode(node.value)

Expand All @@ -63,6 +57,12 @@ def variable(self, node: VariableNode) -> ExpressionNode:
def parameter(self, node: ParameterNode) -> ExpressionNode:
return ParameterNode(node.name)

def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode:
return ComponentVariableNode(node.component_id, node.name)

def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode:
return ComponentParameterNode(node.component_id, node.name)

def copy_expression_range(
self, expression_range: ExpressionRange
) -> ExpressionRange:
Expand Down
12 changes: 6 additions & 6 deletions src/andromede/expression/degree.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ class ExpressionDegreeVisitor(ExpressionVisitor[int]):
Computes degree of expression with respect to variables.
"""

def comp_parameter(self, node: ComponentParameterNode) -> int:
return 0

def comp_variable(self, node: ComponentVariableNode) -> int:
return 1

def literal(self, node: LiteralNode) -> int:
return 0

Expand Down Expand Up @@ -78,6 +72,12 @@ def variable(self, node: VariableNode) -> int:
def parameter(self, node: ParameterNode) -> int:
return 0

def comp_variable(self, node: ComponentVariableNode) -> int:
return 1

def comp_parameter(self, node: ComponentParameterNode) -> int:
return 0

def time_operator(self, node: TimeOperatorNode) -> int:
if node.name in ["TimeShift", "TimeEvaluation"]:
return visit(node.operand, self)
Expand Down
12 changes: 6 additions & 6 deletions src/andromede/expression/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,6 @@ class EvaluationVisitor(ExpressionVisitorOperations[float]):

context: ValueProvider

def comp_parameter(self, node: ComponentParameterNode) -> float:
return self.context.get_component_parameter_value(node.component_id, node.name)

def comp_variable(self, node: ComponentVariableNode) -> float:
return self.context.get_component_variable_value(node.component_id, node.name)

def literal(self, node: LiteralNode) -> float:
return node.value

Expand All @@ -120,6 +114,12 @@ def variable(self, node: VariableNode) -> float:
def parameter(self, node: ParameterNode) -> float:
return self.context.get_parameter_value(node.name)

def comp_parameter(self, node: ComponentParameterNode) -> float:
return self.context.get_component_parameter_value(node.component_id, node.name)

def comp_variable(self, node: ComponentVariableNode) -> float:
return self.context.get_component_variable_value(node.component_id, node.name)

def time_operator(self, node: TimeOperatorNode) -> float:
raise NotImplementedError()

Expand Down
25 changes: 25 additions & 0 deletions src/andromede/expression/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ def param(name: str) -> ParameterNode:

@dataclass(frozen=True, eq=False)
class ComponentParameterNode(ExpressionNode):
"""
Represents one parameter of one component.
When building actual equations for a system,
we need to associated each parameter to its
actual component, at some point.
"""

component_id: str
name: str

Expand All @@ -189,6 +197,14 @@ def comp_param(component_id: str, name: str) -> ComponentParameterNode:

@dataclass(frozen=True, eq=False)
class ComponentVariableNode(ExpressionNode):
"""
Represents one variable of one component.
When building actual equations for a system,
we need to associated each variable to its
actual component, at some point.
"""

component_id: str
name: str

Expand Down Expand Up @@ -321,6 +337,15 @@ def expression_range(


class InstancesTimeIndex:
"""
Defines a set of time indices on which a time operator operates.
In particular, it defines time indices created by the shift operator.
The actual indices can either be defined as a time range defined by
2 expression, or as a list of expressions.
"""

expressions: Union[List[ExpressionNode], ExpressionRange]

def __init__(
Expand Down
28 changes: 16 additions & 12 deletions src/andromede/expression/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,6 @@ class TimeScenarioIndexingVisitor(ExpressionVisitor[IndexingStructure]):

context: IndexingStructureProvider

def comp_parameter(self, node: ComponentParameterNode) -> IndexingStructure:
return self.context.get_component_parameter_structure(
node.component_id, node.name
)

def comp_variable(self, node: ComponentVariableNode) -> IndexingStructure:
return self.context.get_component_variable_structure(
node.component_id, node.name
)

def literal(self, node: LiteralNode) -> IndexingStructure:
return IndexingStructure(False, False)

Expand Down Expand Up @@ -109,6 +99,16 @@ def parameter(self, node: ParameterNode) -> IndexingStructure:
scenario = self.context.get_parameter_structure(node.name).scenario == True
return IndexingStructure(time, scenario)

def comp_variable(self, node: ComponentVariableNode) -> IndexingStructure:
return self.context.get_component_variable_structure(
node.component_id, node.name
)

def comp_parameter(self, node: ComponentParameterNode) -> IndexingStructure:
return self.context.get_component_parameter_structure(
node.component_id, node.name
)

def time_operator(self, node: TimeOperatorNode) -> IndexingStructure:
time_operator_cls = getattr(andromede.expression.time_operator, node.name)
if time_operator_cls.rolling():
Expand All @@ -126,10 +126,14 @@ def scenario_operator(self, node: ScenarioOperatorNode) -> IndexingStructure:
return IndexingStructure(visit(node.operand, self).time, False)

def port_field(self, node: PortFieldNode) -> IndexingStructure:
raise ValueError("Should be instantiated before computing indexing structure.")
raise ValueError(
"Port fields must be resolved before computing indexing structure."
)

def port_field_aggregator(self, node: PortFieldAggregatorNode) -> IndexingStructure:
raise ValueError("Should be instantiated before computing indexing structure.")
raise ValueError(
"Port fields aggregators must be resolved before computing indexing structure."
)


def compute_indexation(
Expand Down
10 changes: 5 additions & 5 deletions src/andromede/expression/port_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ class PortResolver(CopyVisitor):
their corresponding expression.
"""

ports_expressions: Dict[PortFieldKey, List[ExpressionNode]]
component_id: str
ports_expressions: Dict[PortFieldKey, List[ExpressionNode]]

def port_field(self, node: PortFieldNode) -> ExpressionNode:
expression = self.ports_expressions[
expressions = self.ports_expressions[
PortFieldKey(
self.component_id, PortFieldId(node.port_name, node.field_name)
)
]
if len(expression) != 1:
if len(expressions) != 1:
raise ValueError(
f"Invalid number of expression for port : {node.port_name}"
)
else:
return expression[0]
return expressions[0]

def port_field_aggregator(self, node: PortFieldAggregatorNode) -> ExpressionNode:
if node.aggregator != "PortSum":
Expand All @@ -80,4 +80,4 @@ def resolve_port(
component_id: str,
ports_expressions: Dict[PortFieldKey, List[ExpressionNode]],
) -> ExpressionNode:
return visit(expression, PortResolver(ports_expressions, component_id))
return visit(expression, PortResolver(component_id, ports_expressions))
12 changes: 6 additions & 6 deletions src/andromede/expression/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,6 @@ class PrinterVisitor(ExpressionVisitor[str]):
TODO: remove parenthis where not necessary.
"""

def comp_parameter(self, node: ComponentParameterNode) -> str:
return f"{node.component_id}.{node.name}"

def comp_variable(self, node: ComponentVariableNode) -> str:
return f"{node.component_id}.{node.name}"

def literal(self, node: LiteralNode) -> str:
return str(node.value)

Expand Down Expand Up @@ -98,6 +92,12 @@ def variable(self, node: VariableNode) -> str:
def parameter(self, node: ParameterNode) -> str:
return node.name

def comp_variable(self, node: ComponentVariableNode) -> str:
return f"{node.component_id}.{node.name}"

def comp_parameter(self, node: ComponentParameterNode) -> str:
return f"{node.component_id}.{node.name}"

# TODO: Add pretty print for node.instances_index
def time_operator(self, node: TimeOperatorNode) -> str:
return f"({visit(node.operand, self)}.{str(node.name)}({node.instances_index}))"
Expand Down
14 changes: 7 additions & 7 deletions src/andromede/expression/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,31 +86,31 @@ def parameter(self, node: ParameterNode) -> T:
...

@abstractmethod
def time_operator(self, node: TimeOperatorNode) -> T:
def comp_parameter(self, node: ComponentParameterNode) -> T:
...

@abstractmethod
def time_aggregator(self, node: TimeAggregatorNode) -> T:
def comp_variable(self, node: ComponentVariableNode) -> T:
...

@abstractmethod
def scenario_operator(self, node: ScenarioOperatorNode) -> T:
def time_operator(self, node: TimeOperatorNode) -> T:
...

@abstractmethod
def port_field(self, node: PortFieldNode) -> T:
def time_aggregator(self, node: TimeAggregatorNode) -> T:
...

@abstractmethod
def port_field_aggregator(self, node: PortFieldAggregatorNode) -> T:
def scenario_operator(self, node: ScenarioOperatorNode) -> T:
...

@abstractmethod
def comp_parameter(self, node: ComponentParameterNode) -> T:
def port_field(self, node: PortFieldNode) -> T:
...

@abstractmethod
def comp_variable(self, node: ComponentVariableNode) -> T:
def port_field_aggregator(self, node: PortFieldAggregatorNode) -> T:
...


Expand Down
Loading

0 comments on commit 61c563d

Please sign in to comment.