Skip to content

Commit

Permalink
Make GitTokenizer call args are similar to Blip
Browse files Browse the repository at this point in the history
  • Loading branch information
st81 committed Mar 10, 2024
1 parent 5c44259 commit 868bc4d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 75 deletions.
65 changes: 26 additions & 39 deletions src/transformers/models/git/tokenization_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for GIT."""
from typing import List
from typing import List, Optional, Union

from ...tokenization_utils_base import BatchEncoding, TruncationStrategy
from ...utils import (
PaddingStrategy,
TensorType,
is_torch_available,
logging,
)
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType, is_torch_available, logging
from ..bert.tokenization_bert import BertTokenizer


Expand Down Expand Up @@ -65,48 +60,40 @@ class GitTokenizer(BertTokenizer):

def __call__(
self,
text: str | List[str] | List[List[str]] = None,
text_pair: str | List[str] | List[List[str]] | None = None,
text_target: str | List[str] | List[List[str]] = None,
text_pair_target: str | List[str] | List[List[str]] | None = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = True,
padding: bool | str | PaddingStrategy = False,
truncation: bool | str | TruncationStrategy = None,
max_length: int | None = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: int | None = None,
return_tensors: str | TensorType | None = None,
return_token_type_ids: bool | None = None,
return_attention_mask: bool | None = None,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchEncoding:
add_special_tokens = False
encodings = super().__call__(
text,
text_pair,
text_target,
text_pair_target,
add_special_tokens,
padding,
truncation,
max_length,
stride,
is_split_into_words,
pad_to_multiple_of,
return_tensors,
return_token_type_ids,
return_attention_mask,
return_overflowing_tokens,
return_special_tokens_mask,
return_offsets_mapping,
return_length,
verbose,
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
input_ids = encodings["input_ids"]
Expand Down
62 changes: 26 additions & 36 deletions src/transformers/models/git/tokenization_git_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for GIT."""
from typing import List
from typing import List, Optional, Union

from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy
from transformers.utils.generic import PaddingStrategy, TensorType

from ...utils import is_torch_available, logging
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType, is_torch_available, logging
from ..bert.tokenization_bert_fast import BertTokenizerFast
from .tokenization_git import GitTokenizer

Expand Down Expand Up @@ -66,48 +64,40 @@ class GitTokenizerFast(BertTokenizerFast):

def __call__(
self,
text: str | List[str] | List[List[str]] = None,
text_pair: str | List[str] | List[List[str]] | None = None,
text_target: str | List[str] | List[List[str]] = None,
text_pair_target: str | List[str] | List[List[str]] | None = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = True,
padding: bool | str | PaddingStrategy = False,
truncation: bool | str | TruncationStrategy = None,
max_length: int | None = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: int | None = None,
return_tensors: str | TensorType | None = None,
return_token_type_ids: bool | None = None,
return_attention_mask: bool | None = None,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchEncoding:
add_special_tokens = False
encodings = super().__call__(
text,
text_pair,
text_target,
text_pair_target,
add_special_tokens,
padding,
truncation,
max_length,
stride,
is_split_into_words,
pad_to_multiple_of,
return_tensors,
return_token_type_ids,
return_attention_mask,
return_overflowing_tokens,
return_special_tokens_mask,
return_offsets_mapping,
return_length,
verbose,
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
input_ids = encodings["input_ids"]
Expand Down

0 comments on commit 868bc4d

Please sign in to comment.