Skip to content

Commit

Permalink
Subword Tokenizer HuggingFace like API (#7942)
Browse files Browse the repository at this point in the history
This PR closes #5868 by adding a new tokenizer API. 

We are seeing speedups even at low batch sizes (10/100) so this should potentially unlock some inference/training use cases for us. 

## Benchmarks:
(Thanks to  @davidwendt  for writing the super fast tokenizer  💥 )

| Batch Size 	| HuggingFace 	| Rapids Old API 	| Rapids Tokenizer API 	| Tokenizer API Speed up  vs HuggingFace 	| Rapids New API  Speedup 	|
|-	|-	|-	|-	|-	|-	|
| 1 	| 0.000242 	| 0.006890 	| 0.000497 	| 0.487 	| 13.863 	|
| 10 	| 0.002800 	| 0.007030 	| 0.000516 	| 5.426 	| 13.624 	|
| 100 	| 0.016200 	| 0.007140 	| 0.000537 	| 30.168 	| 13.296 	|
| 1000 	| 0.149000 	| 0.007150 	| 0.000517 	| 288.201 	| 13.830 	|


## API Comparision to HuggingFace:

The goal of this PR is to ensure our API matches up HuggingFace as much as possible to help with ease of porting. 

Proposed API in this PR:

```python
from cudf.core.subword_tokenizer import SubwordTokenizer
tokenizer = SubwordTokenizer('bert-base-cased-vocab-hash.txt',do_lower_case=False)

output = tokenizer(str_series,
                   max_num_rows=len(str_series),
                   truncation=True,
                   max_length=seq_len,
                   padding='max_length',
                   add_special_tokens=False,
                   return_tensors='pt')
```           

HuggingFace API:
```python
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased', do_lower_case=False)

output = tokenizer(input_sentence_ls,
                   truncation=True,
                   max_length=seq_len,
                   padding='max_length',
                   add_special_tokens=False,
                   return_tensors = 'pt')
output_d = {k:v.cuda() for k,v in output.items()} 
```

## TODO:
- [x] Add tests
- [x] Throw appropriate warnings for HuggingFace discrepancies 
- [x] API checks 
- [X] [ Benchmark/Example Notebook ](https://nbviewer.jupyter.org/gist/VibhuJawa/350a8479b10be3591dd9c4d5da3cfc3b)



CC: @raykallen, @BartleyR  (from the cyber team)

CC: @randerzander , @beckernick  (from the workflows team)

Authors:
  - Vibhu Jawa (https://github.com/VibhuJawa)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - AJ Schmidt (https://github.com/ajschmidt8)
  - Keith Kraus (https://github.com/kkraus14)

URL: #7942
  • Loading branch information
VibhuJawa authored May 3, 2021
1 parent ad081ae commit 36eaa06
Show file tree
Hide file tree
Showing 19 changed files with 8,491 additions and 5,596 deletions.
1 change: 1 addition & 0 deletions conda/environments/cudf_dev_cuda11.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ dependencies:
- protobuf
- nvtx>=0.2.1
- cachetools
- transformers
- pip:
- git+https://github.com/dask/dask.git@main
- git+https://github.com/dask/distributed.git@main
Expand Down
1 change: 1 addition & 0 deletions conda/environments/cudf_dev_cuda11.1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ dependencies:
- protobuf
- nvtx>=0.2.1
- cachetools
- transformers
- pip:
- git+https://github.com/dask/dask.git@main
- git+https://github.com/dask/distributed.git@main
Expand Down
1 change: 1 addition & 0 deletions conda/environments/cudf_dev_cuda11.2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ dependencies:
- protobuf
- nvtx>=0.2.1
- cachetools
- transformers
- pip:
- git+https://github.com/dask/dask.git@main
- git+https://github.com/dask/distributed.git@main
Expand Down
7 changes: 7 additions & 0 deletions docs/cudf/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,13 @@ Window
.. autoclass:: Rolling
:members:

SubwordTokenizer
----------------
.. currentmodule:: cudf.core.subword_tokenizer

.. autoclass:: SubwordTokenizer
:members:
:special-members: __call__

General utility functions
-------------------------
Expand Down
28 changes: 27 additions & 1 deletion python/cudf/cudf/_lib/cpp/nvtext/subword_tokenize.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from libcpp cimport bool
from libcpp.memory cimport unique_ptr
from libcpp.string cimport string
from libc.stdint cimport uint32_t
from libc.stdint cimport uint16_t, uint32_t


from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.column.column_view cimport column_view
Expand All @@ -17,6 +18,31 @@ cdef extern from "nvtext/subword_tokenize.hpp" namespace "nvtext" nogil:
unique_ptr[column] tensor_attention_mask
unique_ptr[column] tensor_metadata

cdef struct hashed_vocabulary "nvtext::hashed_vocabulary":
uint16_t first_token_id
uint16_t separator_token_id
uint16_t unknown_token_id
uint32_t outer_hash_a
uint32_t outer_hash_b
uint16_t num_bin
unique_ptr[column] table
unique_ptr[column] bin_coefficients
unique_ptr[column] bin_offsets

cdef unique_ptr[hashed_vocabulary] load_vocabulary_file(
const string &filename_hashed_vocabulary
) except +

cdef tokenizer_result subword_tokenize(
const column_view & strings,
hashed_vocabulary & hashed_vocablary_obj,
uint32_t max_sequence_length,
uint32_t stride,
bool do_lower,
bool do_truncate,
uint32_t max_rows_tensor
) except +

cdef tokenizer_result subword_tokenize(
const column_view &strings,
const string &filename_hashed_vocabulary,
Expand Down
58 changes: 52 additions & 6 deletions python/cudf/cudf/_lib/nvtext/subword_tokenize.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,74 @@ from libc.stdint cimport uintptr_t

from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.column.column_view cimport column_view
from cudf._lib.cpp.nvtext.subword_tokenize cimport (
from cudf._lib.cpp.nvtext.subword_tokenize cimport(
subword_tokenize as cpp_subword_tokenize,
hashed_vocabulary as cpp_hashed_vocabulary,
load_vocabulary_file as cpp_load_vocabulary_file,
tokenizer_result as cpp_tokenizer_result,
move as tr_move
move as tr_move,
)
from cudf._lib.column cimport Column


def subword_tokenize(
cdef class Hashed_Vocabulary:
cdef unique_ptr[cpp_hashed_vocabulary] c_obj

def __cinit__(self, hash_file):
cdef string c_hash_file = <string>str(hash_file).encode()
with nogil:
self.c_obj = move(cpp_load_vocabulary_file(c_hash_file))


def subword_tokenize_inmem_hash(
Column strings,
object hash_file,
Hashed_Vocabulary hashed_vocabulary,
uint32_t max_sequence_length=64,
uint32_t stride=48,
bool do_lower=True,
bool do_truncate=False,
uint32_t max_rows_tensor=500
):
"""
Subword tokenizes text series by using the pre-loaded hashed vocabulary
"""
cdef column_view c_strings = strings.view()
cdef string c_hash_file = <string>str(hash_file).encode()
cdef cpp_tokenizer_result c_result
with nogil:
c_result = tr_move(
cpp_subword_tokenize(
c_strings,
hashed_vocabulary.c_obj.get()[0],
max_sequence_length,
stride,
do_lower,
do_truncate,
max_rows_tensor
)
)
# return the 3 tensor components
tokens = Column.from_unique_ptr(move(c_result.tensor_token_ids))
masks = Column.from_unique_ptr(move(c_result.tensor_attention_mask))
metadata = Column.from_unique_ptr(move(c_result.tensor_metadata))
return tokens, masks, metadata


def subword_tokenize_vocab_file(
Column strings,
object hash_file,
uint32_t max_sequence_length=64,
uint32_t stride=48,
bool do_lower=True,
bool do_truncate=False,
uint32_t max_rows_tensor=500
):
"""
Subword tokenizes text series by using the hashed vocabulary
stored on disk
"""
cdef column_view c_strings = strings.view()
cdef cpp_tokenizer_result c_result
cdef string c_hash_file = <string>str(hash_file).encode()
with nogil:
c_result = tr_move(
cpp_subword_tokenize(
Expand All @@ -42,7 +89,6 @@ def subword_tokenize(
max_rows_tensor
)
)

# return the 3 tensor components
tokens = Column.from_unique_ptr(move(c_result.tensor_token_ids))
masks = Column.from_unique_ptr(move(c_result.tensor_attention_mask))
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
porter_stemmer_measure as cpp_porter_stemmer_measure,
)
from cudf._lib.nvtext.subword_tokenize import (
subword_tokenize as cpp_subword_tokenize,
subword_tokenize_vocab_file as cpp_subword_tokenize_vocab_file,
)
from cudf._lib.nvtext.tokenize import (
_count_tokens_column as cpp_count_tokens_column,
Expand Down Expand Up @@ -4617,7 +4617,7 @@ def subword_tokenize(
array([[0, 0, 2],
[1, 0, 1]], dtype=uint32)
"""
tokens, masks, metadata = cpp_subword_tokenize(
tokens, masks, metadata = cpp_subword_tokenize_vocab_file(
self._column,
hash_file,
max_length,
Expand Down
Loading

0 comments on commit 36eaa06

Please sign in to comment.