Skip to content

Commit

Permalink
Merge pull request #365 from flaviabeo/tokenize
Browse files Browse the repository at this point in the history
Adds tokenization task for embeddings
  • Loading branch information
evaline-ju authored Jul 24, 2024
2 parents 44d61a5 + 2a8321e commit 1499cab
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
27 changes: 27 additions & 0 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
SentenceSimilarityResult,
SentenceSimilarityResults,
SentenceSimilarityScores,
Token,
TokenizationResults,
)
from caikit.interfaces.nlp.tasks import (
EmbeddingTask,
Expand All @@ -63,6 +65,7 @@
RerankTasks,
SentenceSimilarityTask,
SentenceSimilarityTasks,
TokenizationTask,
)
import alog

Expand Down Expand Up @@ -135,6 +138,7 @@ class TruncatedTokensTuple(NamedTuple):
SentenceSimilarityTasks,
RerankTask,
RerankTasks,
TokenizationTask,
],
)
class EmbeddingModule(ModuleBase):
Expand Down Expand Up @@ -201,6 +205,29 @@ def public_model_info(cls) -> Dict[str, Any]: # pylint: disable=no-self-argumen
"sentence_embedding_dimension": cls.model.get_sentence_embedding_dimension(),
}

@TokenizationTask.taskmethod()
def run_tokenizer(
self,
text: str,
) -> TokenizationResults:
"""Run tokenization task against the model
Args:
text: str
Text to tokenize
Returns:
TokenizationResults
The token count
"""
result = self.model._get_tokenized([text])

mapping = [
interv for interv in result.offset_mapping[0] if (interv[1] - interv[0]) > 0
]
tokens = [Token(start=i[0], end=i[1], text=text[i[0] : i[1]]) for i in mapping]

return TokenizationResults(token_count=len(result.input_ids[0]), results=tokens)

@classmethod
def _get_ipex(cls, ipex_flag):
"""Get IPEX optimization library if enabled and available, else return False
Expand Down
10 changes: 10 additions & 0 deletions tests/modules/text_embedding/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
RerankResults,
RerankScore,
RerankScores,
Token,
TokenizationResults,
)

# Local
Expand Down Expand Up @@ -280,6 +282,14 @@ def test_run_embeddings(loaded_model):
assert res.input_token_count == INPUT_TOKEN_COUNT


def test_run_tokenization(loaded_model):
res = loaded_model.run_tokenizer(text=INPUT)
assert isinstance(res, TokenizationResults)
assert isinstance(res.results, list)
assert isinstance(res.results[0], Token)
assert res.token_count == INPUT_TOKEN_COUNT


@pytest.mark.parametrize(
"query,docs,top_n",
[
Expand Down

0 comments on commit 1499cab

Please sign in to comment.