Skip to content

Commit

Permalink
Merge pull request #6 from robertsonwang/small_fixes
Browse files Browse the repository at this point in the history
Small Fixes
  • Loading branch information
sumedhghaisas2 authored Nov 12, 2024
2 parents f6d3e54 + 57641c3 commit 3a068a3
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/synthid_text/logits_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def watermarked_call(
])

device = scores.device
if device != self.device:
if str(device) != str(self.device):
raise ValueError(
"SynthIDLogitsProcessor received inputs with unexpected device.",
)
Expand Down
1 change: 0 additions & 1 deletion src/synthid_text/logits_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def does_mean_g_value_matches_theoretical(
)

probs = torch.nn.functional.softmax(updated_scores, dim=1)
generator = torch.Generator(device=device).manual_seed(0)
next_tokens = torch.multinomial(
probs,
num_samples=1,
Expand Down
5 changes: 4 additions & 1 deletion src/synthid_text/synthid_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ def _sample(
"`do_sample` is set to `True`, `logits_warper` must be a"
f" `LogitsProcessorList` instance (it is {logits_warper})."
)

if has_eos_stopping_criteria and not pad_token_id:
raise ValueError(
"`stopping_criteria` is not empty, `pad_token_id` must be set in generation_config."
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
Expand Down

0 comments on commit 3a068a3

Please sign in to comment.