From 592eb44d7e288ad5fe7b1bfb01e715ef3cb0ca05 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 8 Sep 2023 07:49:40 -0700 Subject: [PATCH] Fix Pallas tests that broke after recent changes PiperOrigin-RevId: 563750775 --- tests/pallas/pallas_test.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 9de56d9a8e27..00f486d1f39c 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -163,8 +163,8 @@ def add_one(x_ref, o_ref): def test_add_vector_block_spec(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32), - in_specs=(pl.BlockSpec(lambda i: i, (1,)),), - out_specs=(pl.BlockSpec(lambda i: i, (1,)),), + in_specs=[pl.BlockSpec(lambda i: i, (1,))], + out_specs=pl.BlockSpec(lambda i: i, (1,)), grid=8, debug=False) def add_one(x_ref, o_ref): o_ref[0] = x_ref[0] + 1 @@ -174,8 +174,8 @@ def add_one(x_ref, o_ref): def test_add_matrix_block_spec(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 8), jnp.int32), - in_specs=(pl.BlockSpec(lambda i, j: (i, j), (2, 2)),), - out_specs=(pl.BlockSpec(lambda i, j: (i, j), (2, 2)),), + in_specs=[pl.BlockSpec(lambda i, j: (i, j), (2, 2))], + out_specs=pl.BlockSpec(lambda i, j: (i, j), (2, 2)), grid=(4, 4)) def add_one(x_ref, o_ref): o_ref[:, :] = x_ref[:, :] + 1 @@ -1203,7 +1203,6 @@ def test_slicing_block_spec(self): pl.BlockSpec(lambda _: (0, 0), (None, 4)), pl.BlockSpec(lambda _: (1, 0), (None, 4)), ], - out_specs=None, debug=False, grid=1) def add_vectors(x_ref, y_ref, o_ref): o_ref[:] = x_ref[:] + y_ref[:]