Skip to content

Commit

Permalink
source_padding tensor has to have the same batch size as `source_ve…
Browse files Browse the repository at this point in the history
…cs` / `source_contexts` inputs to attention (all versions). Add the necessary size checks to `PackSource` functions.

Fix test sending invalid sizes.

PiperOrigin-RevId: 681913876
  • Loading branch information
lingvo-bot authored and copybara-github committed Oct 3, 2024
1 parent 2cb07a2 commit 777d628
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 13 deletions.
52 changes: 45 additions & 7 deletions lingvo/core/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,14 @@ def PackSource(
Returns:
A NestedMap containing the packed source.
"""
time, batch_size = py_utils.GetShape(source_vecs, 2)
source_contexts = py_utils.HasShape(source_contexts, [time, batch_size, -1])
source_padding = py_utils.HasShape(source_padding, [time, batch_size])
if source_segment_id is not None:
source_segment_id = py_utils.HasShape(
source_segment_id, [time, batch_size]
)

with tf.name_scope(self.params.name):
if source_segment_id is None:
source_segment_id = tf.zeros_like(source_padding)
Expand Down Expand Up @@ -1197,6 +1205,14 @@ def PackSource(
[batch_size, time, some_dim] and `source_padding` is a tensor of shape
[time, batch_size].
"""
time, batch_size = py_utils.GetShape(source_vecs, 2)
source_contexts = py_utils.HasShape(source_contexts, [time, batch_size, -1])
source_padding = py_utils.HasShape(source_padding, [time, batch_size])
if source_segment_id is not None:
source_segment_id = py_utils.HasShape(
source_segment_id, [time, batch_size]
)

concated_source_vecs = tf.identity(source_vecs)
concated_source_contexts = tf.transpose(source_contexts, [1, 0, 2])
if source_segment_id is None:
Expand Down Expand Up @@ -1627,13 +1643,11 @@ def PackSource(
# Check input tensor shapes
# [time_steps, batch_size, source_dim]
source_vecs = py_utils.HasRank(source_vecs, 3)
[time_steps, batch_size] = py_utils.GetShape(source_vecs, 2)
if p.use_source_vec_as_attention_value:
assert source_contexts is not None
# [time_steps, batch_size, context_dim]
source_contexts = py_utils.HasShape(
source_contexts, [time_steps, batch_size, -1]
)
time_steps, batch_size = py_utils.GetShape(source_vecs, 2)
# [time_steps, batch_size, context_dim]
source_contexts = py_utils.HasShape(
source_contexts, [time_steps, batch_size, -1]
)
source_padding = py_utils.HasShape(source_padding, [time_steps, batch_size])
if source_segment_id is not None:
source_segment_id = py_utils.HasShape(
Expand Down Expand Up @@ -2643,6 +2657,14 @@ def PackSource(
source_padding: tf.Tensor,
source_segment_id: Optional[tf.Tensor] = None,
) -> py_utils.NestedMap:
time, batch_size = py_utils.GetShape(source_vecs, 2)
source_contexts = py_utils.HasShape(source_contexts, [time, batch_size, -1])
source_padding = py_utils.HasShape(source_padding, [time, batch_size])
if source_segment_id is not None:
source_segment_id = py_utils.HasShape(
source_segment_id, [time, batch_size]
)

with tf.name_scope(self.params.name):
if source_segment_id is None:
source_segment_id = tf.zeros_like(source_padding)
Expand Down Expand Up @@ -2959,6 +2981,14 @@ def PackSource(
source_padding: tf.Tensor,
source_segment_id: Optional[tf.Tensor] = None,
) -> py_utils.NestedMap:
time, batch_size = py_utils.GetShape(source_vecs, 2)
source_contexts = py_utils.HasShape(source_contexts, [time, batch_size, -1])
source_padding = py_utils.HasShape(source_padding, [time, batch_size])
if source_segment_id is not None:
source_segment_id = py_utils.HasShape(
source_segment_id, [time, batch_size]
)

with tf.name_scope(self.params.name):
if source_segment_id is None:
source_segment_id = tf.zeros_like(source_padding)
Expand Down Expand Up @@ -3361,6 +3391,14 @@ def PackSource(
source_padding: tf.Tensor,
source_segment_id: Optional[tf.Tensor] = None,
) -> py_utils.NestedMap:
time, batch_size = py_utils.GetShape(source_vecs, 2)
source_contexts = py_utils.HasShape(source_contexts, [time, batch_size, -1])
source_padding = py_utils.HasShape(source_padding, [time, batch_size])
if source_segment_id is not None:
source_segment_id = py_utils.HasShape(
source_segment_id, [time, batch_size]
)

with tf.name_scope(self.params.name):
if source_segment_id is None:
source_segment_id = tf.zeros_like(source_padding)
Expand Down
9 changes: 3 additions & 6 deletions lingvo/core/steps/attention_steps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for third_party.py.lingvo.core.steps.attention_steps."""

from lingvo import compat as tf
from lingvo.core import attention
Expand All @@ -37,8 +36,7 @@ def testAttentionStep(self):
source_contexts = tf.constant(
np.random.rand(src_length, src_batch_size, src_context_dim),
dtype=tf.float32)
source_padding = tf.zeros([src_length, target_batch_size],
dtype=tf.float32)
source_padding = tf.zeros([src_length, src_batch_size], dtype=tf.float32)
query_vec = tf.constant(
np.random.rand(target_batch_size, query_dim), dtype=tf.float32)

Expand Down Expand Up @@ -463,8 +461,7 @@ def testAttentionBlockStep(self):
src_dim = context_dim
source_vecs = tf.constant(
np.random.rand(src_length, src_batch_size, src_dim), dtype=tf.float32)
source_padding = tf.zeros([src_length, target_batch_size],
dtype=tf.float32)
source_padding = tf.zeros([src_length, src_batch_size], dtype=tf.float32)

p = attention_steps.AttentionBlockStep.Params()
p.attention.atten.params_init = py_utils.WeightInit.Gaussian(0.1, 12345)
Expand All @@ -488,7 +485,7 @@ def testAttentionBlockStep(self):
state0)

self.evaluate(tf.global_variables_initializer())
output, state1 = self.evaluate([output, state1])
output, _ = self.evaluate([output, state1])

self.assertAllClose(
output, {
Expand Down

0 comments on commit 777d628

Please sign in to comment.