Skip to content

Commit

Permalink
Review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Jul 25, 2024
1 parent 84f15d5 commit c1ff003
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ def _add_storage(
Add storage for data containers used in the SDFG. For fields, it allocates dace arrays,
while scalars are stored as SDFG symbols.
The fields used as temporary arrays, when `transient = True`, are allocated internally
in the SDFG scope; when `transient = False`, the fields are allocated by the SDFG caller.
The latter case (external arrays) is for fields passed as program arguments.
The fields used as temporary arrays, when `transient = True`, are allocated and exist
only within the SDFG; when `transient = False`, the fields have to be allocated outside
and have to be passed as array arguments to the SDFG.
"""
if isinstance(symbol_type, ts.FieldType):
dtype = dace_fieldview_util.as_dace_type(symbol_type.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,17 @@ def _add_entry_memlet_path(
) -> None:
self.input_connections.append((src, src_subset, dst_node, dst_conn))

def _add_edge(
self,
src_node: dace.Node,
src_node_connector: Optional[str],
dst_node: dace.Node,
dst_node_connector: Optional[str],
memlet: dace.Memlet,
) -> None:
"""Helper method to add an edge in current state."""
self.state.add_edge(src_node, src_node_connector, dst_node, dst_node_connector, memlet)

def _add_map(
self,
name: str,
Expand All @@ -121,7 +132,7 @@ def _add_map(
],
**kwargs: Any,
) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]:
"""Helper method to add a map with unique ame in current state."""
"""Helper method to add a map with unique name in current state."""
return self.subgraph_builder.add_map(name, self.state, ndrange, **kwargs)

def _add_tasklet(
Expand All @@ -132,7 +143,7 @@ def _add_tasklet(
code: str,
**kwargs: Any,
) -> dace.nodes.Tasklet:
"""Helper method to add a tasklet with unique ame in current state."""
"""Helper method to add a tasklet with unique name in current state."""
return self.subgraph_builder.add_tasklet(name, self.state, inputs, outputs, code, **kwargs)

def _get_tasklet_result(
Expand All @@ -145,7 +156,7 @@ def _get_tasklet_result(
self.sdfg.add_scalar(temp_name, dtype, transient=True)
data_type = dace_fieldview_util.as_scalar_type(str(dtype.as_numpy_dtype()))
temp_node = self.state.add_access(temp_name)
self.state.add_edge(
self._add_edge(
src_node,
src_connector,
temp_node,
Expand Down Expand Up @@ -228,7 +239,7 @@ def _visit_deref(self, node: gtir.FunCall) -> MemletExpr | ValueExpr:
)

elif isinstance(index_expr, ValueExpr):
self.state.add_edge(
self._add_edge(
index_expr.node,
None,
deref_node,
Expand Down Expand Up @@ -319,7 +330,7 @@ def _make_cartesian_shift(
input_connector,
)
elif isinstance(input_expr, ValueExpr):
self.state.add_edge(
self._add_edge(
input_expr.node,
None,
dynamic_offset_tasklet,
Expand Down Expand Up @@ -374,7 +385,7 @@ def _make_dynamic_neighbor_offset(
"offset",
)
else:
self.state.add_edge(
self._add_edge(
offset_expr.node,
None,
tasklet_node,
Expand Down Expand Up @@ -495,7 +506,7 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | MemletExpr | Value

for connector, arg_expr in node_connections.items():
if isinstance(arg_expr, ValueExpr):
self.state.add_edge(
self._add_edge(
arg_expr.node,
None,
tasklet_node,
Expand Down

0 comments on commit c1ff003

Please sign in to comment.