diff --git a/tests/huggingface/test_huggingface_classify.py b/tests/huggingface/test_huggingface_classify.py index f42a0df..6d7cb39 100644 --- a/tests/huggingface/test_huggingface_classify.py +++ b/tests/huggingface/test_huggingface_classify.py @@ -228,10 +228,11 @@ def test_cache_nested(model_and_tokenizer, atol): # for SentencePiece tokenizers like Llama's delim = "" - logits = lambda *args, **kwargs: hf._utils.logits_texts(*args, **kwargs)[0] - """ - Returns next-token logits for each token in an inputted text. - """ + def logits(*args, **kwargs) -> torch.Tensor: + """ + Returns next-token logits for each token in an inputted text. + """ + return hf._utils.logits_texts(*args, **kwargs)[0] with classify.cache(model_and_tokenizer, "a") as cached_a: with classify.cache(cached_a, delim + "b c") as cached_a_b_c: @@ -241,9 +242,8 @@ def test_cache_nested(model_and_tokenizer, atol): logits3 = logits([delim + "1 2 3"], cached_a_b_c) logits4 = logits([delim + "b c d"], cached_a) - logits_correct = lambda texts, **kwargs: logits( - texts, model_and_tokenizer, drop_bos_token=False - ) + def logits_correct(texts: Sequence[str], **kwargs) -> torch.Tensor: + return logits(texts, model_and_tokenizer, drop_bos_token=False) assert torch.allclose(logits1, logits_correct(["a b c d e f"]), atol=atol) assert torch.allclose(logits2, logits_correct(["a b c d x"]), atol=atol)