Skip to content

Commit

Permalink
Added mocking to test tokenizer changes, directly pass padding strategy
Browse files Browse the repository at this point in the history
Signed-off-by: kcirred <[email protected]>
  • Loading branch information
kcirred committed Oct 15, 2024
1 parent 4f8a821 commit c66918a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 14 deletions.
1 change: 0 additions & 1 deletion caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from torch import nn
from torch.backends import mps
from transformers import BatchEncoding
from transformers.tokenization_utils import PaddingStrategy
import numpy as np
import torch

Expand Down
67 changes: 54 additions & 13 deletions tests/modules/text_embedding/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import numpy as np
import pytest
import torch
from unittest.mock import patch
from transformers import BatchEncoding

# First Party
from caikit.core import ModuleConfig
Expand Down Expand Up @@ -1144,21 +1146,60 @@ def test_same_same(loaded_model: EmbeddingModule, truncate_input_tokens):
separate_vectors[1], separate_vectors[2], rtol=1e-05, atol=1e-08
)

@pytest.mark.parametrize("pad_to_max_length", [None, False, True, 0, 1])
def test_pad_to_max_length(pad_to_max_length, loaded_model):
"""Tests for tokenization kwargs pad_to_max_length will modify tokenizer and give same result"""
def custom_sum_token_count(
tokenized: BatchEncoding,
) -> int:
"""Returns total number of tokens regardless of attention_mask value
"""

token_count = 0
for encoding in tokenized.encodings:
token_count += len(encoding.attention_mask)

return token_count

@pytest.mark.parametrize("padding_strategy", [True, 'max_length'])
def test_pad_to_max_length(padding_strategy, loaded_model):
"""Tests for tokenization kwargs max_length will modify tokenizer"""
model_max = loaded_model.model.max_seq_length

tokenizer_kwargs = {'pad_to_max_length': pad_to_max_length}
tokenizer_kwargs = {'padding_strategy': padding_strategy}
max_seq = "x " * (model_max - 2) # Subtract 2 for begin/end tokens
max_seq_minus_one = "x " * (model_max - 3) # 1 token length shorter than max_seq_length
short = "x "
single = "x "

normal_result = loaded_model._encode_with_retry(
[max_seq_minus_one, max_seq, short], return_token_count=True
)
padded_result = loaded_model._encode_with_retry(
[max_seq_minus_one, max_seq, short], return_token_count=True, tokenizer_kwargs=tokenizer_kwargs
)

assert np.all(normal_result.embedding == padded_result.embedding)
if padding_strategy is True:
normal_result = loaded_model._encode_with_retry(
[max_seq_minus_one], return_token_count=True
)
padded_result = loaded_model._encode_with_retry(
[max_seq_minus_one], return_token_count=True, tokenizer_kwargs=tokenizer_kwargs
)
assert np.all(normal_result.embedding == padded_result.embedding)
elif padding_strategy == 'max_length':
with patch('caikit_nlp.modules.text_embedding.embedding.sum_token_count') as mock_sum_token_count:
mock_sum_token_count.side_effect = custom_sum_token_count
normal_result = loaded_model._encode_with_retry(
[max_seq_minus_one], return_token_count=True
)
padded_result = loaded_model._encode_with_retry(
[max_seq_minus_one], return_token_count=True, tokenizer_kwargs=tokenizer_kwargs
)
assert normal_result.input_token_count != padded_result.input_token_count
assert not np.all(normal_result.embedding == padded_result.embedding)
normal_result = loaded_model._encode_with_retry(
[max_seq], return_token_count=True
)
padded_result = loaded_model._encode_with_retry(
[max_seq], return_token_count=True, tokenizer_kwargs=tokenizer_kwargs
)
assert normal_result.input_token_count == padded_result.input_token_count
assert np.all(normal_result.embedding == padded_result.embedding)
normal_result = loaded_model._encode_with_retry(
[single], return_token_count=True
)
padded_result = loaded_model._encode_with_retry(
[single], return_token_count=True, tokenizer_kwargs=tokenizer_kwargs
)
assert normal_result.input_token_count != padded_result.input_token_count
assert not np.all(normal_result.embedding == padded_result.embedding)

0 comments on commit c66918a

Please sign in to comment.