Skip to content

Commit

Permalink
Increase test shard_count for shape_poly_test on GPU
Browse files Browse the repository at this point in the history
Reverts changelist 723586237

FUTURE_COPYBARA_INTEGRATE_REVIEW=#25519 from emilyfertig:debug-nans e58f702
PiperOrigin-RevId: 723915109
  • Loading branch information
gnecula authored and Google-ML-Automation committed Feb 6, 2025
1 parent 0fb278a commit 52d0b04
Show file tree
Hide file tree
Showing 28 changed files with 912 additions and 128 deletions.
134 changes: 132 additions & 2 deletions docs/sharded-computation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,52 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "UEObolTqw4pp"
},
"source": [
"The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n",
"\n",
"The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.\n",
"\n",
"To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aKNeOHTJnqmS",
"outputId": "847c53ec-8b2e-4be0-f993-7fde7d77c0f2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pinned_host\n",
"device\n"
]
}
],
"source": [
"s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')\n",
"s_dev = s_host.with_memory_kind('device')\n",
"arr_host = jax.device_put(arr, s_host)\n",
"arr_dev = jax.device_put(arr, s_dev)\n",
"print(arr_host.sharding.memory_kind)\n",
"print(arr_dev.sharding.memory_kind)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jDHYnVqHwaST"
},
"source": [
"## 1. Automatic parallelism via `jit`\n",
"\n",
"Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.\n",
Expand Down Expand Up @@ -354,10 +396,98 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "Q4N5mrr9i_ki"
},
"source": [
"The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n",
"\n",
"### 1.1 Sharding transformation between memory types\n",
"\n",
"The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.\n",
"\n",
"#### Example 1: Pinned host to device memory\n",
"\n",
"In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PXu3MhafyRHo",
"outputId": "7bc6821f-a4a9-4cf8-8b21-e279d516d27b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n",
" [ 8. 9. 10. 11. 12. 13. 14. 15.]\n",
" [16. 17. 18. 19. 20. 21. 22. 23.]\n",
" [24. 25. 26. 27. 28. 29. 30. 31.]]\n",
"device\n"
]
}
],
"source": [
"f = jax.jit(lambda x: x, out_shardings=s_dev)\n",
"out_dev = f(arr_host)\n",
"print(out_dev)\n",
"print(out_dev.sharding.memory_kind)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LuYFqpcBySiX"
},
"source": [
"#### Example 2: Device to pinned_host memory\n",
"\n",
"In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qLsgNlKfybRw",
"outputId": "a16448b9-7e39-408f-b200-505f65ad4464"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n",
" [ 8. 9. 10. 11. 12. 13. 14. 15.]\n",
" [16. 17. 18. 19. 20. 21. 22. 23.]\n",
" [24. 25. 26. 27. 28. 29. 30. 31.]]\n",
"pinned_host\n"
]
}
],
"source": [
"g = jax.jit(lambda x: x, out_shardings=s_host)\n",
"out_host = g(arr_dev)\n",
"print(out_host)\n",
"print(out_host.sharding.memory_kind)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7BGD31-owaSU"
},
"source": [
"## 2. Semi-automated sharding with constraints\n",
"\n",
"If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n",
Expand Down
67 changes: 67 additions & 0 deletions docs/sharded-computation.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,31 @@ print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)
```

+++ {"id": "UEObolTqw4pp"}

The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.

The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.

To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding.

```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: aKNeOHTJnqmS
outputId: 847c53ec-8b2e-4be0-f993-7fde7d77c0f2
---
s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
s_dev = s_host.with_memory_kind('device')
arr_host = jax.device_put(arr, s_host)
arr_dev = jax.device_put(arr, s_dev)
print(arr_host.sharding.memory_kind)
print(arr_dev.sharding.memory_kind)
```

+++ {"id": "jDHYnVqHwaST"}

## 1. Automatic parallelism via `jit`

Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.
Expand Down Expand Up @@ -129,8 +152,52 @@ jax.debug.visualize_array_sharding(result)
print(result)
```

+++ {"id": "Q4N5mrr9i_ki"}

The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.

### 1.1 Sharding transformation between memory types

The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.

#### Example 1: Pinned host to device memory

In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory.

```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: PXu3MhafyRHo
outputId: 7bc6821f-a4a9-4cf8-8b21-e279d516d27b
---
f = jax.jit(lambda x: x, out_shardings=s_dev)
out_dev = f(arr_host)
print(out_dev)
print(out_dev.sharding.memory_kind)
```

+++ {"id": "LuYFqpcBySiX"}

#### Example 2: Device to pinned_host memory

In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory.

```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: qLsgNlKfybRw
outputId: a16448b9-7e39-408f-b200-505f65ad4464
---
g = jax.jit(lambda x: x, out_shardings=s_host)
out_host = g(arr_dev)
print(out_host)
print(out_host.sharding.memory_kind)
```

+++ {"id": "7BGD31-owaSU"}

## 2. Semi-automated sharding with constraints

If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.
Expand Down
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ pytype_strict_library(
":traceback_util",
":typing",
":util",
"//jax/_src/lib",
] + py_deps("ml_dtypes") + py_deps("numpy"),
)

Expand Down
50 changes: 40 additions & 10 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
zip, unsafe_zip = safe_zip, zip


@api_boundary
def _nan_check_posthook(fun, args, kwargs, output):
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
buffers = []
Expand All @@ -108,12 +109,18 @@ def _nan_check_posthook(fun, args, kwargs, output):

try:
dispatch.check_special(pjit.pjit_p.name, buffers)
except FloatingPointError:
# compiled_fun can only raise in this case
except dispatch.InternalFloatingPointError as e:
assert config.debug_nans.value or config.debug_infs.value
print("Invalid nan value encountered in the output of a C++-jit/pmap "
"function. Calling the de-optimized version.")
fun._cache_miss(*args, **kwargs)[0] # probably won't return
if hasattr(fun, '_fun'):
f = fun._fun
if getattr(f, '_apply_primitive', False):
raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}") from None
# compiled_fun can only raise in this case
dispatch.maybe_recursive_nan_check(e, f, args, kwargs)
raise AssertionError("Unreachable") from e
else:
# TODO(emilyaf): Shouldn't need this fallback.
raise

def _update_debug_special_global(_):
if config._read("jax_debug_nans") or config._read("jax_debug_infs"):
Expand Down Expand Up @@ -1574,11 +1581,14 @@ def cache_miss(*args, **kwargs):

execute: Callable | None = None
with core.take_current_trace() as trace:
if isinstance(trace, core.EvalTrace):
execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
out = execute(*p.flat_args)
else:
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)
try:
if isinstance(trace, core.EvalTrace):
execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
out = execute(*p.flat_args)
else:
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)
except dispatch.InternalFloatingPointError as e:
raise FloatingPointError(f'Invalid value ({e.ty}) encountered in parallel computation.')

out_tree, out_flat = p.out_tree, out
out_pytree_def = out_tree()
Expand Down Expand Up @@ -1629,6 +1639,7 @@ def cache_miss(*args, **kwargs):
_pmap_cache_clears.add(cpp_mapped_f)

pmap_f = wraps(fun)(cpp_mapped_f)
pmap_f._fun = fun

@api_boundary
def lower(*args, **kwargs):
Expand Down Expand Up @@ -1674,6 +1685,7 @@ def trace(*args, **kwargs):
_pmap_cache_clears = weakref.WeakSet() # type: ignore


@api_boundary
def jvp(
fun: Callable, primals, tangents, has_aux: bool = False
) -> tuple[Any, ...]:
Expand Down Expand Up @@ -1878,6 +1890,7 @@ def fun(*tangents):

return apply_flat_fun_nokwargs(fun, io_tree, py_args)

@api_boundary
def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_):
if len(py_args_) != 1:
msg = (f"The function returned by `jax.vjp` applied to {name} was called "
Expand Down Expand Up @@ -1937,6 +1950,7 @@ def vjp(fun: Callable[..., tuple[T, U]], *primals: Any,
has_aux: Literal[True],
reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable, U]:
...
@api_boundary
def vjp(
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
) -> tuple[Any, Callable] | tuple[Any, Callable, Any]:
Expand Down Expand Up @@ -2225,6 +2239,18 @@ def _infer_src_sharding(src, x) -> Sharding | None:
return None


@lru_cache(maxsize=2048)
def _check_string_compatible_sharding(s):
"""Checks if target devices are compatible with string arrays."""
if isinstance(s, xc.Device) and s.device_kind == "cpu":
return
if (isinstance(s, Sharding)
and s._internal_device_list[0].device_kind == "cpu"):
return
raise TypeError(
"String arrays can only be sharded to CPU devices. Received"
f" unsupported device or sharding: {s}")

# TODO(yashkatariya): Generalize check_compatible_aval (maybe renamed) and use
# that to check if shardings are compatible with the input.
@lru_cache(maxsize=2048)
Expand All @@ -2235,6 +2261,10 @@ def _check_sharding(aval, s):
"`jax.device_put` only accepts `None`, `jax.sharding.Sharding`,"
" `jax.Device`, `Layout` or a pytree of these values. Received"
f" invalid value: {s}")

if isinstance(aval, core.ShapedArray) and dtypes.is_string_dtype(aval.dtype):
_check_string_compatible_sharding(s)

if isinstance(s, Sharding):
if isinstance(aval, core.AbstractToken):
aval = core.get_token_aval()
Expand Down
7 changes: 5 additions & 2 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,11 +1472,14 @@ def lattice_join(x, y):

def valid_jaxtype(x) -> bool:
try:
abstractify(x)
aval = abstractify(x)
except TypeError:
return False
else:
return True
if hasattr(aval, "dtype") and dtypes.is_string_dtype(aval.dtype):
return False
else:
return True

def check_valid_jaxtype(x):
if not valid_jaxtype(x):
Expand Down
Loading

0 comments on commit 52d0b04

Please sign in to comment.