Skip to content

Commit

Permalink
Add Watermarking LogitsProcessor and WatermarkDetector (#29676)
Browse files Browse the repository at this point in the history
* add watermarking processor

* remove the other hashing (context width=1 always)

* make style

* Update src/transformers/generation/logits_process.py

Co-authored-by: Joao Gante <[email protected]>

* Update src/transformers/generation/logits_process.py

Co-authored-by: Joao Gante <[email protected]>

* Update src/transformers/generation/logits_process.py

Co-authored-by: Joao Gante <[email protected]>

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Joao Gante <[email protected]>

* update watermarking process

* add detector

* update tests to use detector

* fix failing tests

* rename `input_seq`

* make style

* doc for processor

* minor fixes

* docs

* make quality

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Joao Gante <[email protected]>

* Update src/transformers/generation/logits_process.py

Co-authored-by: Joao Gante <[email protected]>

* Update src/transformers/generation/watermarking.py

Co-authored-by: Joao Gante <[email protected]>

* Update src/transformers/generation/watermarking.py

Co-authored-by: Joao Gante <[email protected]>

* Update src/transformers/generation/watermarking.py

Co-authored-by: Joao Gante <[email protected]>

* add PR suggestions

* let's use lru_cache's default max size (128)

* import processor if torch available

* maybe like this

* lets move the config to torch independet file

* add docs

* tiny docs fix to make the test happy

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Joao Gante <[email protected]>

* Update src/transformers/generation/watermarking.py

Co-authored-by: Joao Gante <[email protected]>

* PR suggestions

* add docs

* fix test

* fix docs

* address pr comments

* style

* Revert "style"

This reverts commit 7f33cc3.

* correct style

* make doctest green

---------

Co-authored-by: Joao Gante <[email protected]>
  • Loading branch information
zucchini-nlp and gante authored May 14, 2024
1 parent 65ea190 commit 5ad960f
Show file tree
Hide file tree
Showing 12 changed files with 738 additions and 4 deletions.
49 changes: 49 additions & 0 deletions docs/source/en/generation_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,55 @@ your screen, one word at a time:
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```


## Watermarking

The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
When generating the "green" will have a small 'bias' value added to their logits, thus having a higher chance to be generated.
The watermarked text can be detected by calculating the proportion of "green" tokens in the text and estimating how likely it is
statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper
["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on
the inner functioning of watermarking, it is recommended to refer to the paper.

The watermarking can be used with any generative model in `tranformers` and does not require an extra classification model
to detect watermarked text. To trigger watermarking, pass in a [`WatermarkingConfig`] with needed arguments directly to the
`.generate()` method or add it to the [`GenerationConfig`]. Watermarked text can be later detected with a [`WatermarkDetector`].


<Tip warning={true}>

The WatermarkDetector internally relies on the proportion of "green" tokens, and whether generated text follows the coloring pattern.
That is why it is recommended to strip off the prompt text, if it is much longer than the generated text.
This also can have an effect when one sequence in the batch is a lot longer causing other rows to be padded.
Additionally, the detector **must** be initiated with identical watermark configuration arguments used when generating.

</Tip>

Let's generate some text with watermarking. In the below code snippet, we set the bias to 2.5 which is a value that
will be added to "green" tokens' logits. After generating watermarked text, we can pass it directly to the `WatermarkDetector`
to check if the text is machine-generated (outputs `True` for machine-generated and `False` otherwise).

```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkDetector, WatermarkingConfig

>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> tok.pad_token_id = tok.eos_token_id
>>> tok.padding_side = "left"

>>> inputs = tok(["This is the beginning of a long story", "Alice and Bob are"], padding=True, return_tensors="pt")
>>> input_len = inputs["input_ids"].shape[-1]

>>> watermarking_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
>>> out = model.generate(**inputs, watermarking_config=watermarking_config, do_sample=False, max_length=20)

>>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config=watermarking_config)
>>> detection_out = detector(out, return_dict=True)
>>> detection_out.prediction
array([True, True])
```


## Decoding strategies

Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific
Expand Down
11 changes: 11 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ generation.
[[autodoc]] WhisperTimeStampLogitsProcessor
- __call__

[[autodoc]] WatermarkLogitsProcessor
- __call__


### TensorFlow

[[autodoc]] TFForcedBOSTokenLogitsProcessor
Expand Down Expand Up @@ -372,3 +376,10 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- update
- get_seq_length
- reorder_cache


## Watermark Utils

[[autodoc]] WatermarkDetector
- __call__

2 changes: 2 additions & 0 deletions docs/source/en/main_classes/text_generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ like token streaming.
- validate
- get_generation_mode

[[autodoc]] generation.WatermarkingConfig

## GenerationMixin

[[autodoc]] generation.GenerationMixin
Expand Down
13 changes: 11 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,12 @@
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [],
"generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"],
"generation": [
"GenerationConfig",
"TextIteratorStreamer",
"TextStreamer",
"WatermarkingConfig",
],
"hf_argparser": ["HfArgumentParser"],
"hyperparameter_search": [],
"image_transforms": [],
Expand Down Expand Up @@ -1232,6 +1237,8 @@
"TopPLogitsWarper",
"TypicalLogitsWarper",
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
"WatermarkDetector",
"WatermarkLogitsProcessor",
"WhisperTimeStampLogitsProcessor",
]
)
Expand Down Expand Up @@ -4617,7 +4624,7 @@
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin

# Generation
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
from .hf_argparser import HfArgumentParser

# Integrations
Expand Down Expand Up @@ -5797,6 +5804,8 @@
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkDetector,
WatermarkLogitsProcessor,
WhisperTimeStampLogitsProcessor,
)
from .modeling_utils import PreTrainedModel
Expand Down
14 changes: 12 additions & 2 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


_import_structure = {
"configuration_utils": ["GenerationConfig", "GenerationMode"],
"configuration_utils": ["GenerationConfig", "GenerationMode", "WatermarkingConfig"],
"streamers": ["TextIteratorStreamer", "TextStreamer"],
}

Expand Down Expand Up @@ -78,6 +78,7 @@
"TypicalLogitsWarper",
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
"WhisperTimeStampLogitsProcessor",
"WatermarkLogitsProcessor",
]
_import_structure["stopping_criteria"] = [
"MaxNewTokensCriteria",
Expand Down Expand Up @@ -106,6 +107,10 @@
"GenerateDecoderOnlyOutput",
"GenerateEncoderDecoderOutput",
]
_import_structure["watermarking"] = [
"WatermarkDetector",
"WatermarkDetectorOutput",
]

try:
if not is_tf_available():
Expand Down Expand Up @@ -174,7 +179,7 @@
]

if TYPE_CHECKING:
from .configuration_utils import GenerationConfig, GenerationMode
from .configuration_utils import GenerationConfig, GenerationMode, WatermarkingConfig
from .streamers import TextIteratorStreamer, TextStreamer

try:
Expand Down Expand Up @@ -218,6 +223,7 @@
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
WhisperTimeStampLogitsProcessor,
)
from .stopping_criteria import (
Expand Down Expand Up @@ -247,6 +253,10 @@
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
)
from .watermarking import (
WatermarkDetector,
WatermarkDetectorOutput,
)

try:
if not is_tf_available():
Expand Down
Loading

0 comments on commit 5ad960f

Please sign in to comment.