From b7105ccd195241db8735c3fb533b5516704342e2 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 19 Jul 2024 20:22:21 +0300 Subject: [PATCH] [pallas] Fix the handling of captured consts There was an attempt to handle consts captured by the kernel, but it was incomplete and with errors: the calling convention was wrong, and the support for handling consts along with scalar prefetch and scratch values was incomplete. I expanded the tests: one in pallas_tests.py and two tests in tpu_pallas_test.py (to handle scalar prefetch, with and without scratch inputs). The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`. This is different from before (`*consts, *scalar_refs, *ins, ...`) so that it keeps the block arguments (consts, ins, outs) together and makes it easier to write the lowering. I will follow up with a cleanup PR for the handling of grid_mapping. Here I attempted to minimize the changes. --- jax/_src/pallas/core.py | 6 +- jax/_src/pallas/mosaic/lowering.py | 10 +-- .../pallas/mosaic/pallas_call_registration.py | 1 + .../mosaic_gpu/pallas_call_registration.py | 6 +- jax/_src/pallas/pallas_call.py | 61 +++++++++++++++---- jax/_src/pallas/triton/lowering.py | 8 ++- .../pallas/triton/pallas_call_registration.py | 5 ++ jax/_src/state/utils.py | 13 +++- tests/pallas/pallas_test.py | 29 ++++++--- tests/pallas/pallas_vmap_test.py | 24 +++++++- tests/pallas/tpu_pallas_test.py | 52 +++++++++++++++- 11 files changed, 183 insertions(+), 32 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index c088fbcfe1d3..ae5f267d2fbb 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -245,7 +245,7 @@ class BlockMapping: index_map_jaxpr: jax_core.ClosedJaxpr indexing_mode: IndexingMode - def compute_start_indices(self, loop_idx, *args): + def compute_start_indices_interpret(self, loop_idx, *args): discharged_jaxpr, discharged_consts = state_discharge.discharge_state( self.index_map_jaxpr.jaxpr, self.index_map_jaxpr.consts ) @@ -344,6 +344,10 @@ def _convert_block_spec_to_block_mapping( f"{len(aval.shape)} values to match {block_shape=}. " f"Currently returning {len(out_avals)} values." ) + if consts: + raise NotImplementedError( + f"Index map for {what}{tree_util.keystr(path)} captures constants: " + f"{consts}") return BlockMapping( block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode ) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 8177c7f81a39..f4aa9a70f351 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -274,7 +274,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pl_core.GridMapping, self.mapped_dims = grid_mapping.mapped_dims num_scalar_prefetch = grid_mapping.num_index_operands num_scratch = grid_mapping.num_scratch_operands - # jaxpr has signature [*scalar_prefetch, *in_ops *out_ops, *scratch] + # jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch] num_operands = ( len(self.jaxpr.invars) - num_scalar_prefetch @@ -411,9 +411,13 @@ def lower_jaxpr_to_module( grid_mapping.num_index_operands:-grid_mapping.num_scratch_operands] else: invars = invars[grid_mapping.num_index_operands:] + # invars now = *consts, *ins, *outs avals = tuple(v.aval for v in invars) block_operand_shapes = ( - *in_shapes[grid_mapping.num_index_operands :], + *[jax.ShapeDtypeStruct(v.aval.shape, + v.aval.dtype) + for v in invars[:grid_mapping.num_constant_operands]], + *in_shapes[grid_mapping.num_index_operands:], *out_shapes, ) assert len(block_operand_shapes) == len(grid_mapping.block_mappings) @@ -425,8 +429,6 @@ def lower_jaxpr_to_module( raise NotImplementedError( "BlockSpecs are required on TPU when grid is specified" ) - if bm.index_map_jaxpr.consts: - raise NotImplementedError("Index map jaxpr with consts not supported.") # ANY operands don't support windowing and require empty window_params. if aval.memory_space == tpu_core.TPUMemorySpace.ANY: # We may not require windowing if our block_shape matches the original diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 70d9cc469008..9e2486eb1581 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -76,6 +76,7 @@ def pallas_call_tpu_lowering_rule( compiler_params: dict[str, Any]): """Lowers a pallas_call to a Mosaic TPU custom call.""" if interpret: + # TODO(necula): is this branch still needed? return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes, in_shapes=in_shapes, diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 740f0c31ebb7..6aca2b74ef04 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -42,6 +42,7 @@ def pallas_call_lowering( compiler_params: dict[str, Any], ): if interpret: + # TODO(necula): is this still needed? return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( ctx, *args, @@ -68,7 +69,10 @@ def pallas_call_lowering( if debug: print(jaxpr) print(grid_mapping) - + if grid_mapping.num_constant_operands: + raise NotImplementedError( + "captured consts not supported in the Mosaic GPU backend" + ) lowering_result = lowering.lower_jaxpr_to_module( grid_mapping, in_shapes, diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 359a42e8647b..d185133c100d 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -187,12 +187,13 @@ def _pallas_call_impl_interpret( ) assert next(dynamic_grid_args_iter, None) is None with grid_mapping.trace_env(): - discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ()) + discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ()) if debug: print(discharged_jaxpr) out = _initialize_output_vals(out_shapes, args, input_output_aliases) scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore - # invars: [*scalar_prefetch, *inputs, *outputs, *scratch] + # invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch] + # args now contains: *consts, *inputs, *outputs num_invars = len(jaxpr.invars) num_inputs_outputs = ( num_invars @@ -243,6 +244,10 @@ def _pallas_call_impl_interpret( else: # Base case is always one iteration when grid is () num_iterations = 1 + + # The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch) + # i:int32 is the interation index + # loop_idx: tuple[int32] are the program ids for each grid axis def cond(carry): i, *_ = carry return i < num_iterations @@ -256,7 +261,7 @@ def body(carry): carry, scratch = split_list(carry, [num_inout]) with pallas_core.grid_env(local_grid_env): start_indices = [ - None if bm is None else bm.compute_start_indices(loop_idx, *scalars) + None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars) for bm in grid_mapping.block_mappings] blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry, is_indexing_dim) @@ -269,13 +274,14 @@ def body(carry): len(blocks), len(scratch_values), ) - blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars, + blocks = jax.core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch) blocks = blocks[grid_mapping.num_index_operands:] blocks, out_scratch = split_list(blocks, [num_inout]) carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes, carry, blocks, is_indexing_dim) return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch) + (_, _, *carry) = lax.while_loop( cond, body, (jnp.int32(0), grid_start_indices, *carry) ) @@ -604,13 +610,13 @@ def _maybe_squeeze_out_bdim( # Ordinarily, adding support for scalar prefetch in vmap would involve # modifying the block specs in a nontrivial way. However, if we are only # vmapping over 1-sized dimensions, we can just get rid of the dimensions - # and pretend we were never vmapping over them at all. + # and pretend we were never vmapped over them at all. if all( bdim is batching.not_mapped or arg.shape[bdim] == 1 for arg, bdim in zip(scalar_args, scalar_bdims) ): scalar_args = safe_map(_maybe_squeeze_out_bdim, scalar_args, scalar_bdims) - scalar_bdims = [None] * len(scalar_args) + scalar_bdims = [batching.not_mapped] * len(scalar_args) args = (*scalar_args, *args) dims = (*scalar_bdims, *bdims) else: @@ -648,6 +654,7 @@ def _maybe_squeeze_out_bdim( all_dims = list(dims) + [0] * len(out_shapes) num_index_operands = grid_mapping.num_index_operands + num_constant_operands = grid_mapping.num_constant_operands num_scratch_operands = grid_mapping.num_scratch_operands # Only add a batch dimension for the avals that actually have a grid mapping. @@ -661,11 +668,16 @@ def _maybe_squeeze_out_bdim( block_mappings, ) + # TODO(necula): should fix in_shapes to include the consts + dims_no_consts = ( + dims[:num_index_operands] + + dims[num_index_operands + num_constant_operands:] + ) batched_in_shapes = tuple( jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else tuple_insert(x.shape, dim, axis_size), x.dtype) - for x, dim in zip(in_shapes, dims)) + for x, dim in zip(in_shapes, dims_no_consts)) batched_out_shapes = tuple( jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype) for x in out_shapes) @@ -900,11 +912,37 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals, debug) if consts: - jaxpr = state_utils.hoist_consts_to_refs(jaxpr) # Pad ``block_mappings`` to account for the hoisted constants. + # The constants will be right after the index operands and just before + # the real inputs and outputs. + jaxpr = state_utils.hoist_consts_to_refs( + jaxpr, + index=grid_mapping.num_index_operands, + make_abstract_ref=lambda aval: pallas_core.AbstractMemoryRef(aval, None)) + num_constant_operands = len(consts) + # TODO(necula): refactor grid_mapping to remove this code duplication + grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid) + if grid_mapping.num_index_operands: + grid_avals += flat_in_avals[:grid_mapping.num_index_operands] # type: ignore + # Create args, kwargs pytree def + grid_tree = tree_util.tree_structure((tuple(grid_avals), {})) + const_block_mappings = [] + for c_idx, c in enumerate(consts): + const_block_mapping = pallas_core._convert_block_spec_to_block_mapping( + grid_avals, + pallas_core.BlockSpec(None, None), + path=(tree_util.SequenceKey(c_idx),), + aval=jax_core.ShapedArray(c.shape, c.dtype), + in_tree=grid_tree, + grid=grid_mapping.grid, + mapped_dims=(), + what="consts", + ) + const_block_mappings.append(const_block_mapping) + grid_mapping = grid_mapping.replace( - block_mappings=(*grid_mapping.block_mappings, *[None] * len(consts)), - num_constant_operands=len(consts), + block_mappings=(*const_block_mappings, *grid_mapping.block_mappings), + num_constant_operands=num_constant_operands, ) return grid_mapping, jaxpr, consts, out_tree_thunk() @@ -1105,8 +1143,9 @@ def wrapped(*args): f"and to output{tree_util.keystr(out_paths[o_idx])} with " f"a different abstract value {out_aval}.") + index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands]) out_flat = pallas_call_p.bind( - *dynamic_grid_bounds, *consts, *flat_args, + *dynamic_grid_bounds, *index_args, *consts, *rest_args, jaxpr=jaxpr, name=name, in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype) for a in flat_args), diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index d8eaa8d28348..5653dcf6a6b7 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -251,7 +251,7 @@ def _new_ir_context() -> ir.Context: def lower_jaxpr_to_triton_module( jaxpr: jax_core.Jaxpr, - in_shapes, + in_out_shapes, grid_mapping: GridMapping, name: str, platform: str @@ -301,6 +301,10 @@ def lower_jaxpr_to_triton_module( functools.partial(_eval_index_map, ctx, program_ids), grid_mapping.block_mappings, ) + consts_shapes = [ + jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype) + for v in jaxpr.invars[grid_mapping.num_index_operands:grid_mapping.num_index_operands + grid_mapping.num_constant_operands] + ] block_infos = [ BlockInfo( jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype), @@ -310,7 +314,7 @@ def lower_jaxpr_to_triton_module( if block_mapping is not None else None for shape_dtype, block_mapping, start_idx in zip( - (*in_shapes, *[()] * grid_mapping.num_constant_operands), + (*consts_shapes, *in_out_shapes), grid_mapping.block_mappings, start_indices, ) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index e6d521692ec2..800b328f28e9 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -54,6 +54,7 @@ def pallas_call_lowering( compiler_params: dict[str, Any], ): if interpret: + # TODO(necula): is this branch still needed? return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( ctx, *in_nodes, @@ -72,6 +73,10 @@ def pallas_call_lowering( raise NotImplementedError( "dynamic grid bounds not supported in the Triton backend" ) + if grid_mapping.num_index_operands: + raise NotImplementedError( + "scalar prefetch not implemented in the Triton backend" + ) triton_params = compiler_params.get("triton", compiler_params) num_warps = triton_params.pop("num_warps", 4) [lowering_platform] = ctx.platforms or ctx.module_context.platforms diff --git a/jax/_src/state/utils.py b/jax/_src/state/utils.py index cb947a5ed90f..33fced775fad 100644 --- a/jax/_src/state/utils.py +++ b/jax/_src/state/utils.py @@ -13,6 +13,8 @@ # limitations under the License. """Utilities for tracing stateful functions.""" +from typing import Callable + from jax._src.interpreters import partial_eval as pe from jax._src import core from jax._src import linear_util as lu @@ -24,13 +26,20 @@ zip, unsafe_zip = safe_zip, zip -def hoist_consts_to_refs(jaxpr: core.Jaxpr, *, index: int = 0) -> core.Jaxpr: +def hoist_consts_to_refs( + jaxpr: core.Jaxpr, + *, + index: int = 0, + make_abstract_ref: Callable[[core.AbstractValue], AbstractRef] = lambda aval: AbstractRef(aval) +) -> core.Jaxpr: """Hoists the constants in the given jaxpr into invars. Args: jaxpr: The jaxpr. index: The index where the invars for the constants should be inserted. By default, the new invars are inserted *before* any existing invars. + make_abstract_ref: a callable to construct an AbstractRef, or subtype + thereof, from a constant AbstractValue. Returns: A new jaxpr where the constants were hoisted into invars as ``Ref``s. @@ -42,7 +51,7 @@ def hoist_consts_to_refs(jaxpr: core.Jaxpr, *, index: int = 0) -> core.Jaxpr: isinstance(var.aval, AbstractRef) for var in jaxpr.constvars ] const_avals = [ - var.aval if is_ref else AbstractRef(var.aval) + var.aval if is_ref else make_abstract_ref(var.aval) for is_ref, var in zip(is_const_ref, jaxpr.constvars) ] in_avals = [var.aval for var in jaxpr.invars] diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index ac40d1358f85..10f3c9ee7b00 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -150,7 +150,6 @@ def pallas_call(self, *args, **kwargs): class PallasCallTest(PallasBaseTest): - def test_add_one(self): if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: # TODO: assertion failures on CPU in 64-bit mode @@ -468,19 +467,22 @@ def kernel(o_ref): def test_hoisted_consts(self): # See https://github.com/google/jax/issues/21557. - if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: - self.skipTest("On TPU the test works only in interpret mode") - x = jnp.zeros(32) - indices = jnp.arange(4).reshape((2, 2)) + # to_store will be hoisted as a constant. Choose distinct shapes from in/outs. + to_store = np.arange(128, dtype=np.float32).reshape((1, 128)) + x = np.arange(16 * 128, dtype=np.float32).reshape((16, 128)) @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype), + grid=(2,), + in_specs=[pl.BlockSpec((8, 128), lambda i: (i, 0))], + out_specs=pl.BlockSpec((32, 128), lambda i: (i, 0)), ) def kernel(src, dst): - dst[indices] = src[indices] + dst[0:1] = to_store - jax.block_until_ready(kernel(x)) + res = kernel(x) + self.assertAllClose(res[0:1], to_store) def test_vector_slicing(self): if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: @@ -744,6 +746,17 @@ def test_pallas_call_index_map_wrong_number_of_results(self): "Index map for input\\[0\\] must return 1 values to match .*Currently returning 2 values."): f(a) + def test_pallas_call_index_map_captures_consts(self): + a = np.arange(256, dtype=np.int32) + index_map_result = np.array([0], dtype=np.int32) + f = self.pallas_call(lambda x_ref, o1_ref: None, + out_shape=a, + in_specs=[pl.BlockSpec((4,), lambda: index_map_result)]) + with self.assertRaisesRegex( + NotImplementedError, + "Index map for input\\[0\\] captures constants"): + f(a) + def test_pallas_call_out_specs_mismatch_shape(self): a = np.arange(256, dtype=np.int32) f = self.pallas_call(lambda x_ref, o1_ref: None, diff --git a/tests/pallas/pallas_vmap_test.py b/tests/pallas/pallas_vmap_test.py index 303a71ccc367..cfeff5cf79d3 100644 --- a/tests/pallas/pallas_vmap_test.py +++ b/tests/pallas/pallas_vmap_test.py @@ -37,7 +37,7 @@ @jtu.with_config(jax_traceback_filtering="off") -class PallasTest(jtu.JaxTestCase): +class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False def setUp(self): @@ -58,7 +58,7 @@ def pallas_call(self, *args, **kwargs): return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) -class PallasCallVmapTest(PallasTest): +class PallasCallVmapTest(PallasBaseTest): def setUp(self): super().setUp() @@ -130,6 +130,26 @@ def add_one(x_ref, o_ref): out_ref = jnp.arange(1, 9).reshape((4, 2)) np.testing.assert_allclose(out, out_ref) + def test_vmap_with_hoisted_consts(self): + # to_store will be hoisted as a constant. Choose distinct shapes from in/outs. + to_store = np.arange(128, dtype=np.float32).reshape((1, 128)) + x = np.arange(4 * 16 * 128, dtype=np.float32).reshape((4, 16, 128)) + + @jax.vmap + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype), + grid=(2,), + in_specs=[pl.BlockSpec((8, 128), lambda i: (i, 0))], + out_specs=pl.BlockSpec((32, 128), lambda i: (i, 0)), + ) + def kernel(src, dst): + dst[0:1] = to_store + + res = kernel(x) + for i in range(x.shape[0]): + self.assertAllClose(res[i, 0:1], to_store) + def test_vmap_of_kernel_with_input_output_aliases(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32), diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index b428705427e8..94162b343eff 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -17,6 +17,7 @@ import contextlib import functools import io +import math import re import sys from absl.testing import absltest @@ -71,7 +72,6 @@ def pallas_call(self, *args, **kwargs): class PallasCallScalarPrefetchTest(PallasBaseTest): - def test_trivial_scalar_prefetch(self): def body(_, x_ref, o_ref): o_ref[...] = x_ref[...] @@ -115,6 +115,56 @@ def body(_, x_ref, o_ref): )(s, x) np.testing.assert_array_equal(out, x) + @jtu.parameterized_filterable( + kwargs=[ + dict(scratch=scratch, vmap=vmap) + for scratch in [True, False] + for vmap in [True, False] + ] + ) + def test_scalar_prefetch_hoisted_const(self, *, scratch: bool, vmap: bool): + if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled: + self.skipTest("TODO: dslice(start, 1) raises error about slice inputs being int32 and int64") + # to_store will be hoisted as constants. Choose distinct shapes from in/outs. + to_store = np.arange(128, dtype=np.float32).reshape((1, 128)) + if vmap: + x_shape = (4, 16, 128) + else: + x_shape = (16, 128) + x = np.arange(math.prod(x_shape), dtype=np.float32).reshape(x_shape) + + def f(x): + s = jnp.array([1, 0], jnp.int32) # iteration 0 -> 1, iteration 1 -> 0 + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(2,), + in_specs=[pl.BlockSpec((8, 128), + lambda i, s_ref: (pl.load(s_ref, (i,)), 0))], + out_specs=pl.BlockSpec((32, 128), + lambda i, s_ref: (pl.load(s_ref, i), 0)), + scratch_shapes=([pltpu.SemaphoreType.REGULAR((3,))] if scratch + else []), + ), + ) + def kernel(s_ref, src, dst, *scratch_refs): + store_idx = s_ref[pl.program_id(0)] + pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store) + return kernel(s, x) + + if vmap: + f = jax.vmap(f) + res = f(x) + if vmap: + for i in range(x.shape[0]): + self.assertAllClose(res[i, 0:1], to_store) + self.assertAllClose(res[i, 33:34], to_store) + else: + self.assertAllClose(res[0:1], to_store) + self.assertAllClose(res[33:34], to_store) + def test_block_spec_with_wrong_block_shape_errors(self): def body(x_ref, o_ref): o_ref[...] = x_ref[...]