Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[cartesian]: unexpanded sdfg cleanups #1843

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

romanc
Copy link
Contributor

@romanc romanc commented Feb 1, 2025

Description

Refactors from recent debugging sessions around transient arrays in the "unexpanded SDFG" (the one with the library nodes):

  • Remove unused **kwargs from OirSDFGBuilder
  • Forward debugging information about transient arrays to DaCe
  • Use a (constant) variable for connector prefixes of data going in to/out of the library nodes
  • Configure the lifetime of transient arrays directly in OirSDFGBuilder
  • Configure storage type of transient arrays directly in OirSDFGBuilder (for GPU targets)
  • Configure the LibraryNode's device type directly in OirSDFGBuilder (for GPU targets)

Sidenote on the allocation lifetime: In the orchestrated code path, we reset the allocation lifetime of transients to SDFG when we freeze the stencil with origin/domain

def freeze_origin_domain_sdfg(inner_sdfg, arg_names, field_info, *, origin, domain):
wrapper_sdfg = dace.SDFG("frozen_" + inner_sdfg.name)
state = wrapper_sdfg.add_state("frozen_" + inner_sdfg.name + "_state")
inputs = set()
outputs = set()
for inner_state in inner_sdfg.nodes():
for node in inner_state.nodes():
if (
not isinstance(node, dace.nodes.AccessNode)
or inner_sdfg.arrays[node.data].transient
):
continue
if node.has_reads(inner_state):
inputs.add(node.data)
if node.has_writes(inner_state):
outputs.add(node.data)
nsdfg = state.add_nested_sdfg(inner_sdfg, None, inputs, outputs)
_sdfg_add_arrays_and_edges(
field_info, wrapper_sdfg, state, inner_sdfg, nsdfg, inputs, outputs, origins=origin
)
# in special case of empty domain, remove entire SDFG.
if any(d == 0 for d in domain):
states = wrapper_sdfg.states()
assert len(states) == 1
for node in states[0].nodes():
state.remove_node(node)
# make sure that symbols are passed through to inner sdfg
for symbol in nsdfg.sdfg.free_symbols:
if symbol not in wrapper_sdfg.symbols:
wrapper_sdfg.add_symbol(symbol, nsdfg.sdfg.symbols[symbol])
# Try to inline wrapped SDFG before symbols are specialized to avoid extra views
inline_sdfgs(wrapper_sdfg)
_sdfg_specialize_symbols(wrapper_sdfg, domain)
for _, _, array in wrapper_sdfg.arrays_recursive():
if array.transient:
array.lifetime = dace.dtypes.AllocationLifetime.SDFG
wrapper_sdfg.arg_names = arg_names
return wrapper_sdfg

This might be relevant when tracking down orchestration performance. Seems odd at least.

Requirements

  • All fixes and/or new features come with corresponding tests.
    Covered by existing tests
  • Important design decisions have been documented in the appropriate ADR inside the docs/development/ADRs/ folder.
    N/A

@romanc romanc force-pushed the romanc/unexpanded-sdfg-cleanups branch 2 times, most recently from 76c47ad to c7aba86 Compare February 3, 2025 12:52
@romanc romanc marked this pull request as ready for review February 3, 2025 18:18
Comment on lines +36 to +44
transient_storage_per_device: Dict[Literal["cpu", "gpu"], dace.StorageType] = {
"cpu": dace.StorageType.Default,
"gpu": dace.StorageType.GPU_Global,
}

device_type_per_device: Dict[Literal["cpu", "gpu"], dace.DeviceType] = {
"cpu": dace.DeviceType.CPU,
"gpu": dace.DeviceType.GPU,
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a named type for Literal["cpu", "gpu"]? I found it typed like in the LayoutInfo

class LayoutInfo(TypedDict):
alignment: int # measured in bytes
device: Literal["cpu", "gpu"]
layout_map: Callable[[Tuple[str, ...]], Tuple[Optional[int], ...]]
is_optimal_layout: Callable[[Any, Tuple[str, ...]], bool]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we need an enum for this

Roman Cattaneo and others added 5 commits February 10, 2025 09:09
Forward debug info from gt4py to dace if we know it. If we don't know,
just don't specify instead of setting `DebugInfo(0)`.

Moved `get_dace_debuginfo()` one folder higher from expansion utils into
"general" dace utils because its not only used in expansion.
Transients are added in OirSDFGBuilder, where no array lifetime is
configured. After building that SDFG, the lifetime of all transients is
manually set to `Persistent` (which is an optimization leading to less
frequent memory allocation in case a kernel is called multiple times).
In this commit we directly specify the transient's lifetime when
building the SDFG.
For GPU targets, we have to configure the `storage_type` for transient
arrays. In addition, we have to set the library node's `device` property.
We can do both while building the SDFG instead of separate passes afterwards.
@romanc romanc force-pushed the romanc/unexpanded-sdfg-cleanups branch from 553196d to 38df291 Compare February 10, 2025 08:09
@romanc
Copy link
Contributor Author

romanc commented Feb 10, 2025

(rebased onto main branch because there was a hickup on the CSCS-CI)

@FlorianDeconinck
Copy link
Contributor

Re: freeze_origin_domain_sdfg

My cursory look seems to point that this is only called at top level for DaceLazyStencil et al. e.g. it's not called when we are doing orchestration but for the top level arguments (and then it's fine).

We should log a task to double check and we need a methodological/code way to differentiate code path we expect is in stencil mode and what is in orchestrated mode

Copy link
Contributor

@FlorianDeconinck FlorianDeconinck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good. Couple of question.

Do the enum for the device type, the string is making me sad.

access_node = state.add_access(field, debuginfo=dace.DebugInfo(0))
library_node.add_in_connector("__in_" + field)
access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field]))
connector_name = CONNECTOR_PREFIX_IN + field
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: should we try to always use f-string when we have string operation to have a visual easy way to pick on the type of things?

None,
)
for memlet in computation.write_memlets:
if memlet.field not in write_acc_and_conn:
write_acc_and_conn[memlet.field] = (
sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

State gets no debug info? Do you expect the bottom to be enough? (I have no idea how this works inside DaCe)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants