Skip to content

Commit

Permalink
Revert "Feat/chipper repetitions (#295)"
Browse files Browse the repository at this point in the history
This reverts commit 54e3e46.
  • Loading branch information
ajjimeno authored Dec 20, 2023
1 parent 54e3e46 commit 8e21f74
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 253 deletions.
4 changes: 0 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
## 0.7.21

* Revised repetitions for Chipper

## 0.7.20

* chipper-v3: improved table prediction
Expand Down
40 changes: 5 additions & 35 deletions test_unstructured_inference/models/test_chippermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_no_repeat_ngram_logits():

no_repeat_ngram_size = 2

logitsProcessor = chipper.NoRepeatNGramLogitsProcessor(ngram_size=2, context_length=10)
logitsProcessor = chipper.NoRepeatNGramLogitsProcessor(ngram_size=2)
output = logitsProcessor(input_ids=input_ids, scores=logits)

assert (
Expand Down Expand Up @@ -194,49 +194,20 @@ def test_ngram_repetiton_stopping_criteria():
logits = torch.tensor([[0.1, -0.3, -0.5, 0, 1.0, -0.9]])

stoppingCriteria = chipper.NGramRepetitonStoppingCriteria(
ngram_size=2, context_length=10, skip_tokens={0, 1, 2, 3, 4}
repetition_window=2, skip_tokens={0, 1, 2, 3, 4}
)

output = stoppingCriteria(input_ids=input_ids, scores=logits)

assert output is False

stoppingCriteria = chipper.NGramRepetitonStoppingCriteria(
ngram_size=2, context_length=10, skip_tokens={1, 2, 3, 4}
repetition_window=2, skip_tokens={1, 2, 3, 4}
)
output = stoppingCriteria(input_ids=input_ids, scores=logits)
assert output is True


def test_no_repeat_group_ngram_logits_processor():
input_ids = torch.tensor([[1, 2, 3, 4, 0, 1, 2, 3, 4]])
logits = torch.tensor([[0.1, -0.3, -0.5, 0, 1.0, -0.9]])

logitsProcessor = chipper.NoRepeatGroupNGramLogitsProcessor(ngram_size=3, token_group=[1, 2])

output = logitsProcessor(input_ids=input_ids, scores=logits)

assert (
int(
torch.sum(
output == torch.tensor([[0.1000, -0.3000, -0.5000, 0.0000, 1.0000, -0.9000]]),
),
)
== 6
)


def test_target_token_id_stopping_criterion():
input_ids = torch.tensor([1, 2, 3])
logits = torch.tensor([0.1, 0.2, 0.3])

stoppingCriterion = chipper.TargetTokenIdStoppingCriterion(1)

output = stoppingCriterion(input_ids=input_ids, scores=logits)

assert output is True


@pytest.mark.parametrize(
("decoded_str", "expected_classes"),
[
Expand Down Expand Up @@ -288,8 +259,7 @@ def test_predict_tokens_beam_indices():
model = get_model("chipper")
model.stopping_criteria = [
chipper.NGramRepetitonStoppingCriteria(
ngram_size=1,
context_length=10,
repetition_window=1,
skip_tokens={},
),
]
Expand Down Expand Up @@ -326,7 +296,7 @@ def test_deduplicate_detected_elements():

def test_norepeatnGramlogitsprocessor_exception():
with pytest.raises(ValueError):
chipper.NoRepeatNGramLogitsProcessor(ngram_size="", context_length=10)
chipper.NoRepeatNGramLogitsProcessor(ngram_size="")


def test_run_chipper_v3():
Expand Down
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.21" # pragma: no cover
__version__ = "0.7.20" # pragma: no cover
Loading

0 comments on commit 8e21f74

Please sign in to comment.