Skip to content

Commit

Permalink
Update device checking, removed unused generator, and raise valueerro…
Browse files Browse the repository at this point in the history
…r if pad token id is not passed with early stopping
  • Loading branch information
robertsonwang committed Oct 30, 2024
1 parent 9daa14b commit 57641c3
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/synthid_text/logits_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def watermarked_call(
])

device = scores.device
if device != self.device:
if str(device) != str(self.device):
raise ValueError(
"SynthIDLogitsProcessor received inputs with unexpected device.",
)
Expand Down
1 change: 0 additions & 1 deletion src/synthid_text/logits_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def test_mean_g_value_matches_theoretical(
)

probs = torch.nn.functional.softmax(updated_scores, dim=1)
generator = torch.Generator(device=device).manual_seed(0)
next_tokens = torch.multinomial(
probs,
num_samples=1,
Expand Down
5 changes: 4 additions & 1 deletion src/synthid_text/synthid_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ def _sample(
"`do_sample` is set to `True`, `logits_warper` must be a"
f" `LogitsProcessorList` instance (it is {logits_warper})."
)

if has_eos_stopping_criteria and not pad_token_id:
raise ValueError(
"`stopping_criteria` is not empty, `pad_token_id` must be set in generation_config."
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
Expand Down

0 comments on commit 57641c3

Please sign in to comment.