Skip to content

Commit

Permalink
Enable conversational pipeline for GPTSw3Tokenizer (#24648)
Browse files Browse the repository at this point in the history
* feat: Add `_build_conversation_input_ids` to GPT-SW3 tokenizer, adjust line length

* feat: Merge in PR #24504.

This allows the GPT-SW3 models (and other GPT-2 based models) to be 4-bit quantised
using `load_in_4bit` with `bitsandbytes`.

* fix: F-string

* fix: F-string

* fix: Remove EOS token from all responses

* fix: Remove redundant newlines

* feat: Add `load_in_4bit` to `Pipeline`

* fix: Separate turns with `\n<s>\n` rather than `<s>`

* fix: Add missing newline in prompt

* tests: Add unit tests for the new `_build_conversation_input_ids` method

* style: Automatic style correction

* tests: Compare encodings rather than decodings

* fix: Remove `load_in_4bit` from pipeline arguments

* docs: Add description and references of the GPT-SW3 chat format

* style: Line breaks

* Apply suggestions from code review

Fix Conversation type hints

Co-authored-by: Arthur <[email protected]>

* fix: Import TYPE_CHECKING

* style: Run automatic fixes

* tests: Remove `_build_conversation_input_ids` unit tests

* tests: Remove import of `Conversation` in GPT-SW3 unit test

* style: Revert formatting

* style: Move TYPE_CHECKING line after all imports

* style: Imports order

* fix: Change prompt to ensure that `sp_model.encode` and `encode` yields same result

* docs: Add TODO comment related to the addition of whitespace during decoding

* style: Automatic style checks

* fix: Remove final whitespace in prompt, as prefix whitespace is used by sentencepiece

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
saattrupdan and ArthurZucker authored Jul 7, 2023
1 parent f614b6e commit abaca9f
Showing 1 changed file with 41 additions and 7 deletions.
48 changes: 41 additions & 7 deletions src/transformers/models/gpt_sw3/tokenization_gpt_sw3.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
"""The tokenizer used by the GPT-SW3 models."""

import os
import re
import unicodedata
from shutil import copyfile
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from ... import is_torch_available
import sentencepiece as spm

from ...tokenization_utils import PreTrainedTokenizer
from ...utils import is_torch_available, logging


if is_torch_available():
import torch

from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union

import sentencepiece as spm

from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
if TYPE_CHECKING:
from transformers.pipelines.conversational import Conversation


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -230,8 +233,10 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str:
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
# TODO: Check if this is needed, as it ensures that decode(encode(doc)) != doc by adding extra whitespace in the decoded document
if not prev_is_special:
out_string += " "

out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
Expand Down Expand Up @@ -312,3 +317,32 @@ def decode_fast(self, token_ids: Union[int, List[int]]) -> str:
"""

return self.sp_model.decode(token_ids)

def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
"""Builds the input ids for a conversation.
This is the format used in the original GPT-SW3 paper [1] and which is also mentioned in the model card [2].
The format is inspired by the ChatML format [3]. Concretely, the chat format is set up as follows:
```
<eos><bos>User: Jag tycker träd är fina<bos>Bot: Kul att du tycker det!<bos>...
```
Args:
conversation (`Conversation`):
Conversation to build input ids for.
Returns:
`List[int]`:
Input ids for the conversation.
References:
- [1] https://doi.org/10.48550/arXiv.2305.12987
- [2] https://huggingface.co/AI-Sweden-Models/gpt-sw3-126m-instruct
- [3] https://github.com/openai/openai-python/blob/main/chatml.md
"""
all_responses = [f"User: {text}" if is_user else f"Bot: {text}" for is_user, text in conversation.iter_texts()]
prompt = (
f"{self.eos_token}{self.bos_token}" + f"{self.bos_token}".join(all_responses) + f"{self.bos_token}Bot:"
)
return self.encode(text=prompt)

0 comments on commit abaca9f

Please sign in to comment.