From 5b5ed76dfda68e1fbc2d53499dc7c2ae8091bc15 Mon Sep 17 00:00:00 2001 From: The praxis Authors Date: Wed, 24 Jul 2024 02:42:09 -0700 Subject: [PATCH] [DotProductAttention] When relative bias is used, set query_segment_pos and key_segment_pos to default positions (jnp.arange) if not provided. PiperOrigin-RevId: 655490645 --- praxis/layers/attentions.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/praxis/layers/attentions.py b/praxis/layers/attentions.py index b81a8232..ed9bc9fd 100644 --- a/praxis/layers/attentions.py +++ b/praxis/layers/attentions.py @@ -1731,14 +1731,20 @@ def __call__( # Apply relative bias. # Paper: https://aclanthology.org/N18-2074.pdf. if self.relative_bias_tpl: - # Create dummy variables for segment_pos if they are none so that - # the relative bias layer can infer the shape of the keys and queries. + # The relative bias expects the segment positions to be set. + # -> Create default segment positions if they are not provided. if query_segment_pos is None: # shape should be B x T - query_segment_pos = jnp.zeros(query_vec.shape[:2], dtype=jnp.int32) + query_segment_pos = jnp.repeat( + jnp.arange(query_vec.shape[1])[jnp.newaxis], + query_vec.shape[0], + axis=0, + ) if key_segment_pos is None: # shape should be B x S - key_segment_pos = jnp.zeros(key_vec.shape[:2], dtype=jnp.int32) + key_segment_pos = jnp.repeat( + jnp.arange(key_vec.shape[1])[jnp.newaxis], key_vec.shape[0], axis=0 + ) relative_bias = self.relative_bias(query_segment_pos, key_segment_pos) else: relative_bias = None