diff --git a/pyrit/prompt_converter/insert_punctuation_converter.py b/pyrit/prompt_converter/insert_punctuation_converter.py new file mode 100644 index 000000000..7c79cd5d4 --- /dev/null +++ b/pyrit/prompt_converter/insert_punctuation_converter.py @@ -0,0 +1,155 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import random +import string +import re +from typing import List, Optional + +from pyrit.models import PromptDataType +from pyrit.prompt_converter import PromptConverter, ConverterResult + + +class InsertPunctuationConverter(PromptConverter): + """ + Inserts punctuation into a prompt to test robustness. + Punctuation insertion: inserting single punctuations in string.punctuation. + Words in a prompt: a word does not contain any punctuation and space. + "a1b2c3" is a word; "a1 2" are 2 words; "a1,b,3" are 3 words. + """ + + default_punctuation_list = [",", ".", "!", "?", ":", ";", "-"] + + def __init__(self, word_swap_ratio: float = 0.2, between_words: bool = True) -> None: + """ + Initialize the converter with optional and word swap ratio. + Args: + word_swap_ratio (float): Percentage of words to perturb. Defaults to 0.2. + between_words (bool): If True, insert punctuation only between words. + If False, insert punctuation within words. Defaults to True. + """ + # Swap ratio cannot be 0 or larger than 1 + if not 0 < word_swap_ratio <= 1: + raise ValueError("word_swap_ratio must be between 0 to 1, as (0, 1].") + + self._word_swap_ratio = word_swap_ratio + self._between_words = between_words + + def _is_valid_punctuation(self, punctuation_list: List[str]) -> bool: + """ + Check if all items in the list are valid punctuation characters in string.punctuation. + Space, letters, numbers, double punctuations are all invalid. + Args: + punctuation_list (List[str]): List of punctuations to validate. + Returns: + bool: valid list and valid punctuations + """ + return all(str in string.punctuation for str in punctuation_list) + + async def convert_async( + self, *, prompt: str, input_type: PromptDataType = "text", punctuation_list: Optional[List[str]] = None + ) -> ConverterResult: + """ + Convert the given prompt by inserting punctuation. + Args: + prompt (str): The text to convert. + input_type (PromptDataType): The type of input data. + punctuation_list (Optional[List[str]]): List of punctuations to use for insertion. + Returns: + ConverterResult: A ConverterResult containing a interation of modified prompts. + """ + if not self.input_supported(input_type): + raise ValueError("Input type not supported") + + # Initialize default punctuation list + # If not specified, defaults to default_punctuation_list + if punctuation_list is None: + punctuation_list = self.default_punctuation_list + elif not self._is_valid_punctuation(punctuation_list): + raise ValueError( + f"Invalid punctuations: {punctuation_list}." + f" Only single characters from {string.punctuation} are allowed." + ) + + modified_prompt = self._insert_punctuation(prompt, punctuation_list) + return ConverterResult(output_text=modified_prompt, output_type="text") + + def _insert_punctuation(self, prompt: str, punctuation_list: List[str]) -> str: + """ + Insert punctuation into the prompt. + Args: + prompt (str): The text to modify. + punctuation_list (List[str]): List of punctuations for insertion. + Returns: + str: The modified prompt with inserted punctuation from helper method. + """ + # Words list contains single spaces, single word without punctuations, single punctuations + words = re.findall(r"\w+|[^\w\s]|\s", prompt) + # Maintains indices for actual "words", i.e. letters and numbers not divided by punctuations + word_indices = [i for i in range(0, len(words)) if not re.match(r"\W", words[i])] + # Calculate the number of insertions + num_insertions = max( + 1, round(len(word_indices) * self._word_swap_ratio) + ) # Ensure at least one punctuation is inserted + + # If there's no actual word without punctuations in the list, insert random punctuation at position 0 + if not word_indices: + return random.choice(punctuation_list) + prompt + + if self._between_words: + return self._insert_between_words(words, word_indices, num_insertions, punctuation_list) + else: + return self._insert_within_words(prompt, num_insertions, punctuation_list) + + def _insert_between_words( + self, words: List[str], word_indices: List[int], num_insertions: int, punctuation_list: List[str] + ) -> str: + """ + Insert punctuation between words in the prompt. + Args: + words (List[str]): List of words and punctuations. + word_indices (List[int]): Indices of the actual words without punctuations in words list. + num_insertions (int): Number of punctuations to insert. + punctuation_list (List[str]): punctuations for insertion. + + Returns: + str: The modified prompt with inserted punctuation. + """ + insert_indices = random.sample(word_indices, num_insertions) + # Randomly choose num_insertions indices from actual word indices. + INSERT_BEFORE = 0 + INSERT_AFTER = 1 + for index in insert_indices: + if random.randint(INSERT_BEFORE, INSERT_AFTER) == INSERT_AFTER: + words[index] += random.choice(punctuation_list) + else: + words[index] = random.choice(punctuation_list) + words[index] + # Join the words list and return a modified prompt + return "".join(words).strip() + + def _insert_within_words(self, prompt: str, num_insertions: int, punctuation_list: List[str]) -> str: + """ + Insert punctuation at any indices in the prompt, can insert into a word. + Args: + promp str: The prompt string + num_insertions (int): Number of punctuations to insert. + punctuation_list (List[str]): punctuations for insertion. + Returns: + str: The modified prompt with inserted punctuation. + """ + # List of chars in the prompt string + prompt_list = list(prompt) + # Store random indices of prompt_list into insert_indices + # If the prompt has only 0 or 1 chars, insert at the end of the prompt + insert_indices = ( + [1] if len(prompt_list) <= num_insertions else random.sample(range(0, len(prompt_list) - 1), num_insertions) + ) + + for index in insert_indices: + # Insert into prompt_list at the insert_indices with random punctuation from the punctuation_list + prompt_list.insert(index, random.choice(punctuation_list)) + + return "".join(prompt_list).strip() + + def input_supported(self, input_type: PromptDataType) -> bool: + return input_type == "text" diff --git a/tests/converter/test_insert_punctuation_converter.py b/tests/converter/test_insert_punctuation_converter.py new file mode 100644 index 000000000..bd62ed3ff --- /dev/null +++ b/tests/converter/test_insert_punctuation_converter.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import re +import pytest +from pyrit.prompt_converter.insert_punctuation_converter import InsertPunctuationConverter + + +# Test for correctness +# Long prompt, short prompt, weird spacing and punctuation, non-wordy prompt, short and empty prompt +@pytest.mark.parametrize( + "input_prompt,between_words,punctuation_list,word_swap_ratio,expected_punctuation_count", + [ + ("Despite the rain, we decided to go hiking; it was a refreshing experience.", True, [",", "!", "]"], 1, 16), + ("Quick!", False, [",", "~", "]"], 1, 2), + (" Hello, world! ", True, [",", "[", "]"], 0.3, 3), + ("....", True, [",", "[", ">"], 0.2, 5), + ("Numbers are also words, 1234 not intuitive, not symbols $@.", True, [",", "[", "]"], 0.6, 10), + ("", True, [",", "$", "]"], 0.9, 1), + ("a b", False, [",", "^", "]"], 1, 2), + ("I can't wait!!!", False, [",", "/", "]"], 0.4, 6), + ], +) +@pytest.mark.asyncio +async def test_word_swap_ratio( + input_prompt, between_words, punctuation_list, word_swap_ratio, expected_punctuation_count +): + converter = InsertPunctuationConverter(word_swap_ratio=word_swap_ratio, between_words=between_words) + result = await converter.convert_async(prompt=input_prompt, punctuation_list=punctuation_list) + modified_prompt = result.output_text + assert ( + punctuation_count := len(re.findall(r"[^\w\s]", modified_prompt)) + ) == expected_punctuation_count, ( + f"Expect {expected_punctuation_count} punctuations found in prompt: {punctuation_count}" + ) + + +# Test default swap ratio = 0.2 +@pytest.mark.parametrize( + "input_prompt, expected_punctuation_count", + [("count 1 2 3 4 5 6 7 8 9 and 10.", 3), ("Aha!", 2)], +) +@pytest.mark.asyncio +async def test_default_swap(input_prompt, expected_punctuation_count): + converter = InsertPunctuationConverter() + result = await converter.convert_async(prompt=input_prompt) + modified_prompt = result.output_text + assert ( + punctuation_count := len(re.findall(r"[^\w\s]", modified_prompt)) + ) == expected_punctuation_count, ( + f"Expect {expected_punctuation_count} punctuations found in prompt: {punctuation_count}" + ) + + +# Test value error raising for invalid swap ratio +@pytest.mark.parametrize( + "word_swap_ratio", + [-0.1, 1.5], +) +@pytest.mark.asyncio +async def test_invalid_word_swap_ratio(word_swap_ratio): + with pytest.raises(ValueError): + InsertPunctuationConverter(word_swap_ratio=word_swap_ratio) + + +# Test value error raising for invalid punctuations +@pytest.mark.parametrize( + "punctuation_list", + ["~~", " ", "1", "a", "//"], +) +@pytest.mark.asyncio +async def test_invalid_punctuation_list(punctuation_list): + with pytest.raises(ValueError): + converter = InsertPunctuationConverter() + await converter.convert_async(prompt="prompt", punctuation_list=[punctuation_list])