From b7b8a49b3522c90649d6b38aa3466322444ace3d Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Mon, 12 Aug 2024 18:09:55 +0100 Subject: [PATCH 1/4] Update xla to use mlir rather than backend-specific-translations --- envpool/python/xla_template.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/envpool/python/xla_template.py b/envpool/python/xla_template.py index 4238f945..ae934a12 100644 --- a/envpool/python/xla_template.py +++ b/envpool/python/xla_template.py @@ -23,6 +23,7 @@ from jax.core import ShapedArray from jax.interpreters import xla from jax.lib import xla_client +from jax import interpreters def _shape_with_layout( @@ -91,10 +92,10 @@ def translation(c: Any, *args: Any, platform: str = "cpu") -> Any: prim.multiple_results = (len(out_specs) > 1) prim.def_impl(partial(xla.apply_primitive, prim)) prim.def_abstract_eval(abstract) - xla.backend_specific_translations["cpu"][prim] = partial( + interpreters.mlir["cpu"][prim] = partial( translation, platform="cpu" ) - xla.backend_specific_translations["gpu"][prim] = partial( + interpreters.mlir["gpu"][prim] = partial( translation, platform="gpu" ) From e69ebb8ef0c536734c40f936818f6b85fa18bcfd Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Mon, 12 Aug 2024 18:20:07 +0100 Subject: [PATCH 2/4] run isort --- envpool/python/xla_template.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/envpool/python/xla_template.py b/envpool/python/xla_template.py index ae934a12..6aad7e48 100644 --- a/envpool/python/xla_template.py +++ b/envpool/python/xla_template.py @@ -18,12 +18,11 @@ from typing import Any, Callable, List, Tuple, Union import numpy as np -from jax import core, dtypes +from jax import core, dtypes, interpreters from jax import numpy as jnp from jax.core import ShapedArray from jax.interpreters import xla from jax.lib import xla_client -from jax import interpreters def _shape_with_layout( From 18e6ac2247adb247ed4d71408fdd41507f078a58 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Mon, 12 Aug 2024 18:23:33 +0100 Subject: [PATCH 3/4] run yapf --- envpool/python/xla_template.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/envpool/python/xla_template.py b/envpool/python/xla_template.py index 6aad7e48..5894d2c6 100644 --- a/envpool/python/xla_template.py +++ b/envpool/python/xla_template.py @@ -91,12 +91,8 @@ def translation(c: Any, *args: Any, platform: str = "cpu") -> Any: prim.multiple_results = (len(out_specs) > 1) prim.def_impl(partial(xla.apply_primitive, prim)) prim.def_abstract_eval(abstract) - interpreters.mlir["cpu"][prim] = partial( - translation, platform="cpu" - ) - interpreters.mlir["gpu"][prim] = partial( - translation, platform="gpu" - ) + interpreters.mlir["cpu"][prim] = partial(translation, platform="cpu") + interpreters.mlir["gpu"][prim] = partial(translation, platform="gpu") def call(*args: Any) -> Any: return prim.bind(*args) From bc5b0c4087137073599cc89d843845e5b46da832 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Mon, 12 Aug 2024 19:27:52 +0100 Subject: [PATCH 4/4] Fix implementation --- envpool/python/xla_template.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/envpool/python/xla_template.py b/envpool/python/xla_template.py index 5894d2c6..2fd802b8 100644 --- a/envpool/python/xla_template.py +++ b/envpool/python/xla_template.py @@ -18,10 +18,10 @@ from typing import Any, Callable, List, Tuple, Union import numpy as np -from jax import core, dtypes, interpreters +from jax import core, dtypes from jax import numpy as jnp from jax.core import ShapedArray -from jax.interpreters import xla +from jax.interpreters import mlir, xla from jax.lib import xla_client @@ -91,8 +91,7 @@ def translation(c: Any, *args: Any, platform: str = "cpu") -> Any: prim.multiple_results = (len(out_specs) > 1) prim.def_impl(partial(xla.apply_primitive, prim)) prim.def_abstract_eval(abstract) - interpreters.mlir["cpu"][prim] = partial(translation, platform="cpu") - interpreters.mlir["gpu"][prim] = partial(translation, platform="gpu") + mlir.register_lowering(prim, translation) def call(*args: Any) -> Any: return prim.bind(*args)