Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Support packing to a max input sequence length with cudf-subword tokenizer #6089

Open
VibhuJawa opened this issue Aug 25, 2020 · 1 comment
Assignees
Labels
feature request New feature or request libcudf Affects libcudf (C++/CUDA) code. Python Affects Python cuDF API. strings strings issues (C++ and Python)

Comments

@VibhuJawa
Copy link
Member

VibhuJawa commented Aug 25, 2020

Is your feature request related to a problem? Please describe.

Currently, the tokenized string is shorter than max_length, output is be padded with 0s. So if max( tokenized string lengths) < max_length, it leads to performance penalties as the compute time for Transformer models is often proportional to the sequence length of the input .

HuggingFace's tokenizer defaults to padding to max input sequence length if max_length and pad_to_max_length are not provided . We should try to follow that, this is especially beneficial for streaming cases that feature #5868 will help.

See below example:

Padding to max sequence length.(Proposed Default Behaviour)

from transformers import BertTokenizerFast


tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True)
data = ['a', 'a b', 'a b c d']
output = tokenizer.batch_encode_plus(data,padding=True,add_special_tokens=False, return_tensors = 'pt')
output['input_ids']

tensor([[1037,    0,    0,    0],
        [1037, 1038,    0,    0],
        [1037, 1038, 1039, 1040]])

Padding to max_length (Current Default Behavior)

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True)
output = tokenizer.batch_encode_plus(
        data, truncation=True, max_length=64, pad_to_max_length=True,
        add_special_tokens=False, return_tensors = 'pt'
    )
output['input_ids']
tensor([[1037,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [1037, 1038,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [1037, 1038, 1039, 1040,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0]])

Related Implications:

a. We might have to switch from returning one-dimensional cupy arrays to 2-dimensional arrays for token-ids and attention masks which we allready do for most workflow cases so should not have performance penalties.

Describe alternatives you've considered

Currently, a user can do the tokenization twice.

  1. First time to get maximum sequence length, do this without a to_dlpack call.
  2. Inputting that sequence length to tokenizer again and then convert to tensors using dlpack

I do above for gpu-bdb q27 HF.
), As most of the time is spent doing to_dlpack so this workaround should not have big performance implications.

CC: @raykallen , @randerzander , @davidwendt

@VibhuJawa VibhuJawa added feature request New feature or request Needs Triage Need team to review and classify strings strings issues (C++ and Python) labels Aug 25, 2020
@kkraus14 kkraus14 added Python Affects Python cuDF API. libcudf Affects libcudf (C++/CUDA) code. and removed Needs Triage Need team to review and classify labels Aug 27, 2020
@github-actions
Copy link

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request libcudf Affects libcudf (C++/CUDA) code. Python Affects Python cuDF API. strings strings issues (C++ and Python)
Projects
Status: Todo
Development

No branches or pull requests

4 participants