From 90727783183c3c780da21976d00802957e3bf81c Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 22 Jan 2025 15:24:49 +0100 Subject: [PATCH 1/5] fix for type inference --- src/gt4py/next/iterator/type_system/type_synthesizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index f5aeac7943..1d37e6e79a 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -149,7 +149,9 @@ def if_(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType )(functools.partial(if_, pred))(true_branch, false_branch) assert not isinstance(true_branch, ts.TupleType) and not isinstance(false_branch, ts.TupleType) - assert isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL + assert isinstance(pred, ts.DeferredType) or ( + isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL + ) # TODO(tehrengruber): Enable this or a similar check. In case the true- and false-branch are # iterators defined on different positions this fails. For the GTFN backend we also don't # want this, but for roundtrip it is totally fine. From 6c4ad129f37b56e6b32105c2899ef117e765b7a8 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 31 Jan 2025 07:55:57 +0100 Subject: [PATCH 2/5] Fix typing --- src/gt4py/next/iterator/type_system/type_synthesizer.py | 2 +- uv.lock | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 1d37e6e79a..37308fdf48 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -141,7 +141,7 @@ def can_deref(it: it_ts.IteratorType | ts.DeferredType) -> ts.ScalarType: @_register_builtin_type_synthesizer -def if_(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType: +def if_(pred: ts.ScalarType | ts.DeferredType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType: if isinstance(true_branch, ts.TupleType) and isinstance(false_branch, ts.TupleType): return tree_map( collection_type=ts.TupleType, diff --git a/uv.lock b/uv.lock index 1d050717af..2165abd59e 100644 --- a/uv.lock +++ b/uv.lock @@ -384,7 +384,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -681,7 +681,7 @@ dependencies = [ { name = "numpy" }, { name = "packaging" }, { name = "ply" }, - { name = "pyreadline", marker = "platform_system == 'Windows' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "pyreadline", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "pyyaml" }, { name = "sympy" }, ] @@ -998,7 +998,6 @@ wheels = [ [[package]] name = "gt4py" -version = "1.0.4" source = { editable = "." } dependencies = [ { name = "attrs" }, @@ -1359,7 +1358,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "platform_system == 'Darwin' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "appnope", marker = "sys_platform == 'darwin' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, From 947b12ce333462623885549145e9fabca4016cdb Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 31 Jan 2025 08:02:34 +0100 Subject: [PATCH 3/5] Fix format --- src/gt4py/next/iterator/type_system/type_synthesizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 37308fdf48..19ab3ecdda 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -141,7 +141,9 @@ def can_deref(it: it_ts.IteratorType | ts.DeferredType) -> ts.ScalarType: @_register_builtin_type_synthesizer -def if_(pred: ts.ScalarType | ts.DeferredType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType: +def if_( + pred: ts.ScalarType | ts.DeferredType, true_branch: ts.DataType, false_branch: ts.DataType +) -> ts.DataType: if isinstance(true_branch, ts.TupleType) and isinstance(false_branch, ts.TupleType): return tree_map( collection_type=ts.TupleType, From ad1fec8d4e51c20311941ed1f69b7af30dd7cea6 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 31 Jan 2025 08:14:51 +0100 Subject: [PATCH 4/5] Fix format --- uv.lock | 1 + 1 file changed, 1 insertion(+) diff --git a/uv.lock b/uv.lock index 2165abd59e..212c398e0c 100644 --- a/uv.lock +++ b/uv.lock @@ -998,6 +998,7 @@ wheels = [ [[package]] name = "gt4py" +version = "1.0.4" source = { editable = "." } dependencies = [ { name = "attrs" }, From aff54dbdb7b26d46783166774ecd08612a9656bb Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 31 Jan 2025 10:43:25 +0100 Subject: [PATCH 5/5] Revert uv.lock changes --- uv.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/uv.lock b/uv.lock index 212c398e0c..1d050717af 100644 --- a/uv.lock +++ b/uv.lock @@ -384,7 +384,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "colorama", marker = "platform_system == 'Windows' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -681,7 +681,7 @@ dependencies = [ { name = "numpy" }, { name = "packaging" }, { name = "ply" }, - { name = "pyreadline", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "pyreadline", marker = "platform_system == 'Windows' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "pyyaml" }, { name = "sympy" }, ] @@ -1359,7 +1359,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "sys_platform == 'darwin' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "appnope", marker = "platform_system == 'Darwin' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" },