Skip to content

Commit

Permalink
Fix Pallas tests that broke after recent changes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 563750775
  • Loading branch information
apaszke authored and jax authors committed Sep 8, 2023
1 parent 601d67a commit 592eb44
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def add_one(x_ref, o_ref):
def test_add_vector_block_spec(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
in_specs=(pl.BlockSpec(lambda i: i, (1,)),),
out_specs=(pl.BlockSpec(lambda i: i, (1,)),),
in_specs=[pl.BlockSpec(lambda i: i, (1,))],
out_specs=pl.BlockSpec(lambda i: i, (1,)),
grid=8, debug=False)
def add_one(x_ref, o_ref):
o_ref[0] = x_ref[0] + 1
Expand All @@ -174,8 +174,8 @@ def add_one(x_ref, o_ref):
def test_add_matrix_block_spec(self):
@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 8), jnp.int32),
in_specs=(pl.BlockSpec(lambda i, j: (i, j), (2, 2)),),
out_specs=(pl.BlockSpec(lambda i, j: (i, j), (2, 2)),),
in_specs=[pl.BlockSpec(lambda i, j: (i, j), (2, 2))],
out_specs=pl.BlockSpec(lambda i, j: (i, j), (2, 2)),
grid=(4, 4))
def add_one(x_ref, o_ref):
o_ref[:, :] = x_ref[:, :] + 1
Expand Down Expand Up @@ -1203,7 +1203,6 @@ def test_slicing_block_spec(self):
pl.BlockSpec(lambda _: (0, 0), (None, 4)),
pl.BlockSpec(lambda _: (1, 0), (None, 4)),
],
out_specs=None,
debug=False, grid=1)
def add_vectors(x_ref, y_ref, o_ref):
o_ref[:] = x_ref[:] + y_ref[:]
Expand Down

0 comments on commit 592eb44

Please sign in to comment.