Skip to content

Commit

Permalink
Add type hints/simplify kernel_theoretical_timing
Browse files Browse the repository at this point in the history
Adding type hints allowed to simplify `kernel_theoretical_timing`.
  • Loading branch information
Roman Cattaneo committed Jan 6, 2025
1 parent 63aca76 commit d4e065c
Showing 1 changed file with 25 additions and 23 deletions.
48 changes: 25 additions & 23 deletions ndsl/dsl/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@
class DaCeProgress:
"""Rough timer & log for major operations of DaCe build stack."""

def __init__(self, config: DaceConfig, label: str):
def __init__(self, config: DaceConfig, label: str) -> None:
self.prefix = DaCeProgress.default_prefix(config)
self.label = label

@classmethod
def default_prefix(cls, config: DaceConfig) -> str:
return f"[{config.get_orchestrate()}]"

def __enter__(self):
def __enter__(self) -> None:
ndsl_log.debug(f"{self.prefix} {self.label}...")
self.start = time.time()

def __exit__(self, _type, _val, _traceback):
def __exit__(self, _type, _val, _traceback) -> None:
elapsed = time.time() - self.start
ndsl_log.debug(f"{self.prefix} {self.label}...{elapsed}s.")

Expand Down Expand Up @@ -133,7 +133,7 @@ def memory_static_analysis(
def report_memory_static_analysis(
sdfg: dace.sdfg.SDFG,
allocations: Dict[dace.StorageType, StorageReport],
detail_report=False,
detail_report: bool = False,
) -> str:
"""Create a human readable report form the memory analysis results"""
report = f"{sdfg.name}:\n"
Expand Down Expand Up @@ -168,7 +168,9 @@ def report_memory_static_analysis(
return report


def memory_static_analysis_from_path(sdfg_path: str, detail_report=False) -> str:
def memory_static_analysis_from_path(
sdfg_path: str, detail_report: bool = False
) -> str:
"""Open a SDFG and report the memory analysis"""
sdfg = dace.SDFG.from_file(sdfg_path)
return report_memory_static_analysis(
Expand All @@ -181,7 +183,7 @@ def memory_static_analysis_from_path(sdfg_path: str, detail_report=False) -> str
# ----------------------------------------------------------
# Theoretical bandwidth from SDFG
# ----------------------------------------------------------
def copy_kernel(q_in: FloatField, q_out: FloatField):
def copy_kernel(q_in: FloatField, q_out: FloatField) -> None:
with computation(PARALLEL), interval(...):
q_in = q_out

Expand All @@ -203,15 +205,15 @@ def __init__(self, size, backend) -> None:
)
orchestrate(obj=self, config=dace_config)

def __call__(self, A, B, n: int):
def __call__(self, A, B, n: int) -> None:
for i in dace.nounroll(range(n)):
self.copy_stencil(A, B)


def kernel_theoretical_timing(
sdfg: dace.sdfg.SDFG,
hardware_bw_in_GB_s=None,
backend=None,
hardware_bw_in_GB_s: Optional[float] = None,
backend: Optional[str] = None,
) -> Dict[str, float]:
"""Compute a lower timing bound for kernels with the following hypothesis:
Expand All @@ -221,7 +223,7 @@ def kernel_theoretical_timing(
- Memory pressure is mostly in read/write from global memory, inner scalar & shared
memory is not counted towards memory movement.
"""
if not hardware_bw_in_GB_s:
if hardware_bw_in_GB_s is None:
size = np.array(sdfg.arrays["__g_self__w"].shape)
print(
f"Calculating experimental hardware bandwidth on {size}"
Expand All @@ -246,13 +248,19 @@ def kernel_theoretical_timing(
bench(A, B, n)
dt.append((time.time() - s) / n)
memory_size_in_b = np.prod(size) * np.dtype(Float).itemsize * 8
bandwidth_in_bytes_s = memory_size_in_b / np.median(dt)
print(
f"Hardware bandwidth computed: {bandwidth_in_bytes_s/(1024*1024*1024)} GB/s"
)
else:
bandwidth_in_bytes_s = hardware_bw_in_GB_s * 1024 * 1024 * 1024
print(f"Given hardware bandwidth: {bandwidth_in_bytes_s/(1024*1024*1024)} GB/s")
measured_bandwidth_in_bytes_s = memory_size_in_b / np.median(dt)

bandwidth_in_bytes_s = (
measured_bandwidth_in_bytes_s
if hardware_bw_in_GB_s is None
else hardware_bw_in_GB_s * 1024 * 1024 * 1024
)
label = (
"Hardware bandwidth computed"
if hardware_bw_in_GB_s
else "Given hardware bandwidth"
)
print(f"{label}: {bandwidth_in_bytes_s/(1024*1024*1024)} GB/s")

allmaps = [
(me, state)
Expand Down Expand Up @@ -305,12 +313,6 @@ def kernel_theoretical_timing(
except TypeError:
pass

# Bad expansion
if not isinstance(newresult_in_us, sympy.core.numbers.Float) and not isinstance(
newresult_in_us, float
):
continue

result[node.label] = float(newresult_in_us)

return result
Expand Down

0 comments on commit d4e065c

Please sign in to comment.