From 388c8ba5d8f2f65eb12019f2f388cea22102a241 Mon Sep 17 00:00:00 2001 From: jsbrittain <98161205+jsbrittain@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:50:57 +0000 Subject: [PATCH] Update to jax 0.4.19 (#11) * Update to jax 0.4.19 * Update README; Python to 3.10 --- .github/action/Dockerfile | 2 +- .github/workflows/tests.yml | 4 ++-- README.md | 15 ++++++++------- setup.py | 5 ++++- src/kepler_jax/kepler_jax.py | 12 ++++++------ 5 files changed, 21 insertions(+), 17 deletions(-) diff --git a/.github/action/Dockerfile b/.github/action/Dockerfile index a3de759..9279506 100644 --- a/.github/action/Dockerfile +++ b/.github/action/Dockerfile @@ -1,4 +1,4 @@ -FROM nvidia/cuda:11.8.0-devel-ubuntu20.04 +FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install -y git python3-pip cmake diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 52be497..266a102 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -21,9 +21,9 @@ jobs: fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: '3.10' - name: Install dependencies run: | diff --git a/README.md b/README.md index 157cc19..601e6ef 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,9 @@ everything that I'll talk about is covered in more detail somewhere else (even if that somewhere is just a comment in some source code), but hopefully this summary can point you in the right direction if you have a use case like this. -**A warning**: I'm writing this in January 2021 and much of what I'm talking -about is based on essentially undocumented APIs that are likely to change. +**A warning**: I'm writing this in January 2021 (most recent update November 2023; see +github for the full revision history) and much of what I'm talking about is based on +essentially undocumented APIs that are likely to change. Furthermore, I'm not affiliated with the JAX project and I'm far from an expert so I'm sure there are wrong things that I say. I'll try to update this if I notice things changing or if I learn of issues, but no promises! So, MIT license @@ -358,7 +359,7 @@ from jax.lib import xla_client from kepler_jax import cpu_ops for _name, _value in cpu_ops.registrations().items(): - xla_client.register_cpu_custom_call_target(_name, _value) + xla_client.register_custom_call_target(_name, _value, platform="cpu") ``` Then, the **lowering rule** is defined roughly as follows (the one you'll @@ -400,13 +401,13 @@ def _kepler_lowering(ctx, mean_anom, ecc): return custom_call( op_name, # Output types - out_types=[dtype, dtype], + result_types=[dtype, dtype], # The inputs: operands=[mlir.ir_constant(size), mean_anom, ecc], # Layout specification: operand_layouts=[(), layout, layout], result_layouts=[layout, layout] - ) + ).results mlir.register_lowering( _kepler_prim, @@ -651,7 +652,7 @@ def _kepler_lowering_gpu(ctx, mean_anom, ecc): return custom_call( op_name, # Output types - out_types=[dtype, dtype], + result_types=[dtype, dtype], # The inputs: operands=[mean_anom, ecc], # Layout specification: @@ -659,7 +660,7 @@ def _kepler_lowering_gpu(ctx, mean_anom, ecc): result_layouts=[layout, layout], # GPU-specific additional data for the kernel backend_config=opaque - ) + ).results mlir.register_lowering( _kepler_prim, diff --git a/setup.py b/setup.py index b7e5f9c..7e9a341 100644 --- a/setup.py +++ b/setup.py @@ -120,7 +120,10 @@ def build_extension(self, ext): packages=find_packages("src"), package_dir={"": "src"}, include_package_data=True, - install_requires=["jax", "jaxlib"], + install_requires=[ + "jax>=0.4.16", + "jaxlib>=0.4.16" + ], extras_require={"test": "pytest"}, ext_modules=extensions, cmdclass={"build_ext": CMakeBuildExt}, diff --git a/src/kepler_jax/kepler_jax.py b/src/kepler_jax/kepler_jax.py index 93bfd47..9bd0ce7 100644 --- a/src/kepler_jax/kepler_jax.py +++ b/src/kepler_jax/kepler_jax.py @@ -7,7 +7,7 @@ import numpy as np from jax import core, dtypes, lax from jax import numpy as jnp -from jax.abstract_arrays import ShapedArray +from jax.core import ShapedArray from jax.interpreters import ad, batching, mlir, xla from jax.lib import xla_client from jaxlib.hlo_helpers import custom_call @@ -16,7 +16,7 @@ from . import cpu_ops for _name, _value in cpu_ops.registrations().items(): - xla_client.register_cpu_custom_call_target(_name, _value) + xla_client.register_custom_call_target(_name, _value, platform="cpu") # If the GPU version exists, also register those try: @@ -93,13 +93,13 @@ def _kepler_lowering(ctx, mean_anom, ecc, *, platform="cpu"): return custom_call( op_name, # Output types - out_types=[dtype, dtype], + result_types=[dtype, dtype], # The inputs: operands=[mlir.ir_constant(size), mean_anom, ecc], # Layout specification: operand_layouts=[(), layout, layout], result_layouts=[layout, layout] - ) + ).results elif platform == "gpu": if gpu_ops is None: @@ -113,7 +113,7 @@ def _kepler_lowering(ctx, mean_anom, ecc, *, platform="cpu"): return custom_call( op_name, # Output types - out_types=[dtype, dtype], + result_types=[dtype, dtype], # The inputs: operands=[mean_anom, ecc], # Layout specification: @@ -121,7 +121,7 @@ def _kepler_lowering(ctx, mean_anom, ecc, *, platform="cpu"): result_layouts=[layout, layout], # GPU specific additional data backend_config=opaque - ) + ).results raise ValueError( "Unsupported platform; this must be either 'cpu' or 'gpu'"