Skip to content

Commit

Permalink
Lambdas to funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey committed Oct 28, 2024
1 parent 3b7ea04 commit c97b761
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tests/huggingface/test_huggingface_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,11 @@ def test_cache_nested(model_and_tokenizer, atol):
# for SentencePiece tokenizers like Llama's
delim = ""

logits = lambda *args, **kwargs: hf._utils.logits_texts(*args, **kwargs)[0]
"""
Returns next-token logits for each token in an inputted text.
"""
def logits(*args, **kwargs) -> torch.Tensor:
"""
Returns next-token logits for each token in an inputted text.
"""
return hf._utils.logits_texts(*args, **kwargs)[0]

with classify.cache(model_and_tokenizer, "a") as cached_a:
with classify.cache(cached_a, delim + "b c") as cached_a_b_c:
Expand All @@ -241,9 +242,8 @@ def test_cache_nested(model_and_tokenizer, atol):
logits3 = logits([delim + "1 2 3"], cached_a_b_c)
logits4 = logits([delim + "b c d"], cached_a)

logits_correct = lambda texts, **kwargs: logits(
texts, model_and_tokenizer, drop_bos_token=False
)
def logits_correct(texts: Sequence[str], **kwargs) -> torch.Tensor:
return logits(texts, model_and_tokenizer, drop_bos_token=False)

assert torch.allclose(logits1, logits_correct(["a b c d e f"]), atol=atol)
assert torch.allclose(logits2, logits_correct(["a b c d x"]), atol=atol)
Expand Down

0 comments on commit c97b761

Please sign in to comment.