Skip to content

Commit

Permalink
[DotProductAttention] When relative bias is used, set query_segment_p…
Browse files Browse the repository at this point in the history
…os and key_segment_pos to default positions (jnp.arange) if not provided.

PiperOrigin-RevId: 655490645
  • Loading branch information
The praxis Authors committed Jul 24, 2024
1 parent e9886a3 commit 5b5ed76
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions praxis/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1731,14 +1731,20 @@ def __call__(
# Apply relative bias.
# Paper: https://aclanthology.org/N18-2074.pdf.
if self.relative_bias_tpl:
# Create dummy variables for segment_pos if they are none so that
# the relative bias layer can infer the shape of the keys and queries.
# The relative bias expects the segment positions to be set.
# -> Create default segment positions if they are not provided.
if query_segment_pos is None:
# shape should be B x T
query_segment_pos = jnp.zeros(query_vec.shape[:2], dtype=jnp.int32)
query_segment_pos = jnp.repeat(
jnp.arange(query_vec.shape[1])[jnp.newaxis],
query_vec.shape[0],
axis=0,
)
if key_segment_pos is None:
# shape should be B x S
key_segment_pos = jnp.zeros(key_vec.shape[:2], dtype=jnp.int32)
key_segment_pos = jnp.repeat(
jnp.arange(key_vec.shape[1])[jnp.newaxis], key_vec.shape[0], axis=0
)
relative_bias = self.relative_bias(query_segment_pos, key_segment_pos)
else:
relative_bias = None
Expand Down

0 comments on commit 5b5ed76

Please sign in to comment.