Skip to content

Commit

Permalink
Merge pull request #311 from huggingface/optional-parallelism
Browse files Browse the repository at this point in the history
Make parallelism optional
  • Loading branch information
n1t0 authored Jun 26, 2020
2 parents 74d812d + bb668bc commit 6d531a4
Show file tree
Hide file tree
Showing 23 changed files with 336 additions and 36 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ jobs:

build_and_test:
name: Check everything builds & tests
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
steps:
- name: Checkout repository
uses: actions/checkout@v1
Expand Down
6 changes: 6 additions & 0 deletions bindings/python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ This adds some methods to easily save/load an entire tokenizer (`from_str`, `fro
activation of the Tensor Cores, while ensuring padding to a multiple of 8. Use with
`enable_padding(pad_to_multiple_of=8)` for example.
- [#298]: Ability to get the currently set truncation/padding params
- [#311]: Ability to enable/disable the parallelism using the `TOKENIZERS_PARALLELISM` environment
variable. This is especially usefull when using `multiprocessing` capabilities, with the `fork`
start method, which happens to be the default on Linux systems. Without disabling the parallelism,
the process dead-locks while encoding. (Cf [#187] for more information)

### Changed
- Improved errors generated during truncation: When the provided max length is too low are
Expand Down Expand Up @@ -190,6 +194,7 @@ delimiter (Works like `.split(delimiter)`)
- Fix a bug with the IDs associated with added tokens.
- Fix a bug that was causing crashes in Python 3.5

[#311]: https://github.com/huggingface/tokenizers/pull/311
[#309]: https://github.com/huggingface/tokenizers/pull/309
[#289]: https://github.com/huggingface/tokenizers/pull/289
[#286]: https://github.com/huggingface/tokenizers/pull/286
Expand All @@ -207,6 +212,7 @@ delimiter (Works like `.split(delimiter)`)
[#193]: https://github.com/huggingface/tokenizers/pull/193
[#190]: https://github.com/huggingface/tokenizers/pull/190
[#188]: https://github.com/huggingface/tokenizers/pull/188
[#187]: https://github.com/huggingface/tokenizers/issues/187
[#175]: https://github.com/huggingface/tokenizers/issues/175
[#174]: https://github.com/huggingface/tokenizers/issues/174
[#165]: https://github.com/huggingface/tokenizers/pull/165
Expand Down
22 changes: 22 additions & 0 deletions bindings/python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ rayon = "1.3"
typetag = "0.1"
serde = "1.0"
serde_json = "1.0"
libc = "0.2"

[dependencies.pyo3]
version = "0.9.2"
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/examples/train_bert_wordpiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@
)

# Save the files
tokenizer.save(args.out, args.name)
tokenizer.save_model(args.out, args.name)
2 changes: 1 addition & 1 deletion bindings/python/examples/train_bytelevel_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)

# Save the files
tokenizer.save(args.out, args.name)
tokenizer.save_model(args.out, args.name)

# Restoring model from learned vocab/merges
tokenizer = ByteLevelBPETokenizer(
Expand Down
28 changes: 28 additions & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
extern crate tokenizers as tk;

mod decoders;
mod encoding;
mod error;
Expand All @@ -13,6 +15,23 @@ mod utils;
use pyo3::prelude::*;
use pyo3::wrap_pymodule;

// For users using multiprocessing in python, it is quite easy to fork the process running
// tokenizers, ending up with a deadlock because we internaly make use of multithreading. So
// we register a callback to be called in the event of a fork so that we can warn the user.
static mut REGISTERED_FORK_CALLBACK: bool = false;
extern "C" fn child_after_fork() {
if !tk::parallelism::is_parallelism_configured() {
println!(
"The current process just got forked. Disabling parallelism to avoid deadlocks..."
);
println!(
"To disable this warning, please explicitly set {}=(true | false)",
tk::parallelism::ENV_VARIABLE
);
tk::parallelism::set_parallelism(false);
}
}

/// Trainers Module
#[pymodule]
fn trainers(_py: Python, m: &PyModule) -> PyResult<()> {
Expand Down Expand Up @@ -84,6 +103,15 @@ fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> {
/// Tokenizers Module
#[pymodule]
fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
// Register the fork callback
#[cfg(target_family = "unix")]
unsafe {
if !REGISTERED_FORK_CALLBACK {
libc::pthread_atfork(None, None, Some(child_after_fork));
REGISTERED_FORK_CALLBACK = true;
}
}

m.add_class::<tokenizer::Tokenizer>()?;
m.add_class::<tokenizer::AddedToken>()?;
m.add_class::<encoding::Encoding>()?;
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use super::utils::Container;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use rayon::prelude::*;
use std::path::Path;
use tk::parallelism::*;

#[pyclass]
struct EncodeInput {
Expand Down Expand Up @@ -154,7 +154,7 @@ impl Model {
fn encode_batch(&self, sequences: Vec<EncodeInput>, type_id: u32) -> PyResult<Vec<Encoding>> {
ToPyResult(self.model.execute(|model| {
sequences
.into_par_iter()
.into_maybe_par_iter()
.map(|sequence| {
let sequence = sequence.into_input();
if sequence.is_empty() {
Expand Down
7 changes: 6 additions & 1 deletion bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pickle
import pytest
from ..utils import data_dir, roberta_files, bert_files
from ..utils import data_dir, roberta_files, bert_files, multiprocessing_with_parallelism

from tokenizers import AddedToken, Tokenizer, Encoding
from tokenizers.models import Model, BPE, WordPiece
Expand Down Expand Up @@ -289,3 +289,8 @@ def test_post_process(self):
# Can post process a pair of encodings
output = tokenizer.post_process(encoding, pair_encoding)
assert output.tokens == ["my", "pair", "[PAD]", "[PAD]"]

def test_multiprocessing_with_parallelism(self):
tokenizer = Tokenizer(BPE())
multiprocessing_with_parallelism(tokenizer, False)
multiprocessing_with_parallelism(tokenizer, True)
7 changes: 6 additions & 1 deletion bindings/python/tests/implementations/test_bert_wordpiece.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ..utils import data_dir, bert_files
from ..utils import data_dir, bert_files, multiprocessing_with_parallelism
from tokenizers import BertWordPieceTokenizer


Expand All @@ -19,3 +19,8 @@ def test_basic_encode(self, bert_files):
assert output.tokens == ["my", "name", "is", "john", "pair"]
assert output.offsets == [(0, 2), (3, 7), (8, 10), (11, 15), (0, 4)]
assert output.type_ids == [0, 0, 0, 0, 1]

def test_multiprocessing_with_parallelism(self, bert_files):
tokenizer = BertWordPieceTokenizer(bert_files["vocab"])
multiprocessing_with_parallelism(tokenizer, False)
multiprocessing_with_parallelism(tokenizer, True)
7 changes: 6 additions & 1 deletion bindings/python/tests/implementations/test_byte_level_bpe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ..utils import data_dir, roberta_files
from ..utils import data_dir, roberta_files, multiprocessing_with_parallelism
from tokenizers import ByteLevelBPETokenizer


Expand Down Expand Up @@ -79,3 +79,8 @@ def test_lowerspace(self, roberta_files):
"Ġlazy",
"Ġdog",
]

def test_multiprocessing_with_parallelism(self, roberta_files):
tokenizer = ByteLevelBPETokenizer(roberta_files["vocab"], roberta_files["merges"])
multiprocessing_with_parallelism(tokenizer, False)
multiprocessing_with_parallelism(tokenizer, True)
7 changes: 6 additions & 1 deletion bindings/python/tests/implementations/test_char_bpe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ..utils import data_dir, openai_files
from ..utils import data_dir, openai_files, multiprocessing_with_parallelism
from tokenizers import CharBPETokenizer


Expand Down Expand Up @@ -42,3 +42,8 @@ def test_decoding(self, openai_files):
tokenizer = CharBPETokenizer(openai_files["vocab"], openai_files["merges"], lowercase=True)
decoded = tokenizer.decode(tokenizer.encode("my name is john").ids)
assert decoded == "my name is john"

def test_multiprocessing_with_parallelism(self, openai_files):
tokenizer = CharBPETokenizer(openai_files["vocab"], openai_files["merges"])
multiprocessing_with_parallelism(tokenizer, False)
multiprocessing_with_parallelism(tokenizer, True)
31 changes: 31 additions & 0 deletions bindings/python/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import multiprocessing as mp
import os
import requests
import pytest
Expand Down Expand Up @@ -56,3 +57,33 @@ def openai_files(data_dir):
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt"
),
}


def multiprocessing_with_parallelism(tokenizer, enabled: bool):
"""
This helper can be used to test that disabling parallelism avoids dead locks when the
same tokenizer is used after forking.
"""
# It's essential to this test that we call 'encode' or 'encode_batch'
# before the fork. This causes the main process to "lock" some resources
# provided by the Rust "rayon" crate that are needed for parallel processing.
tokenizer.encode("Hi")
tokenizer.encode_batch(["hi", "there"])

def encode(tokenizer):
tokenizer.encode("Hi")
tokenizer.encode_batch(["hi", "there"])

# Make sure this environment variable is set before the fork happens
os.environ["TOKENIZERS_PARALLELISM"] = str(enabled)
p = mp.Process(target=encode, args=(tokenizer,))
p.start()
p.join(timeout=1)

# At this point the process should have successfully exited, depending on whether parallelism
# was activated or not. So we check the status and kill it if needed
alive = p.is_alive()
if alive:
p.terminate()

assert (alive and mp.get_start_method() == "fork") == enabled
3 changes: 3 additions & 0 deletions tokenizers/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ using serde. It is now easy to save/load an entire tokenizer.
- [#289]: Ability to pad to a multiple of a specified value. This is especially useful to ensure
activation of the Tensor Cores, while ensuring padding to a multiple of 8.
- [#298]: Ability to get the currently set truncation/padding params
- [#311]: Ability to enable/disable the parallelism using the `TOKENIZERS_PARALLELISM` environment
variable.

### How to migrate
- Replace any `XXX_to_YYY_offsets()` method call by any of the new ones.
Expand Down Expand Up @@ -117,6 +119,7 @@ advised, but that's not the question)
split up in multiple bytes
- [#174]: The `LongestFirst` truncation strategy had a bug

[#311]: https://github.com/huggingface/tokenizers/pull/311
[#309]: https://github.com/huggingface/tokenizers/pull/309
[#298]: https://github.com/huggingface/tokenizers/pull/298
[#289]: https://github.com/huggingface/tokenizers/pull/289
Expand Down
1 change: 1 addition & 0 deletions tokenizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ onig = { version = "6.0", default-features = false }
regex = "1.3"
regex-syntax = "0.6"
rayon = "1.3"
rayon-cond = { version = "*", git = "https://github.com/n1t0/rayon-cond" }
serde = { version = "1.0", features = [ "derive" ] }
serde_json = "1.0"
typetag = "0.1"
Expand Down
3 changes: 3 additions & 0 deletions tokenizers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,6 @@ pub mod utils;

// Re-export from tokenizer
pub use tokenizer::*;

// Re-export also parallelism utils
pub use utils::parallelism;
12 changes: 6 additions & 6 deletions tokenizers/src/models/bpe/trainer.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#![allow(clippy::map_entry)]

use super::{Pair, WithFirstLastIterator, Word, BPE};
use crate::parallelism::*;
use crate::tokenizer::{AddedToken, Model, Result, Trainer};
use indicatif::{ProgressBar, ProgressStyle};
use rayon::prelude::*;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};

Expand Down Expand Up @@ -352,7 +352,7 @@ impl BpeTrainer {
p: &Option<ProgressBar>,
) -> (HashMap<Pair, i32>, HashMap<Pair, HashSet<usize>>) {
words
.par_iter()
.maybe_par_iter()
.enumerate()
.map(|(i, word)| {
let mut pair_counts = HashMap::new();
Expand All @@ -379,10 +379,10 @@ impl BpeTrainer {
h
});
*pair_counts.get_mut(&cur_pair).unwrap() += count as i32;
}

if let Some(p) = &p {
p.inc(1);
}
if let Some(p) = &p {
p.inc(1);
}

(pair_counts, where_to_update)
Expand Down Expand Up @@ -499,7 +499,7 @@ impl BpeTrainer {
// Merge the new pair in every words
let changes = top
.pos
.par_iter()
.maybe_par_iter()
.flat_map(|i| {
let w = &words[*i] as *const _ as *mut _;
// We can merge each of these words in parallel here because each position
Expand Down
4 changes: 2 additions & 2 deletions tokenizers/src/pre_tokenizers/byte_level.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::parallelism::*;
use crate::tokenizer::{
Decoder, Encoding, NormalizedString, Offsets, PostProcessor, PreTokenizer, Result,
};
use onig::Regex;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};

Expand Down Expand Up @@ -97,7 +97,7 @@ impl PreTokenizer for ByteLevel {
.collect::<Vec<_>>();

let splits = positions
.into_par_iter()
.into_maybe_par_iter()
.map(|range| {
// Process one of the splits
let slice = &normalized.get()[range];
Expand Down
Loading

0 comments on commit 6d531a4

Please sign in to comment.