Skip to content

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
kaixih committed Jul 5, 2024
1 parent bdc2db4 commit 9c881fc
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,14 +778,7 @@ def _get_causal_mask(T, S, dtype):
mask = jnp.where(pred, jnp.asarray(0.0, dtype), _get_large_negative(dtype))
return mask[jnp.newaxis, jnp.newaxis, :, :]

def _dot_product_attention_xla(
query: Array,
key: Array,
value: Array,
bias: Array | None,
mask: Array | None,
is_causal: bool,
scale: float):
def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):
logits_dtype = jnp.promote_types(query.dtype, jnp.float32)
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key,
preferred_element_type=logits_dtype)
Expand Down Expand Up @@ -876,7 +869,7 @@ def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None:
for i in range(t.ndim):
if shape[i] != -1 and t.shape[i] != shape[i]:
raise ValueError(f"{name} shape should be {shape}: but got {t.shape}")

query = jnp.asarray(query)
key = jnp.asarray(key)
value = jnp.asarray(value)
Expand Down Expand Up @@ -931,4 +924,3 @@ def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None:
)
case _:
raise ValueError(f"Unsupported implementation option: {implementation}")

0 comments on commit 9c881fc

Please sign in to comment.