Skip to content

Commit

Permalink
Update to jax 0.4.19 (#11)
Browse files Browse the repository at this point in the history
* Update to jax 0.4.19

* Update README; Python to 3.10
  • Loading branch information
jsbrittain authored Nov 2, 2023
1 parent 1cb4c39 commit 388c8ba
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/action/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu20.04
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install -y git python3-pip cmake
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ jobs:
fetch-depth: 0

- name: Set up Python
uses: actions/setup-python@v1
uses: actions/setup-python@v4
with:
python-version: 3.8
python-version: '3.10'

- name: Install dependencies
run: |
Expand Down
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ everything that I'll talk about is covered in more detail somewhere else (even
if that somewhere is just a comment in some source code), but hopefully this
summary can point you in the right direction if you have a use case like this.

**A warning**: I'm writing this in January 2021 and much of what I'm talking
about is based on essentially undocumented APIs that are likely to change.
**A warning**: I'm writing this in January 2021 (most recent update November 2023; see
github for the full revision history) and much of what I'm talking about is based on
essentially undocumented APIs that are likely to change.
Furthermore, I'm not affiliated with the JAX project and I'm far from an expert
so I'm sure there are wrong things that I say. I'll try to update this if I
notice things changing or if I learn of issues, but no promises! So, MIT license
Expand Down Expand Up @@ -358,7 +359,7 @@ from jax.lib import xla_client
from kepler_jax import cpu_ops

for _name, _value in cpu_ops.registrations().items():
xla_client.register_cpu_custom_call_target(_name, _value)
xla_client.register_custom_call_target(_name, _value, platform="cpu")
```

Then, the **lowering rule** is defined roughly as follows (the one you'll
Expand Down Expand Up @@ -400,13 +401,13 @@ def _kepler_lowering(ctx, mean_anom, ecc):
return custom_call(
op_name,
# Output types
out_types=[dtype, dtype],
result_types=[dtype, dtype],
# The inputs:
operands=[mlir.ir_constant(size), mean_anom, ecc],
# Layout specification:
operand_layouts=[(), layout, layout],
result_layouts=[layout, layout]
)
).results

mlir.register_lowering(
_kepler_prim,
Expand Down Expand Up @@ -651,15 +652,15 @@ def _kepler_lowering_gpu(ctx, mean_anom, ecc):
return custom_call(
op_name,
# Output types
out_types=[dtype, dtype],
result_types=[dtype, dtype],
# The inputs:
operands=[mean_anom, ecc],
# Layout specification:
operand_layouts=[layout, layout],
result_layouts=[layout, layout],
# GPU-specific additional data for the kernel
backend_config=opaque
)
).results

mlir.register_lowering(
_kepler_prim,
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def build_extension(self, ext):
packages=find_packages("src"),
package_dir={"": "src"},
include_package_data=True,
install_requires=["jax", "jaxlib"],
install_requires=[
"jax>=0.4.16",
"jaxlib>=0.4.16"
],
extras_require={"test": "pytest"},
ext_modules=extensions,
cmdclass={"build_ext": CMakeBuildExt},
Expand Down
12 changes: 6 additions & 6 deletions src/kepler_jax/kepler_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from jax import core, dtypes, lax
from jax import numpy as jnp
from jax.abstract_arrays import ShapedArray
from jax.core import ShapedArray
from jax.interpreters import ad, batching, mlir, xla
from jax.lib import xla_client
from jaxlib.hlo_helpers import custom_call
Expand All @@ -16,7 +16,7 @@
from . import cpu_ops

for _name, _value in cpu_ops.registrations().items():
xla_client.register_cpu_custom_call_target(_name, _value)
xla_client.register_custom_call_target(_name, _value, platform="cpu")

# If the GPU version exists, also register those
try:
Expand Down Expand Up @@ -93,13 +93,13 @@ def _kepler_lowering(ctx, mean_anom, ecc, *, platform="cpu"):
return custom_call(
op_name,
# Output types
out_types=[dtype, dtype],
result_types=[dtype, dtype],
# The inputs:
operands=[mlir.ir_constant(size), mean_anom, ecc],
# Layout specification:
operand_layouts=[(), layout, layout],
result_layouts=[layout, layout]
)
).results

elif platform == "gpu":
if gpu_ops is None:
Expand All @@ -113,15 +113,15 @@ def _kepler_lowering(ctx, mean_anom, ecc, *, platform="cpu"):
return custom_call(
op_name,
# Output types
out_types=[dtype, dtype],
result_types=[dtype, dtype],
# The inputs:
operands=[mean_anom, ecc],
# Layout specification:
operand_layouts=[layout, layout],
result_layouts=[layout, layout],
# GPU specific additional data
backend_config=opaque
)
).results

raise ValueError(
"Unsupported platform; this must be either 'cpu' or 'gpu'"
Expand Down

0 comments on commit 388c8ba

Please sign in to comment.