Skip to content

Commit

Permalink
fixed a couple of issues
Browse files Browse the repository at this point in the history
  • Loading branch information
marluxiaboss committed Nov 8, 2024
1 parent 7bf7166 commit f5fa553
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion detector_benchmark/conf/watermark/watermark_synth_id.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ sampling_table_size: 65536
sampling_table_seed: 0
context_history_size: 1024
detector_type: mean
threshold: 0.52
z_threshold: 0.52
2 changes: 1 addition & 1 deletion detector_benchmark/detector/fast_detect_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def detect(
reference_model_name = "gpt-neo-2.7B"
scoring_model_name = "gpt-neo-2.7B"

ref_path = "detector/local_infer_ref"
ref_path = "local_infer_ref"
device = self.device

ref_model = self.ref_model
Expand Down
5 changes: 5 additions & 0 deletions detector_benchmark/test_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def test_detector(cfg: DictConfig):
# check that we have the same watermark algorithm
assert watermark_config["algorithm_name"] == watermarking_scheme_name

# check if there is a key called "threshold" in the watermark_config
if "threshold" in watermark_config:
# change it to "z_threshold"
watermark_config["z_threshold"] = watermark_config["threshold"]

# modify all values of cfg.watermark to the values in watermark_config
for key, value in cfg.watermark.items():
cfg.watermark[key] = watermark_config[key]
Expand Down
6 changes: 3 additions & 3 deletions detector_benchmark/watermark/synth_id/synth_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, algorithm_config: dict, gen_model, model_config, *args, **kwa
self.sampling_table_seed = config_dict["sampling_table_seed"]
self.context_history_size = config_dict["context_history_size"]
self.detector_name = config_dict["detector_type"]
self.threshold = config_dict["threshold"]
self.threshold = config_dict["z_threshold"]

# Model configuration
self.generation_model = gen_model
Expand Down Expand Up @@ -504,10 +504,10 @@ def __init__(
self.utils = SynthIDUtils(self.config)

# fix issue with the keys parameter of the config
keys_str = self.config["keys"]
keys_str = self.config.keys
keys = keys_str.replace("[", "").replace("]", "").split(",")
keys = [int(key) for key in keys]
self.config["keys"] = keys
self.config.keys = keys

self.logits_processor = SynthIDLogitsProcessor(self.config, self.utils)

Expand Down

0 comments on commit f5fa553

Please sign in to comment.