-
Notifications
You must be signed in to change notification settings - Fork 915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEA] Create separate API for loading the vocabulary file for the subword-tokenizer #5868
Comments
I have a slightly unrelated question, How hard will it be to support The reasons for this ask are: a. b. It paves the road for the decode feature (essentially a dictionary lookup) which we don't have in our tokenizer right now ( And we have to rely on cpus to do) c. The current process of creating hash-table may have some bugs in it (see:#5760 and #5765 ). It is also not under any kind of tests or CI (AFAIK), (feel free to correct me @raykallen / @BartleyR ), so anyways it might be worthwhile adding it to add it to a location in Also, note that the suggestion here is to have both options (vocab.txt as well as vocab-hash-table) which the user can provide as he wishes. Please feel to shoot this idea down if it's expensive/difficult to implement, This more of a edit: Added point b on |
This would involve incorporating this python file inside the cudf Python package and having someone maintain it. |
This issue has been marked rotten due to no recent activity in the past 90d. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. |
We had a call offline to discuss a plan on this and this is what we came up with as an initial plan of action:
Follow up work after the above is finished:
|
On the python end-user side, the requested API can loosely follow BertTokenizerFast and we will provide a subset of the functionality as first pass. class SubwordTokenizer(hash_file: str, max_length: int = 64, stride: int = 48, do_lower: bool = True, do_truncate: bool = False, max_rows_tensor: int = 500): The function to include in this class is def encode(Convert a cudf string series into sequence of ids)
## cudf series -> Tuple[cupy.core.core.ndarray, cupy.core.core.ndarray, cupy.core.core.ndarray] The end user will call it like below: ser = cudf.Series(['this is the', 'best book'])
tokenizer = SubwordTokenizer('voc_hash.txt', max_length=max_length, stride=stride)
tokens, masks, metadata = tokenizer.encode(ser) CC: @BartleyR / @raykallen for inputs. |
Reference #5868 This PR changes the `nvtext::load_vocabulary_file` to return a unique-pointer to make it easier to manage in Python/Cython class object. The original signature returned a flat structure that contained unique-pointers which would make it difficult to copy and manage. The corresponding gtests and gbenchmarks were updated for this API change. Authors: - David (@davidwendt) Approvers: - Conor Hoekstra (@codereport) - Karthikeyan (@karthikeyann) URL: #7424
This PR closes #5868 by adding a new tokenizer API. We are seeing speedups even at low batch sizes (10/100) so this should potentially unlock some inference/training use cases for us. ## Benchmarks: (Thanks to @davidwendt for writing the super fast tokenizer 💥 ) | Batch Size | HuggingFace | Rapids Old API | Rapids Tokenizer API | Tokenizer API Speed up vs HuggingFace | Rapids New API Speedup | |- |- |- |- |- |- | | 1 | 0.000242 | 0.006890 | 0.000497 | 0.487 | 13.863 | | 10 | 0.002800 | 0.007030 | 0.000516 | 5.426 | 13.624 | | 100 | 0.016200 | 0.007140 | 0.000537 | 30.168 | 13.296 | | 1000 | 0.149000 | 0.007150 | 0.000517 | 288.201 | 13.830 | ## API Comparision to HuggingFace: The goal of this PR is to ensure our API matches up HuggingFace as much as possible to help with ease of porting. Proposed API in this PR: ```python from cudf.core.subword_tokenizer import SubwordTokenizer tokenizer = SubwordTokenizer('bert-base-cased-vocab-hash.txt',do_lower_case=False) output = tokenizer(str_series, max_num_rows=len(str_series), truncation=True, max_length=seq_len, padding='max_length', add_special_tokens=False, return_tensors='pt') ``` HuggingFace API: ```python from transformers import BertTokenizerFast tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased', do_lower_case=False) output = tokenizer(input_sentence_ls, truncation=True, max_length=seq_len, padding='max_length', add_special_tokens=False, return_tensors = 'pt') output_d = {k:v.cuda() for k,v in output.items()} ``` ## TODO: - [x] Add tests - [x] Throw appropriate warnings for HuggingFace discrepancies - [x] API checks - [X] [ Benchmark/Example Notebook ](https://nbviewer.jupyter.org/gist/VibhuJawa/350a8479b10be3591dd9c4d5da3cfc3b) CC: @raykallen, @BartleyR (from the cyber team) CC: @randerzander , @beckernick (from the workflows team) Authors: - Vibhu Jawa (https://github.com/VibhuJawa) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) - AJ Schmidt (https://github.com/ajschmidt8) - Keith Kraus (https://github.com/kkraus14) URL: #7942
Is your feature request related to a problem? Please describe.
The current subword_tokenize API in cudf accepts the vocabulary hash-file as a parameter. Profiling shows processing this file into GPU memory is about 14% of the overall execution time for 1M strings 15 bytes each using the CLX BERT vocab hash file (755KB).
From some previous discussions with @VibhuJawa @raykallen @BartleyR, multiple
subword_tokenize
calls over batches of data are likely using the same vocabulary so loading the hash-file for each tokenize call would be wasteful.Note that the load time is fixed so for a smaller set of strings, the cost against the overall time is higher. Here is a run with only 100K strings with same vocab file:
Here the vocab file load is more than half the overall time.
Describe the solution you'd like
Create a separate API to load the hash-table for the tokenize. The resulting object could then be passed on each tokenize call that used the same vocabulary thereby speeding up the processing of multiple batches of data.
This may also align better with the hugging face tokenizer which at least stores the vocabulary file names as an attribute if not the preprocessed hash data somehow internally.
The libcudf API layer already has APIs that separate the file load from the tokenize:
So this feature request is to design and implement an equivalent cudf Python API appropriately.
@randerzander
The text was updated successfully, but these errors were encountered: