-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor[cartesian]: unexpanded sdfg cleanups #1843
base: main
Are you sure you want to change the base?
Conversation
76c47ad
to
c7aba86
Compare
transient_storage_per_device: Dict[Literal["cpu", "gpu"], dace.StorageType] = { | ||
"cpu": dace.StorageType.Default, | ||
"gpu": dace.StorageType.GPU_Global, | ||
} | ||
|
||
device_type_per_device: Dict[Literal["cpu", "gpu"], dace.DeviceType] = { | ||
"cpu": dace.DeviceType.CPU, | ||
"gpu": dace.DeviceType.GPU, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have a named type for Literal["cpu", "gpu"]
? I found it typed like in the LayoutInfo
gt4py/src/gt4py/storage/cartesian/layout.py
Lines 33 to 37 in ac253b6
class LayoutInfo(TypedDict): | |
alignment: int # measured in bytes | |
device: Literal["cpu", "gpu"] | |
layout_map: Callable[[Tuple[str, ...]], Tuple[Optional[int], ...]] | |
is_optimal_layout: Callable[[Any, Tuple[str, ...]], bool] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we need an enum for this
Forward debug info from gt4py to dace if we know it. If we don't know, just don't specify instead of setting `DebugInfo(0)`. Moved `get_dace_debuginfo()` one folder higher from expansion utils into "general" dace utils because its not only used in expansion.
Transients are added in OirSDFGBuilder, where no array lifetime is configured. After building that SDFG, the lifetime of all transients is manually set to `Persistent` (which is an optimization leading to less frequent memory allocation in case a kernel is called multiple times). In this commit we directly specify the transient's lifetime when building the SDFG.
For GPU targets, we have to configure the `storage_type` for transient arrays. In addition, we have to set the library node's `device` property. We can do both while building the SDFG instead of separate passes afterwards.
553196d
to
38df291
Compare
(rebased onto |
Re: My cursory look seems to point that this is only called at top level for We should log a task to double check and we need a methodological/code way to differentiate code path we expect is in stencil mode and what is in orchestrated mode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good. Couple of question.
Do the enum
for the device type, the string is making me sad.
access_node = state.add_access(field, debuginfo=dace.DebugInfo(0)) | ||
library_node.add_in_connector("__in_" + field) | ||
access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field])) | ||
connector_name = CONNECTOR_PREFIX_IN + field |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: should we try to always use f-string
when we have string operation to have a visual easy way to pick on the type of things?
None, | ||
) | ||
for memlet in computation.write_memlets: | ||
if memlet.field not in write_acc_and_conn: | ||
write_acc_and_conn[memlet.field] = ( | ||
sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
State gets no debug info? Do you expect the bottom to be enough? (I have no idea how this works inside DaCe)
Description
Refactors from recent debugging sessions around transient arrays in the "unexpanded SDFG" (the one with the library nodes):
**kwargs
fromOirSDFGBuilder
OirSDFGBuilder
OirSDFGBuilder
(for GPU targets)LibraryNode
's device type directly inOirSDFGBuilder
(for GPU targets)Sidenote on the allocation lifetime: In the orchestrated code path, we reset the allocation lifetime of transients to
SDFG
when we freeze the stencil with origin/domaingt4py/src/gt4py/cartesian/backend/dace_backend.py
Lines 270 to 317 in ac253b6
This might be relevant when tracking down orchestration performance. Seems odd at least.
Requirements
Covered by existing tests
N/A