Skip to content

Commit

Permalink
Break out aot_test from array_test (for serialization and other aot A…
Browse files Browse the repository at this point in the history
…PIs).

PiperOrigin-RevId: 521568985
  • Loading branch information
pschuh authored and jax authors committed Apr 3, 2023
1 parent 78678ee commit c2b15a1
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 36 deletions.
9 changes: 9 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,15 @@ jax_test(
],
)

jax_test(
name = "aot_test",
srcs = ["aot_test.py"],
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",
],
)

jax_test(
name = "image_test",
srcs = ["image_test.py"],
Expand Down
102 changes: 102 additions & 0 deletions tests/aot_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for GlobalDeviceArray."""

import contextlib
import os
import unittest
from absl.testing import absltest
import numpy as np

import jax
from jax._src import core
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax.experimental.pjit import pjit
from jax.experimental.serialize_executable import (
serialize, deserialize_and_load)
from jax.sharding import PartitionSpec as P

from jax.config import config
config.parse_flags_with_absl()

prev_xla_flags = None

with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator


# Run all tests with 8 CPU devices.
def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
# Don't override user-specified device count, or other XLA flags.
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str +
" --xla_force_host_platform_device_count=8")
# Clear any cached backends so new CPU backend will pick up the env var.
xb.get_backend.cache_clear()

# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
if prev_xla_flags is None:
del os.environ["XLA_FLAGS"]
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
xb.get_backend.cache_clear()


class JaxAotTest(jtu.JaxTestCase):

def check_for_compile_options(self):
example_exe = jax.jit(lambda x: x * x).lower(
core.ShapedArray(
(2, 2), dtype=np.float32)).compile()._executable.xla_executable

# Skip if CompileOptions is not available. This is true on
# CPU/GPU/Cloud TPU for now.
try:
example_exe.compile_options()
except Exception as e:
if str(e) == 'UNIMPLEMENTED: CompileOptions not available.':
raise unittest.SkipTest('Serialization not supported')
raise e

def test_pickle_pjit_lower(self):
self.check_for_compile_options()

def fun(x):
return x * x

with jax.sharding.Mesh(np.array(jax.devices()), ('data',)):
lowered = pjit(
fun, in_shardings=P('data'), out_shardings=P(None, 'data')
).lower(core.ShapedArray(shape=(8, 8), dtype=np.float32))

def verify_serialization(lowered):
serialized, in_tree, out_tree = serialize(lowered.compile())
compiled = deserialize_and_load(serialized, in_tree, out_tree)
self.assertEqual(compiled.as_text(), lowered.compile().as_text())

verify_serialization(lowered)
verify_serialization(jax.jit(lambda x: x * x).lower(np.arange(100)))
verify_serialization(
jax.pmap(lambda x: x * x).lower(
np.zeros((len(jax.devices()), 4), dtype=np.float32)))


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())
37 changes: 1 addition & 36 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import contextlib
import math
import os
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
Expand All @@ -31,8 +30,6 @@
from jax._src.util import safe_zip
from jax.interpreters import pxla
from jax.experimental.pjit import pjit
from jax.experimental.serialize_executable import (
serialize, deserialize_and_load)
from jax.experimental import multihost_utils
from jax.sharding import PartitionSpec as P
from jax._src import array
Expand Down Expand Up @@ -104,7 +101,7 @@ def test_jax_array_value(self, mesh_axes):

@parameterized.named_parameters(
("mesh_x_y", P("x", "y"),
# There are more slices but for convienient purposes, checking for only
# There are more slices but for convenient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
(2, 1),
Expand Down Expand Up @@ -1018,38 +1015,6 @@ def f(x):
y_ref1 = f(jax.device_put(x, jax.devices()[0]))
self.assertArraysEqual(y, y_ref1)

def test_pickle_pjit_lower(self):
example_exe = jax.jit(lambda x: x * x).lower(
core.ShapedArray(
(2, 2), dtype=np.float32)).compile()._executable.xla_executable

# Skip if CompileOptions is not available. This is true on
# CPU/GPU/Cloud TPU for now.
try:
example_exe.compile_options()
except Exception as e:
if str(e) == 'UNIMPLEMENTED: CompileOptions not available.':
raise unittest.SkipTest('Serialization not supported')
raise e

def fun(x):
return x * x

with jax.sharding.Mesh(np.array(jax.devices()), ('data',)):
lowered = pjit(
fun, in_shardings=P('data'), out_shardings=P(None, 'data')
).lower(core.ShapedArray(shape=(8, 8), dtype=np.float32))

def verify_serialization(lowered):
serialized, in_tree, out_tree = serialize(lowered.compile())
compiled = deserialize_and_load(serialized, in_tree, out_tree)
self.assertEqual(compiled.as_text(), lowered.compile().as_text())

verify_serialization(lowered)
verify_serialization(jax.jit(lambda x: x * x).lower(np.arange(100)))
verify_serialization(
jax.pmap(lambda x: x * x).lower(
np.zeros((len(jax.devices()), 4), dtype=np.float32)))

if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit c2b15a1

Please sign in to comment.