Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Structure Support to NestedSDFGs and Python Frontend #1366

Merged
merged 92 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
323a23d
Added 'may_alias' property to Stucture class.
alexnick83 Sep 6, 2023
959d609
When creating copy expressions, replace dots with arrows if the root …
alexnick83 Sep 6, 2023
0c02341
When initializing the CPU code generator, specialize Structure defini…
alexnick83 Sep 6, 2023
2ccc620
Specializes how Structures are added to a nested scope. Attribute vis…
alexnick83 Sep 6, 2023
480bd9a
The sdfg submodule now exposes NestedDict to the rest of DaCe.
alexnick83 Sep 6, 2023
7e7f635
Do not create (double) pointer for Structure connectors to NestedSDFG…
alexnick83 Sep 6, 2023
36e5ed6
Use the root of the data name for NestedSDFG connector validation.
alexnick83 Sep 6, 2023
5ae8e7d
Added structure test written in Python.
alexnick83 Sep 6, 2023
cee8ece
Added writing test.
alexnick83 Sep 6, 2023
ecec25b
C++ array expression generator supports nested data.
alexnick83 Sep 6, 2023
583c7c9
Assignment visitor method supports nested data.
alexnick83 Sep 6, 2023
cc6223a
NestedDict fix for attributed lookups where the root is not a Structure.
alexnick83 Sep 6, 2023
27d1222
Fix for symbolic replacement/equality failures.
alexnick83 Sep 6, 2023
16d1244
Renamed test file
alexnick83 Sep 6, 2023
b590a6e
Changes Memlet API used.
alexnick83 Sep 6, 2023
8a0c63c
Fixes incompatibility with NestedDict.
alexnick83 Sep 6, 2023
96e763d
Datadesc names cannot have dots.
alexnick83 Sep 6, 2023
a0f02fa
Added mini-app test.
alexnick83 Sep 6, 2023
3056d49
Merge branch 'master' into frontend-add-structure-support
alexnick83 Sep 27, 2023
1b0c074
Filter symbol-mapping by used-symbols when generating nested SDFG cal…
alexnick83 Sep 27, 2023
1c097a3
Merge branch 'sym-attr' into frontend-add-structure-support
alexnick83 Sep 28, 2023
959cdeb
Updated emit memlet reference method
alexnick83 Sep 28, 2023
57cf574
Added optional property to Structures.
alexnick83 Sep 28, 2023
d2ab611
Merge branch 'master' into frontend-add-structure-support
alexnick83 Sep 28, 2023
0b9be8a
Use "used" symbols.
alexnick83 Sep 29, 2023
54232cd
Use "used" symbols
alexnick83 Sep 29, 2023
a975435
Add desc symbols.
alexnick83 Sep 29, 2023
5098a0d
GPU-global mode.
alexnick83 Sep 29, 2023
14b4a1e
Merge branch 'master' into frontend-add-structure-support
alexnick83 Oct 3, 2023
3df1a89
Merge branch 'used-symbol-fixes' into frontend-add-structure-support
alexnick83 Oct 3, 2023
4a9a0c6
Added subs method to Attr. Adjusted Attr printing in DaCeSympyPrinter…
alexnick83 Oct 3, 2023
7266d02
Transpose's pure replacement now properly supports 2D slices from ND …
alexnick83 Oct 3, 2023
663a93c
Before calling subs on a symbolic expression, add to the "filtered" s…
alexnick83 Oct 3, 2023
3308b53
InlineSDFG now replaces nested desc names with the top-level names in…
alexnick83 Oct 3, 2023
9693a4c
Attr free symbols should exclude array indexing.
alexnick83 Oct 4, 2023
efd329d
In ConstantPropagation, add to "arrays" any nested data.
alexnick83 Oct 4, 2023
ac49626
Changed parameter names in eye/identity Maps.
alexnick83 Oct 4, 2023
35d23fe
Transpose fix.
alexnick83 Oct 4, 2023
1c14f8f
Experimenting with new test.
alexnick83 Oct 4, 2023
9177525
Merge branch 'master' into frontend-add-structure-support
alexnick83 Oct 5, 2023
25f3972
Cleaned up tests.
alexnick83 Oct 5, 2023
910cde3
Added `keys` method for nested dicts and data. Improvements in findin…
alexnick83 Oct 6, 2023
6a3a6ac
Improvements in determining allocation lifetime for Structures.
alexnick83 Oct 6, 2023
9b56c3e
Improvements in replacing transient Structure names with their Python…
alexnick83 Oct 6, 2023
83f197b
Using root data in the case of Structures.
alexnick83 Oct 6, 2023
6a9c6cd
Added define local structure replacement method.
alexnick83 Oct 6, 2023
e8868c1
Added root data/desc helper methods.
alexnick83 Oct 6, 2023
d551540
Added new test.
alexnick83 Oct 6, 2023
74ccd78
Merge branch 'master' into frontend-add-structure-support
alexnick83 Oct 19, 2023
0a48922
Merge branch 'master' into frontend-add-structure-support
alexnick83 Oct 20, 2023
1e27645
Merge branch 'master' into frontend-add-structure-support
alexnick83 Oct 25, 2023
1a6737e
emit_memlet_reference method is not used any more to define nested SD…
alexnick83 Oct 25, 2023
6397025
_generate_NestedSDFG method now defines nested SDFG arguments. alloca…
alexnick83 Oct 25, 2023
1165ba2
Reworked determine_allocation_lifetime to potentially allocate nested…
alexnick83 Oct 25, 2023
734dba9
Enhanced arrays_recursive and shared_transients methods for better ne…
alexnick83 Oct 25, 2023
1843182
Enabled all tests.
alexnick83 Oct 25, 2023
29fcb2a
Merge branch 'master' into frontend-add-structure-support
alexnick83 Nov 11, 2023
8a0db59
Fixed access to structure members' keys.
alexnick83 Nov 11, 2023
6ba3651
Fixed bad merge.
alexnick83 Nov 11, 2023
ea6ea51
Don't eliminate structures that have members.
alexnick83 Nov 11, 2023
2bc215c
Disable serialization testing.
alexnick83 Nov 11, 2023
c778fa9
Fixed serialization disabling.
alexnick83 Nov 12, 2023
22e3198
Merge branch 'master' into frontend-add-structure-support
alexnick83 Nov 12, 2023
cbaade3
Don't filter by defined symbols.
alexnick83 Nov 12, 2023
3f4323b
Fixed number of values to unpack.
alexnick83 Nov 12, 2023
ce3a911
Ensure that src/dst subsets exist before using them.
alexnick83 Nov 12, 2023
5ee923b
Renamed diag to diagonal to avoid sympy clash.
alexnick83 Nov 12, 2023
24593b4
Removed property replacement.
alexnick83 Nov 12, 2023
0b885af
Switched to using subset.
alexnick83 Nov 12, 2023
e333210
Updated tests.
alexnick83 Nov 13, 2023
31ee757
Added clone and pool
alexnick83 Nov 13, 2023
8f6fd16
Updated structure codegen for CUDA
alexnick83 Nov 13, 2023
ab39d5c
Fixed number of unpacked values.
alexnick83 Nov 13, 2023
8eeb622
OpenBLAS's transpose needs float and double pointers instead of std::…
alexnick83 Nov 13, 2023
15fb33c
Clean up.
alexnick83 Nov 13, 2023
456c913
Merge branch 'master' into frontend-add-structure-support
alexnick83 Nov 14, 2023
b5160f4
Addressed review comments.
alexnick83 Nov 16, 2023
77b4b37
Merge branch 'master' into frontend-add-structure-support
alexnick83 Nov 24, 2023
affde40
Merge branch 'master' into frontend-add-structure-support
alexnick83 Dec 14, 2023
daad8fe
Removed commented out code.
alexnick83 Dec 20, 2023
068b841
Using root-data.
alexnick83 Dec 20, 2023
cf68564
Merge branch 'master' into frontend-add-structure-support
alexnick83 Dec 20, 2023
c3c2616
Removed old methods.
alexnick83 Dec 20, 2023
30ffa5e
Merge branch 'master' into frontend-add-structure-support
alexnick83 Jan 15, 2024
d03958a
Disabled serialization in covariance test.
alexnick83 Jan 18, 2024
75c6a77
Merge branch 'master' into frontend-add-structure-support
tbennun Jan 20, 2024
0665b1a
Merge branch 'master' into frontend-add-structure-support
tbennun Feb 16, 2024
5d584e1
Merge branch 'master' into frontend-add-structure-support
tbennun Feb 18, 2024
cf327a7
Merge branch 'master' into frontend-add-structure-support
alexnick83 Feb 19, 2024
c3787b6
Fixed possible wrong identation. Fixed missing nodes dictionary.
alexnick83 Feb 19, 2024
17fb666
Disabled test (temporarily, see PR #1524)
alexnick83 Feb 19, 2024
f85d3ae
Merge branch 'master' into frontend-add-structure-support
alexnick83 Feb 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def copy_expr(
packed_types=False,
):
data_desc = sdfg.arrays[data_name]
# NOTE: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs?
# TODO: Study this when changing Structures to be (optionally?) non-pointers.
tokens = data_name.split('.')
if len(tokens) > 1 and tokens[0] in sdfg.arrays and isinstance(sdfg.arrays[tokens[0]], data.Structure):
name = data_name.replace('.', '->')
else:
name = data_name
ptrname = ptr(data_name, data_desc, sdfg, dispatcher.frame)
if relative_offset:
s = memlet.subset
Expand Down Expand Up @@ -99,6 +106,7 @@ def copy_expr(
# get conf flag
decouple_array_interfaces = Config.get_bool("compiler", "xilinx", "decouple_array_interfaces")

# TODO: Study structures on FPGAs. Should probably use 'name' instead of 'data_name' here.
expr = fpga.fpga_ptr(
data_name,
data_desc,
Expand All @@ -112,7 +120,7 @@ def copy_expr(
and not isinstance(data_desc, data.View),
decouple_array_interfaces=decouple_array_interfaces)
else:
expr = ptr(data_name, data_desc, sdfg, dispatcher.frame)
expr = ptr(name, data_desc, sdfg, dispatcher.frame)

add_offset = offset_cppstr != "0"

Expand Down Expand Up @@ -344,7 +352,7 @@ def make_const(expr: str) -> str:
is_scalar = False
elif defined_type == DefinedType.Scalar:
typedef = defined_ctype if is_scalar else (defined_ctype + '*')
if is_write is False:
if is_write is False and not isinstance(desc, data.Structure):
typedef = make_const(typedef)
ref = '&' if is_scalar else ''
defined_type = DefinedType.Scalar if is_scalar else DefinedType.Pointer
Expand Down Expand Up @@ -578,17 +586,26 @@ def cpp_array_expr(sdfg,
desc = (sdfg.arrays[memlet.data] if referenced_array is None else referenced_array)
offset_cppstr = cpp_offset_expr(desc, s, o, packed_veclen, indices=indices)

# NOTE: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs?
# TODO: Study this when changing Structures to be (optionally?) non-pointers.
tokens = memlet.data.split('.')
if len(tokens) > 1 and tokens[0] in sdfg.arrays and isinstance(sdfg.arrays[tokens[0]], data.Structure):
name = memlet.data.replace('.', '->')
else:
name = memlet.data

if with_brackets:
if fpga.is_fpga_array(desc):
# get conf flag
decouple_array_interfaces = Config.get_bool("compiler", "xilinx", "decouple_array_interfaces")
# TODO: Study structures on FPGAs. Should probably use 'name' instead of 'memlet.data' here.
ptrname = fpga.fpga_ptr(memlet.data,
desc,
sdfg,
subset,
decouple_array_interfaces=decouple_array_interfaces)
else:
ptrname = ptr(memlet.data, desc, sdfg, codegen)
ptrname = ptr(name, desc, sdfg, codegen)
return "%s[%s]" % (ptrname, offset_cppstr)
else:
return offset_cppstr
Expand Down
131 changes: 75 additions & 56 deletions dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,7 @@ class CPUCodeGen(TargetCodeGenerator):
target_name = "cpu"
language = "cpp"

def __init__(self, frame_codegen, sdfg):
self._frame = frame_codegen
self._dispatcher: TargetDispatcher = frame_codegen.dispatcher
self.calling_codegen = self
dispatcher = self._dispatcher

self._locals = cppunparse.CPPLocals()
# Scope depth (for defining locals)
self._ldepth = 0

# Keep nested SDFG schedule when descending into it
self._toplevel_schedule = None

# FIXME: this allows other code generators to change the CPU
# behavior to assume that arrays point to packed types, thus dividing
# all addresess by the vector length.
self._packed_types = False

# Keep track of traversed nodes
self._generated_nodes = set()

# Keep track of generated NestedSDG, and the name of the assigned function
self._generated_nested_sdfg = dict()
def _define_sdfg_arguments(self, sdfg, arglist):

# NOTE: Multi-nesting with StructArrays must be further investigated.
def _visit_structure(struct: data.Structure, args: dict, prefix: str = ''):
Expand All @@ -66,18 +44,18 @@ def _visit_structure(struct: data.Structure, args: dict, prefix: str = ''):
args[f'{prefix}->{k}'] = v

# Keeps track of generated connectors, so we know how to access them in nested scopes
arglist = dict(self._frame.arglist)
for name, arg_type in self._frame.arglist.items():
args = dict(arglist)
for name, arg_type in arglist.items():
if isinstance(arg_type, data.Structure):
desc = sdfg.arrays[name]
_visit_structure(arg_type, arglist, name)
_visit_structure(arg_type, args, name)
elif isinstance(arg_type, data.StructArray):
desc = sdfg.arrays[name]
desc = desc.stype
_visit_structure(desc, arglist, name)
_visit_structure(desc, args, name)

for name, arg_type in arglist.items():
if isinstance(arg_type, (data.Scalar, data.Structure)):
for name, arg_type in args.items():
if isinstance(arg_type, data.Scalar):
# GPU global memory is only accessed via pointers
# TODO(later): Fix workaround somehow
if arg_type.storage is dtypes.StorageType.GPU_Global:
Expand All @@ -92,10 +70,40 @@ def _visit_structure(struct: data.Structure, args: dict, prefix: str = ''):
self._dispatcher.defined_vars.add(name, DefinedType.StreamArray, arg_type.as_arg(name=''))
else:
self._dispatcher.defined_vars.add(name, DefinedType.Stream, arg_type.as_arg(name=''))
elif isinstance(arg_type, data.Structure):
self._dispatcher.defined_vars.add(name, DefinedType.Pointer, arg_type.dtype.ctype)
else:
raise TypeError("Unrecognized argument type: {t} (value {v})".format(t=type(arg_type).__name__,
v=str(arg_type)))

def __init__(self, frame_codegen, sdfg):
self._frame = frame_codegen
self._dispatcher: TargetDispatcher = frame_codegen.dispatcher
self.calling_codegen = self
dispatcher = self._dispatcher

self._locals = cppunparse.CPPLocals()
# Scope depth (for defining locals)
self._ldepth = 0

# Keep nested SDFG schedule when descending into it
self._toplevel_schedule = None

# FIXME: this allows other code generators to change the CPU
# behavior to assume that arrays point to packed types, thus dividing
# all addresess by the vector length.
self._packed_types = False

# Keep track of traversed nodes
self._generated_nodes = set()

# Keep track of generated NestedSDG, and the name of the assigned function
self._generated_nested_sdfg = dict()

# Keeps track of generated connectors, so we know how to access them in nested scopes
arglist = dict(self._frame.arglist)
self._define_sdfg_arguments(sdfg, arglist)

# Register dispatchers
dispatcher.register_node_dispatcher(self)
dispatcher.register_map_dispatcher(
Expand Down Expand Up @@ -258,7 +266,7 @@ def declare_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, de
raise NotImplementedError("The declare_array method should only be used for variables "
"that must have their declaration and allocation separate.")

name = node.data
name = node.root_data
ptrname = cpp.ptr(name, nodedesc, sdfg, self._frame)

if nodedesc.transient is False:
Expand Down Expand Up @@ -295,23 +303,40 @@ def declare_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, de
raise NotImplementedError("Unimplemented storage type " + str(nodedesc.storage))

def allocate_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, declaration_stream,
allocation_stream):
name = node.data
alloc_name = cpp.ptr(name, nodedesc, sdfg, self._frame)
allocation_stream, allocate_nested_data: bool = True):
alloc_name = cpp.ptr(node.data, nodedesc, sdfg, self._frame)
name = alloc_name

if nodedesc.transient is False:
tokens = node.data.split('.')
top_desc = sdfg.arrays[tokens[0]]
# NOTE: Assuming here that all Structure members share transient/storage/lifetime properties.
# TODO: Study what is needed in the DaCe stack to ensure this assumption is correct.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like with the other comments, if we cannot guarantee in the stack that this assumption is an invariant across our pipeline, we may want to verify and warn

top_transient = top_desc.transient
top_storage = top_desc.storage
top_lifetime = top_desc.lifetime

if top_transient is False:
return

# Check if array is already allocated
if self._dispatcher.defined_vars.has(name):
return

# Check if array is already declared
declared = self._dispatcher.declared_arrays.has(name)

if len(tokens) > 1:
for i in range(len(tokens) - 1):
tmp_name = '.'.join(tokens[:i + 1])
tmp_alloc_name = cpp.ptr(tmp_name, sdfg.arrays[tmp_name], sdfg, self._frame)
if not self._dispatcher.defined_vars.has(tmp_alloc_name):
self.allocate_array(sdfg, dfg, state_id, nodes.AccessNode(tmp_name), sdfg.arrays[tmp_name],
function_stream, declaration_stream, allocation_stream,
allocate_nested_data=False)
declared = True
else:
# Check if array is already declared
declared = self._dispatcher.declared_arrays.has(name)

define_var = self._dispatcher.defined_vars.add
if nodedesc.lifetime in (dtypes.AllocationLifetime.Persistent, dtypes.AllocationLifetime.External):
if top_lifetime in (dtypes.AllocationLifetime.Persistent, dtypes.AllocationLifetime.External):
define_var = self._dispatcher.defined_vars.add_global
nodedesc = update_persistent_desc(nodedesc, sdfg)

Expand All @@ -324,13 +349,14 @@ def allocate_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, d
if isinstance(nodedesc, data.Structure) and not isinstance(nodedesc, data.StructureView):
declaration_stream.write(f"{nodedesc.ctype} {name} = new {nodedesc.dtype.base_type};\n")
define_var(name, DefinedType.Pointer, nodedesc.ctype)
for k, v in nodedesc.members.items():
if isinstance(v, data.Data):
ctypedef = dtypes.pointer(v.dtype).ctype if isinstance(v, data.Array) else v.dtype.ctype
defined_type = DefinedType.Scalar if isinstance(v, data.Scalar) else DefinedType.Pointer
self._dispatcher.declared_arrays.add(f"{name}->{k}", defined_type, ctypedef)
self.allocate_array(sdfg, dfg, state_id, nodes.AccessNode(f"{name}.{k}"), v, function_stream,
declaration_stream, allocation_stream)
if allocate_nested_data:
for k, v in nodedesc.members.items():
if isinstance(v, data.Data):
ctypedef = dtypes.pointer(v.dtype).ctype if isinstance(v, data.Array) else v.dtype.ctype
defined_type = DefinedType.Scalar if isinstance(v, data.Scalar) else DefinedType.Pointer
self._dispatcher.declared_arrays.add(f"{name}->{k}", defined_type, ctypedef)
self.allocate_array(sdfg, dfg, state_id, nodes.AccessNode(f"{name}.{k}"), v, function_stream,
declaration_stream, allocation_stream)
return
if isinstance(nodedesc, (data.StructureView, data.View)):
return self.allocate_view(sdfg, dfg, state_id, node, function_stream, declaration_stream, allocation_stream)
Expand Down Expand Up @@ -620,17 +646,6 @@ def _emit_copy(
#############################################
# Corner cases

# Writing one index
if (isinstance(memlet.subset, subsets.Indices) and memlet.wcr is None
and self._dispatcher.defined_vars.get(vconn)[0] == DefinedType.Scalar):
stream.write(
"%s = %s;" % (vconn, self.memlet_ctor(sdfg, memlet, dst_nodedesc.dtype, False)),
sdfg,
state_id,
[src_node, dst_node],
)
return

# Setting a reference
if isinstance(dst_nodedesc, data.Reference) and orig_vconn == 'set':
srcptr = cpp.ptr(src_node.data, src_nodedesc, sdfg, self._frame)
Expand Down Expand Up @@ -1587,6 +1602,10 @@ def _generate_NestedSDFG(
self._dispatcher.defined_vars.enter_scope(sdfg, can_access_parent=inline)
state_dfg = sdfg.nodes()[state_id]

fsyms = self._frame.free_symbols(node.sdfg)
arglist = node.sdfg.arglist(scalars_only=False, free_symbols=fsyms)
self._define_sdfg_arguments(node.sdfg, arglist)

# Quick sanity check.
# TODO(later): Is this necessary or "can_access_parent" should always be False?
if inline:
Expand Down
6 changes: 4 additions & 2 deletions dace/codegen/targets/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,10 +1023,12 @@
if issubclass(node_dtype.type, ctypes.Structure):
callsite_stream.write('for (size_t __idx = 0; __idx < {arrlen}; ++__idx) '
'{{'.format(arrlen=array_length))
for field_name, field_type in node_dtype._data.items():
# TODO: Study further when tackling Structures on GPU.
for field_name, field_type in node_dtype._typeclass.fields.items():

Check warning on line 1027 in dace/codegen/targets/cuda.py

View check run for this annotation

Codecov / codecov/patch

dace/codegen/targets/cuda.py#L1027

Added line #L1027 was not covered by tests
if isinstance(field_type, dtypes.pointer):
tclass = field_type.type
length = node_dtype._length[field_name]

length = node_dtype._typeclass._length[field_name]

Check warning on line 1031 in dace/codegen/targets/cuda.py

View check run for this annotation

Codecov / codecov/patch

dace/codegen/targets/cuda.py#L1031

Added line #L1031 was not covered by tests
size = 'sizeof({})*{}[__idx].{}'.format(dtypes._CTYPES[tclass], str(src_node), length)
callsite_stream.write('DACE_GPU_CHECK({backend}Malloc(&{dst}[__idx].{fname}, '
'{sz}));'.format(dst=str(dst_node),
Expand Down
26 changes: 16 additions & 10 deletions dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG):
reachability = StateReachability().apply_pass(top_sdfg, {})
access_instances: Dict[int, Dict[str, List[Tuple[SDFGState, nodes.AccessNode]]]] = {}
for sdfg in top_sdfg.all_sdfgs_recursive():
shared_transients[sdfg.sdfg_id] = sdfg.shared_transients(check_toplevel=False)
shared_transients[sdfg.sdfg_id] = sdfg.shared_transients(check_toplevel=False, include_nested_data=True)
fsyms[sdfg.sdfg_id] = self.symbols_and_constants(sdfg)

#############################################
Expand All @@ -564,8 +564,14 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG):

access_instances[sdfg.sdfg_id] = instances

for sdfg, name, desc in top_sdfg.arrays_recursive():
if not desc.transient:
for sdfg, name, desc in top_sdfg.arrays_recursive(include_nested_data=True):
# NOTE: Assuming here that all Structure members share transient/storage/lifetime properties.
# TODO: Study what is needed in the DaCe stack to ensure this assumption is correct.
top_desc = sdfg.arrays[name.split('.')[0]]
top_transient = top_desc.transient
top_storage = top_desc.storage
top_lifetime = top_desc.lifetime
if not top_transient:
continue
if name in sdfg.constants_prop:
# Constants do not need to be allocated
Expand All @@ -589,7 +595,7 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG):
access_instances[sdfg.sdfg_id].get(name, [(None, None)])[-1]

# Cases
if desc.lifetime in (dtypes.AllocationLifetime.Persistent, dtypes.AllocationLifetime.External):
if top_lifetime in (dtypes.AllocationLifetime.Persistent, dtypes.AllocationLifetime.External):
# Persistent memory is allocated in initialization code and
# exists in the library state structure

Expand All @@ -599,13 +605,13 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG):

definition = desc.as_arg(name=f'__{sdfg.sdfg_id}_{name}') + ';'

if desc.storage != dtypes.StorageType.CPU_ThreadLocal: # If thread-local, skip struct entry
if top_storage != dtypes.StorageType.CPU_ThreadLocal: # If thread-local, skip struct entry
self.statestruct.append(definition)

self.to_allocate[top_sdfg].append((sdfg, first_state_instance, first_node_instance, True, True, True))
self.where_allocated[(sdfg, name)] = top_sdfg
continue
elif desc.lifetime is dtypes.AllocationLifetime.Global:
elif top_lifetime is dtypes.AllocationLifetime.Global:
# Global memory is allocated in the beginning of the program
# exists in the library state structure (to be passed along
# to the right SDFG)
Expand All @@ -627,15 +633,15 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG):
# a kernel).
alloc_scope: Union[nodes.EntryNode, SDFGState, SDFG] = None
alloc_state: SDFGState = None
if (name in shared_transients[sdfg.sdfg_id] or desc.lifetime is dtypes.AllocationLifetime.SDFG):
if (name in shared_transients[sdfg.sdfg_id] or top_lifetime is dtypes.AllocationLifetime.SDFG):
# SDFG descriptors are allocated in the beginning of their SDFG
alloc_scope = sdfg
if first_state_instance is not None:
alloc_state = first_state_instance
# If unused, skip
if first_node_instance is None:
continue
elif desc.lifetime == dtypes.AllocationLifetime.State:
elif top_lifetime == dtypes.AllocationLifetime.State:
# State memory is either allocated in the beginning of the
# containing state or the SDFG (if used in more than one state)
curstate: SDFGState = None
Expand All @@ -651,7 +657,7 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG):
else:
alloc_scope = curstate
alloc_state = curstate
elif desc.lifetime == dtypes.AllocationLifetime.Scope:
elif top_lifetime == dtypes.AllocationLifetime.Scope:
# Scope memory (default) is either allocated in the innermost
# scope (e.g., Map, Consume) it is used in (i.e., greatest
# common denominator), or in the SDFG if used in multiple states
Expand All @@ -671,7 +677,7 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG):
for node in state.nodes():
if not isinstance(node, nodes.AccessNode):
continue
if node.data != name:
if node.root_data != name:
continue

# If already found in another state, set scope to SDFG
Expand Down
Loading
Loading