From 9644e265712f4fc1cc70db1913455946ff365df4 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Mon, 11 Nov 2024 15:19:09 +0000 Subject: [PATCH 1/3] Fixes #8. Use Py3.9 features and type hints --- ...synthid_text_huggingface_integration.ipynb | 1729 ++++++++--------- src/synthid_text/detector_bayesian.py | 36 +- src/synthid_text/detector_mean.py | 3 +- src/synthid_text/synthid_mixin.py | 10 +- src/synthid_text/synthid_mixin_test.py | 2 +- 5 files changed, 891 insertions(+), 889 deletions(-) diff --git a/notebooks/synthid_text_huggingface_integration.ipynb b/notebooks/synthid_text_huggingface_integration.ipynb index 4745b3e..b2f7ed5 100644 --- a/notebooks/synthid_text_huggingface_integration.ipynb +++ b/notebooks/synthid_text_huggingface_integration.ipynb @@ -1,868 +1,867 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "Cthb8O3LCPM1" - }, - "source": [ - "# SynthID Text: Watermarking for Generated Text\n", - "\n", - "This notebook demonstrates how to use the [SynthID Text library][synthid-code]\n", - "to apply and detect watermarks on generated text. It is divided into three major\n", - "sections and intended to be run end-to-end.\n", - "\n", - "1. **_Setup_**: Importing the SynthID Text library, choosing your model (either\n", - " [Gemma][gemma] or [GPT-2][gpt2]) and device (either CPU or GPU, depending\n", - " on your runtime), defining the watermarking configuration, and initializing\n", - " some helper functions.\n", - "1. **_Applying a watermark_**: Loading your selected model using the\n", - " [Hugging Face Transformers][transformers] library, using that model to\n", - " generate some watermarked text, and comparing the perplexity of the\n", - " watermarked text to that of text generated by the base model.\n", - "1. **_Detecting a watermark_**: Training a detector to recognize text generated\n", - " with a specific watermarking configuration, and then using that detector to\n", - " predict whether a set of examples were generated with that configuration.\n", - "\n", - "As the reference implementation for the\n", - "[SynthID Text paper in _Nature_][synthid-paper], this library and notebook are\n", - "intended for research review and reproduction only. They should not be used in\n", - "production systems. For a production-grade implementation, check out the\n", - "official SynthID logits processor in [Hugging Face Transformers][transformers].\n", - "\n", - "[gemma]: https://ai.google.dev/gemma/docs/model_card\n", - "[gpt2]: https://huggingface.co/openai-community/gpt2\n", - "[synthid-code]: https://github.com/google-deepmind/synthid-text\n", - "[synthid-paper]: https://www.nature.com/\n", - "[transformers]: https://huggingface.co/docs/transformers/en/index" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "be-I0MNRbyWT" - }, - "source": [ - "# 1. Setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "aq7hChW8njFo" - }, - "outputs": [], - "source": [ - "# @title Install and import the required Python packages\n", - "#\n", - "# @markdown Running this cell may require you to restart your session.\n", - "\n", - "! pip install synthid-text[notebook]\n", - "\n", - "from collections.abc import Sequence\n", - "import enum\n", - "import gc\n", - "\n", - "import datasets\n", - "import huggingface_hub\n", - "from synthid_text import detector_mean\n", - "from synthid_text import logits_processing\n", - "from synthid_text import synthid_mixin\n", - "from synthid_text import detector_bayesian\n", - "import tensorflow as tf\n", - "import torch\n", - "import tqdm\n", - "import transformers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "w9a5nANolFS_" - }, - "outputs": [], - "source": [ - "# @title Choose your model.\n", - "#\n", - "# @markdown This reference implementation is configured to use the Gemma v1.0\n", - "# @markdown Instruction-Tuned variants in 2B or 7B sizes, or GPT-2.\n", - "\n", - "\n", - "class ModelName(enum.Enum):\n", - " GPT2 = 'gpt2'\n", - " GEMMA_2B = 'google/gemma-2b-it'\n", - " GEMMA_7B = 'google/gemma-7b-it'\n", - "\n", - "\n", - "model_name = 'google/gemma-7b-it' # @param ['gpt2', 'google/gemma-2b-it', 'google/gemma-7b-it']\n", - "MODEL_NAME = ModelName(model_name)\n", - "\n", - "if MODEL_NAME is not ModelName.GPT2:\n", - " huggingface_hub.notebook_login()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "B_pe-hG6SW6H" - }, - "outputs": [], - "source": [ - "# @title Configure your device\n", - "#\n", - "# @markdown This notebook loads models from Hugging Face Transformers into the\n", - "# @markdown PyTorch deep learning runtime. PyTorch supports generation on CPU or\n", - "# @markdown GPU, but your chosen model will run best on the following hardware,\n", - "# @markdown some of which may require a\n", - "# @markdown [Colab Subscription](https://colab.research.google.com/signup).\n", - "# @markdown\n", - "# @markdown * Gemma v1.0 2B IT: Use a GPU with 16GB of memory, such as a T4.\n", - "# @markdown * Gemma v1.0 7B IT: Use a GPU with 32GB of memory, such as an A100.\n", - "# @markdown * GPT-2: Any runtime will work, though a High-RAM CPU or any GPU\n", - "# @markdown will be faster.\n", - "\n", - "DEVICE = (\n", - " torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n", - ")\n", - "DEVICE" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "UOGvCjyVjjQ5" - }, - "outputs": [], - "source": [ - "# @title Example watermarking config\n", - "#\n", - "# @markdown SynthID Text produces unique watermarks given a configuration, with\n", - "# @markdown the most important piece of a configuration being the `keys`: a\n", - "# @markdown sequence of unique integers.\n", - "# @markdown\n", - "# @markdown This reference implementation uses a fixed watermarking\n", - "# @markdown configuration, which will be displayed when you run this cell.\n", - "\n", - "CONFIG = synthid_mixin.DEFAULT_WATERMARKING_CONFIG\n", - "CONFIG" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "79mekKj5UUZR" - }, - "outputs": [], - "source": [ - "# @title Initialize the required constants, tokenizer, and logits processor\n", - "\n", - "BATCH_SIZE = 8\n", - "NUM_BATCHES = 320\n", - "OUTPUTS_LEN = 1024\n", - "TEMPERATURE = 0.5\n", - "TOP_K = 40\n", - "TOP_P = 0.99\n", - "\n", - "tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME.value)\n", - "tokenizer.pad_token = tokenizer.eos_token\n", - "tokenizer.padding_side = \"left\"\n", - "\n", - "logits_processor = logits_processing.SynthIDLogitsProcessor(\n", - " **CONFIG, top_k=TOP_K, temperature=TEMPERATURE\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "hndT3YCQUt6D" - }, - "outputs": [], - "source": [ - "# @title Utility functions to load models, compute perplexity, and process prompts.\n", - "\n", - "\n", - "def load_model(\n", - " model_name: ModelName,\n", - " expected_device: torch.device,\n", - " enable_watermarking: bool = False,\n", - ") -> transformers.PreTrainedModel:\n", - " match model_name:\n", - " case ModelName.GPT2:\n", - " model_cls = (\n", - " synthid_mixin.SynthIDGPT2LMHeadModel\n", - " if enable_watermarking\n", - " else transformers.GPT2LMHeadModel\n", - " )\n", - " model = model_cls.from_pretrained(model_name.value, device_map='auto')\n", - " case ModelName.GEMMA_2B | ModelName.GEMMA_7B:\n", - " model_cls = (\n", - " synthid_mixin.SynthIDGemmaForCausalLM\n", - " if enable_watermarking\n", - " else transformers.GemmaForCausalLM\n", - " )\n", - " model = model_cls.from_pretrained(\n", - " model_name.value,\n", - " device_map='auto',\n", - " torch_dtype=torch.bfloat16,\n", - " )\n", - "\n", - " if model.device != expected_device:\n", - " raise ValueError('Model device not as expected.')\n", - " return model\n", - "\n", - "\n", - "def _compute_perplexity(\n", - " outputs: torch.LongTensor,\n", - " scores: torch.FloatTensor,\n", - " eos_token_mask: torch.LongTensor,\n", - " watermarked: bool = False,\n", - ") -> float:\n", - " \"\"\"Compute perplexity given the model outputs and the logits.\"\"\"\n", - " len_offset = len(scores)\n", - " if watermarked:\n", - " nll_scores = scores\n", - " else:\n", - " nll_scores = [\n", - " torch.gather(\n", - " -torch.log(torch.nn.Softmax(dim=1)(sc)),\n", - " 1,\n", - " outputs[:, -len_offset + idx, None],\n", - " )\n", - " for idx, sc in enumerate(scores)\n", - " ]\n", - " nll_sum = torch.nan_to_num(\n", - " torch.squeeze(torch.stack(nll_scores, dim=1), dim=2)\n", - " * eos_token_mask.long(),\n", - " posinf=0,\n", - " )\n", - " nll_sum = nll_sum.sum(dim=1)\n", - " nll_mean = nll_sum / eos_token_mask.sum(dim=1)\n", - " return nll_mean.sum(dim=0)\n", - "\n", - "\n", - "def _process_raw_prompt(prompt: Sequence[str]) -> str:\n", - " \"\"\"Add chat template to the raw prompt.\"\"\"\n", - " match MODEL_NAME:\n", - " case ModelName.GPT2:\n", - " return prompt.decode().strip('\"')\n", - " case ModelName.GEMMA_2B | ModelName.GEMMA_7B:\n", - " return tokenizer.apply_chat_template(\n", - " [{'role': 'user', 'content': prompt.decode().strip('\"')}],\n", - " tokenize=False,\n", - " add_generation_prompt=True,\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Qs9Ih8r4Dyu5" - }, - "source": [ - "# 2. Applying a watermark" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "JJ28Aajwu9uD" - }, - "outputs": [], - "source": [ - "# @title Generate watermarked output\n", - "\n", - "gc.collect()\n", - "torch.cuda.empty_cache()\n", - "\n", - "batch_size = 1\n", - "example_inputs = [\n", - " 'I enjoy walking with my cute dog',\n", - " 'I am from New York',\n", - " 'The test was not so very hard after all',\n", - " \"I don't think they can score twice in so short a time\",\n", - "]\n", - "example_inputs = example_inputs * (int(batch_size / 4) + 1)\n", - "example_inputs = example_inputs[:batch_size]\n", - "\n", - "inputs = tokenizer(\n", - " example_inputs,\n", - " return_tensors='pt',\n", - " padding=True,\n", - ").to(DEVICE)\n", - "\n", - "model = load_model(MODEL_NAME, expected_device=DEVICE, enable_watermarking=True)\n", - "torch.manual_seed(0)\n", - "outputs = model.generate(\n", - " **inputs,\n", - " do_sample=True,\n", - " temperature=0.7,\n", - " max_length=1024,\n", - " top_k=40,\n", - ")\n", - "\n", - "print('Output:\\n' + 100 * '-')\n", - "for i, output in enumerate(outputs):\n", - " print(tokenizer.decode(output, skip_special_tokens=True))\n", - " print(100 * '-')\n", - "\n", - "del inputs, outputs, model\n", - "gc.collect()\n", - "torch.cuda.empty_cache()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Z6VJm-ZjJ3Q8" - }, - "source": [ - "## [Optional] Compare perplexity between watermarked and non-watermarked text\n", - "\n", - "Sample [eli5 dataset](https://facebookresearch.github.io/ELI5/) outputs from\n", - "watermarked and non-watermarked models and verify that:\n", - "\n", - "* The [perplexity](https://huggingface.co/docs/transformers/en/perplexity) of\n", - " watermarked and non-watermarked text is similar.\n", - "\n", - "$$\\text{PPL}(X) = \\exp \\left\\{ {-\\frac{1}{t}\\sum_i^t \\log p_\\theta (x_i|x_{= padded_length:\n", - " line = line[:padded_length]\n", - " else:\n", - " line = line + [\n", - " tokenizer.eos_token_id for _ in range(padded_length - len(line))\n", - " ]\n", - " batched.append(torch.tensor(line, dtype=torch.long, device=DEVICE)[None, :])\n", - " if len(batched) == NEG_BATCH_SIZE:\n", - " tokenized_uwm_outputs.append(torch.cat(batched, dim=0))\n", - " batched = []\n", - " if i > NUM_NEGATIVES:\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "UA6iSRKmklTM" - }, - "outputs": [], - "source": [ - "# @title Train the Bayesian detector\n", - "bayesian_detector, test_loss = (\n", - " detector_bayesian.BayesianDetector.train_best_detector(\n", - " tokenized_wm_outputs=wm_outputs,\n", - " tokenized_uwm_outputs=tokenized_uwm_outputs,\n", - " logits_processor=logits_processor,\n", - " tokenizer=tokenizer,\n", - " torch_device=DEVICE,\n", - " max_padded_length=MAX_PADDED_LENGTH,\n", - " pos_truncation_length=POS_TRUNCATION_LENGTH,\n", - " neg_truncation_length=NEG_TRUNCATION_LENGTH,\n", - " verbose=True,\n", - " learning_rate=3e-3,\n", - " n_epochs=100,\n", - " l2_weights=np.zeros((1,)),\n", - " )\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "collapsed": true, - "id": "wt_xWiSHkvX3" - }, - "outputs": [], - "source": [ - "# @title Get Bayesian detector scores for the generated outputs.\n", - "\n", - "# Watermarked responses tend to have higher Bayesian scores than unwatermarked\n", - "# responses. To classify responses you can set a score threshold, but this will\n", - "# depend on the distribution of scores for your use-case and your desired false\n", - "# positive / false negative rates. See the paper for full details.\n", - "\n", - "wm_bayesian_scores = bayesian_detector.score(\n", - " wm_g_values.cpu().numpy(), wm_mask.cpu().numpy()\n", - ")\n", - "uwm_bayesian_scores = bayesian_detector.score(\n", - " uwm_g_values.cpu().numpy(), uwm_mask.cpu().numpy()\n", - ")\n", - "\n", - "print('Bayesian scores for watermarked responses: ', wm_bayesian_scores)\n", - "print('Bayesian scores for unwatermarked responses: ', uwm_bayesian_scores)" - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "language_info": { - "name": "python" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Cthb8O3LCPM1" + }, + "source": [ + "# SynthID Text: Watermarking for Generated Text\n", + "\n", + "This notebook demonstrates how to use the [SynthID Text library][synthid-code]\n", + "to apply and detect watermarks on generated text. It is divided into three major\n", + "sections and intended to be run end-to-end.\n", + "\n", + "1. **_Setup_**: Importing the SynthID Text library, choosing your model (either\n", + " [Gemma][gemma] or [GPT-2][gpt2]) and device (either CPU or GPU, depending\n", + " on your runtime), defining the watermarking configuration, and initializing\n", + " some helper functions.\n", + "1. **_Applying a watermark_**: Loading your selected model using the\n", + " [Hugging Face Transformers][transformers] library, using that model to\n", + " generate some watermarked text, and comparing the perplexity of the\n", + " watermarked text to that of text generated by the base model.\n", + "1. **_Detecting a watermark_**: Training a detector to recognize text generated\n", + " with a specific watermarking configuration, and then using that detector to\n", + " predict whether a set of examples were generated with that configuration.\n", + "\n", + "As the reference implementation for the\n", + "[SynthID Text paper in _Nature_][synthid-paper], this library and notebook are\n", + "intended for research review and reproduction only. They should not be used in\n", + "production systems. For a production-grade implementation, check out the\n", + "official SynthID logits processor in [Hugging Face Transformers][transformers].\n", + "\n", + "[gemma]: https://ai.google.dev/gemma/docs/model_card\n", + "[gpt2]: https://huggingface.co/openai-community/gpt2\n", + "[synthid-code]: https://github.com/google-deepmind/synthid-text\n", + "[synthid-paper]: https://www.nature.com/\n", + "[transformers]: https://huggingface.co/docs/transformers/en/index" + ] }, - "nbformat": 4, - "nbformat_minor": 0 + { + "cell_type": "markdown", + "metadata": { + "id": "be-I0MNRbyWT" + }, + "source": [ + "# 1. Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "aq7hChW8njFo" + }, + "outputs": [], + "source": [ + "# @title Install and import the required Python packages\n", + "#\n", + "# @markdown Running this cell may require you to restart your session.\n", + "\n", + "! pip install synthid-text[notebook]\n", + "\n", + "from collections.abc import Sequence\n", + "import enum\n", + "import gc\n", + "\n", + "import datasets\n", + "import huggingface_hub\n", + "from synthid_text import detector_mean\n", + "from synthid_text import logits_processing\n", + "from synthid_text import synthid_mixin\n", + "from synthid_text import detector_bayesian\n", + "import tensorflow as tf\n", + "import torch\n", + "import tqdm\n", + "import transformers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "w9a5nANolFS_" + }, + "outputs": [], + "source": [ + "# @title Choose your model.\n", + "#\n", + "# @markdown This reference implementation is configured to use the Gemma v1.0\n", + "# @markdown Instruction-Tuned variants in 2B or 7B sizes, or GPT-2.\n", + "\n", + "\n", + "class ModelName(enum.Enum):\n", + " GPT2 = 'gpt2'\n", + " GEMMA_2B = 'google/gemma-2b-it'\n", + " GEMMA_7B = 'google/gemma-7b-it'\n", + "\n", + "\n", + "model_name = 'google/gemma-7b-it' # @param ['gpt2', 'google/gemma-2b-it', 'google/gemma-7b-it']\n", + "MODEL_NAME = ModelName(model_name)\n", + "\n", + "if MODEL_NAME is not ModelName.GPT2:\n", + " huggingface_hub.notebook_login()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "B_pe-hG6SW6H" + }, + "outputs": [], + "source": [ + "# @title Configure your device\n", + "#\n", + "# @markdown This notebook loads models from Hugging Face Transformers into the\n", + "# @markdown PyTorch deep learning runtime. PyTorch supports generation on CPU or\n", + "# @markdown GPU, but your chosen model will run best on the following hardware,\n", + "# @markdown some of which may require a\n", + "# @markdown [Colab Subscription](https://colab.research.google.com/signup).\n", + "# @markdown\n", + "# @markdown * Gemma v1.0 2B IT: Use a GPU with 16GB of memory, such as a T4.\n", + "# @markdown * Gemma v1.0 7B IT: Use a GPU with 32GB of memory, such as an A100.\n", + "# @markdown * GPT-2: Any runtime will work, though a High-RAM CPU or any GPU\n", + "# @markdown will be faster.\n", + "\n", + "DEVICE = (\n", + " torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n", + ")\n", + "DEVICE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "UOGvCjyVjjQ5" + }, + "outputs": [], + "source": [ + "# @title Example watermarking config\n", + "#\n", + "# @markdown SynthID Text produces unique watermarks given a configuration, with\n", + "# @markdown the most important piece of a configuration being the `keys`: a\n", + "# @markdown sequence of unique integers.\n", + "# @markdown\n", + "# @markdown This reference implementation uses a fixed watermarking\n", + "# @markdown configuration, which will be displayed when you run this cell.\n", + "\n", + "CONFIG = synthid_mixin.DEFAULT_WATERMARKING_CONFIG\n", + "CONFIG" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "79mekKj5UUZR" + }, + "outputs": [], + "source": [ + "# @title Initialize the required constants, tokenizer, and logits processor\n", + "\n", + "BATCH_SIZE = 8\n", + "NUM_BATCHES = 320\n", + "OUTPUTS_LEN = 1024\n", + "TEMPERATURE = 0.5\n", + "TOP_K = 40\n", + "TOP_P = 0.99\n", + "\n", + "tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME.value)\n", + "tokenizer.pad_token = tokenizer.eos_token\n", + "tokenizer.padding_side = \"left\"\n", + "\n", + "logits_processor = logits_processing.SynthIDLogitsProcessor(\n", + " **CONFIG, top_k=TOP_K, temperature=TEMPERATURE\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "hndT3YCQUt6D" + }, + "outputs": [], + "source": [ + "# @title Utility functions to load models, compute perplexity, and process prompts.\n", + "\n", + "\n", + "def load_model(\n", + " model_name: ModelName,\n", + " expected_device: torch.device,\n", + " enable_watermarking: bool = False,\n", + ") -> transformers.PreTrainedModel:\n", + " if model_name == ModelName.GPT2:\n", + " model_cls = (\n", + " synthid_mixin.SynthIDGPT2LMHeadModel\n", + " if enable_watermarking\n", + " else transformers.GPT2LMHeadModel\n", + " )\n", + " model = model_cls.from_pretrained(model_name.value, device_map='auto')\n", + " else:\n", + " model_cls = (\n", + " synthid_mixin.SynthIDGemmaForCausalLM\n", + " if enable_watermarking\n", + " else transformers.GemmaForCausalLM\n", + " )\n", + " model = model_cls.from_pretrained(\n", + " model_name.value,\n", + " device_map='auto',\n", + " torch_dtype=torch.bfloat16,\n", + " )\n", + "\n", + " if str(model.device) != str(expected_device):\n", + " raise ValueError('Model device not as expected.')\n", + "\n", + " return model\n", + "\n", + "\n", + "def _compute_perplexity(\n", + " outputs: torch.LongTensor,\n", + " scores: torch.FloatTensor,\n", + " eos_token_mask: torch.LongTensor,\n", + " watermarked: bool = False,\n", + ") -> float:\n", + " \"\"\"Compute perplexity given the model outputs and the logits.\"\"\"\n", + " len_offset = len(scores)\n", + " if watermarked:\n", + " nll_scores = scores\n", + " else:\n", + " nll_scores = [\n", + " torch.gather(\n", + " -torch.log(torch.nn.Softmax(dim=1)(sc)),\n", + " 1,\n", + " outputs[:, -len_offset + idx, None],\n", + " )\n", + " for idx, sc in enumerate(scores)\n", + " ]\n", + " nll_sum = torch.nan_to_num(\n", + " torch.squeeze(torch.stack(nll_scores, dim=1), dim=2)\n", + " * eos_token_mask.long(),\n", + " posinf=0,\n", + " )\n", + " nll_sum = nll_sum.sum(dim=1)\n", + " nll_mean = nll_sum / eos_token_mask.sum(dim=1)\n", + " return nll_mean.sum(dim=0)\n", + "\n", + "\n", + "def _process_raw_prompt(prompt: Sequence[str]) -> str:\n", + " \"\"\"Add chat template to the raw prompt.\"\"\"\n", + " if MODEL_NAME == ModelName.GPT2:\n", + " return prompt.decode().strip('\"')\n", + " else:\n", + " return tokenizer.apply_chat_template(\n", + " [{'role': 'user', 'content': prompt.decode().strip('\"')}],\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Qs9Ih8r4Dyu5" + }, + "source": [ + "# 2. Applying a watermark" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "JJ28Aajwu9uD" + }, + "outputs": [], + "source": [ + "# @title Generate watermarked output\n", + "\n", + "gc.collect()\n", + "torch.cuda.empty_cache()\n", + "\n", + "batch_size = 1\n", + "example_inputs = [\n", + " 'I enjoy walking with my cute dog',\n", + " 'I am from New York',\n", + " 'The test was not so very hard after all',\n", + " \"I don't think they can score twice in so short a time\",\n", + "]\n", + "example_inputs = example_inputs * (int(batch_size / 4) + 1)\n", + "example_inputs = example_inputs[:batch_size]\n", + "\n", + "inputs = tokenizer(\n", + " example_inputs,\n", + " return_tensors='pt',\n", + " padding=True,\n", + ").to(DEVICE)\n", + "\n", + "model = load_model(MODEL_NAME, expected_device=DEVICE, enable_watermarking=True)\n", + "torch.manual_seed(0)\n", + "outputs = model.generate(\n", + " **inputs,\n", + " do_sample=True,\n", + " temperature=0.7,\n", + " max_length=1024,\n", + " top_k=40,\n", + ")\n", + "\n", + "print('Output:\\n' + 100 * '-')\n", + "for i, output in enumerate(outputs):\n", + " print(tokenizer.decode(output, skip_special_tokens=True))\n", + " print(100 * '-')\n", + "\n", + "del inputs, outputs, model\n", + "gc.collect()\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z6VJm-ZjJ3Q8" + }, + "source": [ + "## [Optional] Compare perplexity between watermarked and non-watermarked text\n", + "\n", + "Sample [eli5 dataset](https://facebookresearch.github.io/ELI5/) outputs from\n", + "watermarked and non-watermarked models and verify that:\n", + "\n", + "* The [perplexity](https://huggingface.co/docs/transformers/en/perplexity) of\n", + " watermarked and non-watermarked text is similar.\n", + "\n", + "$$\\text{PPL}(X) = \\exp \\left\\{ {-\\frac{1}{t}\\sum_i^t \\log p_\\theta (x_i|x_{= padded_length:\n", + " line = line[:padded_length]\n", + " else:\n", + " line = line + [\n", + " tokenizer.eos_token_id for _ in range(padded_length - len(line))\n", + " ]\n", + " batched.append(torch.tensor(line, dtype=torch.long, device=DEVICE)[None, :])\n", + " if len(batched) == NEG_BATCH_SIZE:\n", + " tokenized_uwm_outputs.append(torch.cat(batched, dim=0))\n", + " batched = []\n", + " if i > NUM_NEGATIVES:\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "UA6iSRKmklTM" + }, + "outputs": [], + "source": [ + "# @title Train the Bayesian detector\n", + "bayesian_detector, test_loss = (\n", + " detector_bayesian.BayesianDetector.train_best_detector(\n", + " tokenized_wm_outputs=wm_outputs,\n", + " tokenized_uwm_outputs=tokenized_uwm_outputs,\n", + " logits_processor=logits_processor,\n", + " tokenizer=tokenizer,\n", + " torch_device=DEVICE,\n", + " max_padded_length=MAX_PADDED_LENGTH,\n", + " pos_truncation_length=POS_TRUNCATION_LENGTH,\n", + " neg_truncation_length=NEG_TRUNCATION_LENGTH,\n", + " verbose=True,\n", + " learning_rate=3e-3,\n", + " n_epochs=100,\n", + " l2_weights=np.zeros((1,)),\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "collapsed": true, + "id": "wt_xWiSHkvX3" + }, + "outputs": [], + "source": [ + "# @title Get Bayesian detector scores for the generated outputs.\n", + "\n", + "# Watermarked responses tend to have higher Bayesian scores than unwatermarked\n", + "# responses. To classify responses you can set a score threshold, but this will\n", + "# depend on the distribution of scores for your use-case and your desired false\n", + "# positive / false negative rates. See the paper for full details.\n", + "\n", + "wm_bayesian_scores = bayesian_detector.score(\n", + " wm_g_values.cpu().numpy(), wm_mask.cpu().numpy()\n", + ")\n", + "uwm_bayesian_scores = bayesian_detector.score(\n", + " uwm_g_values.cpu().numpy(), uwm_mask.cpu().numpy()\n", + ")\n", + "\n", + "print('Bayesian scores for watermarked responses: ', wm_bayesian_scores)\n", + "print('Bayesian scores for unwatermarked responses: ', uwm_bayesian_scores)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/src/synthid_text/detector_bayesian.py b/src/synthid_text/detector_bayesian.py index 73411b5..e87e239 100644 --- a/src/synthid_text/detector_bayesian.py +++ b/src/synthid_text/detector_bayesian.py @@ -20,7 +20,7 @@ import enum import functools import gc -from typing import Any +from typing import Any, Optional, Union import flax.linen as nn import jax import jax.numpy as jnp @@ -64,7 +64,7 @@ def pad_to_len( def filter_and_truncate( outputs: torch.tensor, - truncation_length: int | None, + truncation_length: Optional[int], eos_token_mask: torch.tensor, ) -> torch.tensor: """Filter and truncate outputs to given length. @@ -90,8 +90,8 @@ def process_outputs_for_training( logits_processor: logits_processing.SynthIDLogitsProcessor, tokenizer: Any, *, - pos_truncation_length: int | None, - neg_truncation_length: int | None, + pos_truncation_length: Optional[int], + neg_truncation_length: Optional[int], max_length: int, is_cv: bool, is_pos: bool, @@ -233,7 +233,7 @@ class LikelihoodModelWatermarked(nn.Module, LikelihoodModel): """ watermarking_depth: int - params: Mapping[str, Mapping[str, Any]] | None = None + params: Optional[Mapping[str, Mapping[str, Any]]] = None def setup(self): """Initializes the model parameters.""" @@ -394,7 +394,7 @@ class BayesianDetectorModule(nn.Module): """ watermarking_depth: int # The number of tournament layers. - params: Mapping[str, Mapping[str, Any]] | None = None + params: Optional[Mapping[str, Mapping[str, Any]]] = None baserate: float = 0.5 # Prior probability P(w) that a text is watermarked. @property @@ -442,7 +442,7 @@ def __call__( def score( self, - g_values: jnp.ndarray | Sequence[jnp.ndarray], + g_values: Union[jnp.ndarray, Sequence[jnp.ndarray]], mask: jnp.ndarray, ) -> jnp.ndarray: if self.params is None: @@ -522,9 +522,9 @@ def train( seed: int = 0, l2_weight: float = 0.0, shuffle: bool = True, - g_values_val: jnp.ndarray | None = None, - mask_val: jnp.ndarray | None = None, - watermarked_val: jnp.ndarray | None = None, + g_values_val: Optional[jnp.ndarray] = None, + mask_val: Optional[jnp.ndarray] = None, + watermarked_val: Optional[jnp.ndarray] = None, verbose: bool = False, validation_metric: ValidationMetric = ValidationMetric.TPR_AT_FPR, ) -> tuple[Mapping[int, Mapping[str, PyTree]], float]: @@ -761,14 +761,14 @@ def score(self, outputs: jnp.ndarray) -> jnp.ndarray: def process_raw_model_outputs( cls, *, - tokenized_wm_outputs: Sequence[np.ndarray] | np.ndarray, - tokenized_uwm_outputs: Sequence[np.ndarray] | np.ndarray, + tokenized_wm_outputs: Union[Sequence[np.ndarray], np.ndarray], + tokenized_uwm_outputs: Union[Sequence[np.ndarray], np.ndarray], logits_processor: logits_processing.SynthIDLogitsProcessor, tokenizer: Any, torch_device: torch.device, test_size: float = 0.3, - pos_truncation_length: int | None = 200, - neg_truncation_length: int | None = 100, + pos_truncation_length: Optional[int] = 200, + neg_truncation_length: Optional[int] = 100, max_padded_length: int = 2300, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Process raw models outputs into inputs we can train. @@ -986,14 +986,14 @@ def train_best_detector_given_g_values( def train_best_detector( cls, *, - tokenized_wm_outputs: Sequence[np.ndarray] | np.ndarray, - tokenized_uwm_outputs: Sequence[np.ndarray] | np.ndarray, + tokenized_wm_outputs: Union[Sequence[np.ndarray], np.ndarray], + tokenized_uwm_outputs: Union[Sequence[np.ndarray], np.ndarray], logits_processor: logits_processing.SynthIDLogitsProcessor, tokenizer: Any, torch_device: torch.device, test_size: float = 0.3, - pos_truncation_length: int | None = 200, - neg_truncation_length: int | None = 100, + pos_truncation_length: Optional[int] = 200, + neg_truncation_length: Optional[int] = 100, max_padded_length: int = 2300, n_epochs: int = 50, learning_rate: float = 2.1e-2, diff --git a/src/synthid_text/detector_mean.py b/src/synthid_text/detector_mean.py index a961797..e022d9e 100644 --- a/src/synthid_text/detector_mean.py +++ b/src/synthid_text/detector_mean.py @@ -15,6 +15,7 @@ """Code for Mean and Weighted Mean scoring functions.""" +from typing import Optional import jax.numpy as jnp @@ -43,7 +44,7 @@ def mean_score( def weighted_mean_score( g_values: jnp.ndarray, mask: jnp.ndarray, - weights: jnp.ndarray | None = None, + weights: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """Computes the Weighted Mean score. diff --git a/src/synthid_text/synthid_mixin.py b/src/synthid_text/synthid_mixin.py index d92752c..5c723dd 100644 --- a/src/synthid_text/synthid_mixin.py +++ b/src/synthid_text/synthid_mixin.py @@ -16,7 +16,7 @@ """SynthID watermarked mixin class.""" from collections.abc import Mapping -from typing import Any +from typing import Any, Optional, Union import immutabledict import torch @@ -136,10 +136,12 @@ def _sample( stopping_criteria: transformers.StoppingCriteriaList, generation_config: transformers.GenerationConfig, synced_gpus: bool, - streamer: "transformers.BaseStreamer | None", - logits_warper: transformers.LogitsProcessorList | None = None, + streamer: Optional["transformers.BaseStreamer"], + logits_warper: Optional[transformers.LogitsProcessorList] = None, **model_kwargs, - ) -> transformers.generation.utils.GenerateNonBeamOutput | torch.LongTensor: + ) -> Union[ + transformers.generation.utils.GenerateNonBeamOutput, torch.LongTensor + ]: r"""Sample sequence of tokens. Generates sequences of token ids for models with a language modeling head diff --git a/src/synthid_text/synthid_mixin_test.py b/src/synthid_text/synthid_mixin_test.py index e9f0050..0aae3a5 100644 --- a/src/synthid_text/synthid_mixin_test.py +++ b/src/synthid_text/synthid_mixin_test.py @@ -25,7 +25,7 @@ from . import logits_processing -@dataclasses.dataclass(frozen=True, kw_only=True) +@dataclasses.dataclass(frozen=True) class Config: is_encoder_decoder: bool = True From 13dc0ad35adccc5e30ed2c4f56e1a08efd3e5698 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Mon, 11 Nov 2024 15:38:01 +0000 Subject: [PATCH 2/3] Use concrete imports. Rename test helper func to avoid pickup by pytest. --- src/synthid_text/detector_bayesian.py | 2 +- src/synthid_text/logits_processing.py | 2 +- src/synthid_text/logits_processing_test.py | 12 ++++++------ src/synthid_text/synthid_mixin.py | 2 +- src/synthid_text/synthid_mixin_test.py | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/synthid_text/detector_bayesian.py b/src/synthid_text/detector_bayesian.py index e87e239..955424b 100644 --- a/src/synthid_text/detector_bayesian.py +++ b/src/synthid_text/detector_bayesian.py @@ -31,7 +31,7 @@ import torch import tqdm -from . import logits_processing +from synthid_text import logits_processing def pad_to_len( diff --git a/src/synthid_text/logits_processing.py b/src/synthid_text/logits_processing.py index 8a10775..dea3da8 100644 --- a/src/synthid_text/logits_processing.py +++ b/src/synthid_text/logits_processing.py @@ -20,7 +20,7 @@ import torch import transformers -from . import hashing_function +from synthid_text import hashing_function def update_scores( diff --git a/src/synthid_text/logits_processing_test.py b/src/synthid_text/logits_processing_test.py index 881ac68..33a59d4 100644 --- a/src/synthid_text/logits_processing_test.py +++ b/src/synthid_text/logits_processing_test.py @@ -22,12 +22,12 @@ import torch import tqdm -from . import logits_processing -from . import g_value_expectations -from . import torch_testing +from synthid_text import logits_processing +from synthid_text import g_value_expectations +from synthid_text import torch_testing -def test_mean_g_value_matches_theoretical( +def does_mean_g_value_matches_theoretical( vocab_size: int, ngram_len: int, batch_size: int, @@ -307,7 +307,7 @@ def test_bias_from_logits_processor( ): """Check if watermarked distribution converges to input distribution.""" device = torch_testing.torch_device() - result = test_mean_g_value_matches_theoretical( + mean, expected, passes = does_mean_g_value_matches_theoretical( vocab_size=vocab_size, ngram_len=ngram_len, batch_size=20_000, @@ -316,7 +316,7 @@ def test_bias_from_logits_processor( device=device, num_leaves=num_leaves, ) - self.assertTrue(result[2]) + self.assertTrue(passes) class LogitsProcessorTest(absltest.TestCase): diff --git a/src/synthid_text/synthid_mixin.py b/src/synthid_text/synthid_mixin.py index 5c723dd..eb0bf14 100644 --- a/src/synthid_text/synthid_mixin.py +++ b/src/synthid_text/synthid_mixin.py @@ -22,7 +22,7 @@ import torch import transformers -from . import logits_processing +from synthid_text import logits_processing DEFAULT_WATERMARKING_CONFIG = immutabledict.immutabledict({ diff --git a/src/synthid_text/synthid_mixin_test.py b/src/synthid_text/synthid_mixin_test.py index 0aae3a5..8626553 100644 --- a/src/synthid_text/synthid_mixin_test.py +++ b/src/synthid_text/synthid_mixin_test.py @@ -21,8 +21,8 @@ import transformers from transformers import utils as transformers_utils -from . import synthid_mixin -from . import logits_processing +from synthid_text import synthid_mixin +from synthid_text import logits_processing @dataclasses.dataclass(frozen=True) From dd021e8169afab2ec26af8cbb338062f301b8a36 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Mon, 11 Nov 2024 15:50:19 +0000 Subject: [PATCH 3/3] Fixes #10. Runs PyTest for PRs with GitHub Actions --- .github/workflows/ci.yaml | 44 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 .github/workflows/ci.yaml diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..acd5fc3 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,44 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +name: CI + +on: + push: + branches: [main, dev] + pull_request: + branches: [main, dev] + +jobs: + build-and-test: + name: Build and test (Python ${{ matrix.python-version }}) + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11"] + defaults: + run: + shell: bash -l {0} + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install SynthID Text package with testing dependencies + run: python -m pip install -e '.[test]' + - name: Test SynthID Text + run: pytest -v