Skip to content

Commit

Permalink
now using hf's implementation for SynthID watermarking
Browse files Browse the repository at this point in the history
  • Loading branch information
marluxiaboss committed Nov 7, 2024
1 parent b51a976 commit d4ba6c3
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 19 deletions.
30 changes: 24 additions & 6 deletions detector_benchmark/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ElectraForSequenceClassification,
ElectraTokenizer,
AutoConfig,
SynthIDTextWatermarkingConfig,
)
import torch
import argparse
Expand Down Expand Up @@ -62,12 +63,29 @@ def choose_watermarking_scheme(cfg: DictConfig, watermarking_scheme_name: str, g

algorithm_config = cfg.watermark

watermarking_scheme = AutoWatermark.load(
watermarking_scheme_name,
algorithm_config=algorithm_config,
gen_model=gen,
model_config=model_config,
)
# temporary band-aid fix for SynthID while we try to find a better solution
if watermarking_scheme_name == "SynthID":

# we have to pass these parameters to the watermarking scheme
# but the true parameters are handled with the hydra config
watermarking_config = SynthIDTextWatermarkingConfig(
keys=[654, 400, 836, 123, 340, 443, 597, 160, 57, ...],
ngram_len=5,
)

for k, v in algorithm_config.items():
setattr(watermarking_config, k, v)

watermarking_scheme = watermarking_config

else:

watermarking_scheme = AutoWatermark.load(
watermarking_scheme_name,
algorithm_config=algorithm_config,
gen_model=gen,
model_config=model_config,
)

return watermarking_scheme

Expand Down
19 changes: 12 additions & 7 deletions detector_benchmark/generation/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AutoModelForCausalLM,
LogitsProcessor,
LogitsProcessorList,
SynthIDTextWatermarkingConfig,
)
from ..utils.configs import ModelConfig
from ..watermark.auto_watermark import AutoWatermark
Expand Down Expand Up @@ -68,12 +69,6 @@ def forward(
outputs_list = []
for i in tqdm(range(0, len(samples), batch_size), desc="Generating text"):

# specific for SynthID watermarking scheme, we need to reset the state
if watermarking_scheme is not None and hasattr(
watermarking_scheme.logits_processor, "state"
):
watermarking_scheme.logits_processor.state = None

batch_samples = samples[i : i + batch_size]
encoding = self.tokenizer.batch_encode_plus(
batch_samples, return_tensors="pt", padding=True, truncation=True
Expand All @@ -83,8 +78,18 @@ def forward(
with torch.no_grad():
if watermarking_scheme is not None:

# special case for SynthID
if isinstance(watermarking_scheme, SynthIDTextWatermarkingConfig):

output_ids = self.generator.generate(
input_ids,
pad_token_id=self.tokenizer.pad_token_id,
watermarking_config=watermarking_scheme,
**self.gen_params
)

# if the watermarking scheme has a logits processor, use it
if hasattr(watermarking_scheme, "logits_processor"):
elif hasattr(watermarking_scheme, "logits_processor"):
output_ids = self.generator.generate(
input_ids,
pad_token_id=self.tokenizer.pad_token_id,
Expand Down
8 changes: 4 additions & 4 deletions detector_benchmark/watermark/synth_id/synth_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def __init__(self, algorithm_config: dict, gen_model, model_config, *args, **kwa
self.vocab_size = self.generation_tokenizer.vocab_size
self.device = model_config.device
self.gen_kwargs = model_config.gen_params
# self.top_k = model_config.gen_params["top_k"]
self.top_k = -1
# self.temperature = model_config.gen_params["temperature"]
self.temperature = 0.7
self.top_k = model_config.gen_params["top_k"]
# self.top_k = -1
self.temperature = model_config.gen_params["temperature"]
# self.temperature = 0.7


class SynthIDUtils:
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,13 @@ tempdir==0.7.1
termcolor==2.4.0
threadpoolctl==3.3.0
tinycss2==1.3.0
tokenizers==0.19.1
tokenizers==0.20.3
torch>=2.2.1
tornado==6.4
tqdm==4.66.2
tqdm-multiprocess==0.0.11
traitlets==5.14.3
transformers==4.44.2
transformers==4.46.2
typepy==1.3.2
typing_extensions
tyro==0.8.3
Expand Down

0 comments on commit d4ba6c3

Please sign in to comment.