diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 73ffc7eef..8b9a0a0d5 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -67,7 +67,10 @@ jobs: build_and_test: name: Check everything builds & tests - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest] steps: - name: Checkout repository uses: actions/checkout@v1 diff --git a/bindings/python/CHANGELOG.md b/bindings/python/CHANGELOG.md index 7cfa5f3f2..78ff53cbe 100644 --- a/bindings/python/CHANGELOG.md +++ b/bindings/python/CHANGELOG.md @@ -18,6 +18,10 @@ This adds some methods to easily save/load an entire tokenizer (`from_str`, `fro activation of the Tensor Cores, while ensuring padding to a multiple of 8. Use with `enable_padding(pad_to_multiple_of=8)` for example. - [#298]: Ability to get the currently set truncation/padding params +- [#311]: Ability to enable/disable the parallelism using the `TOKENIZERS_PARALLELISM` environment +variable. This is especially usefull when using `multiprocessing` capabilities, with the `fork` +start method, which happens to be the default on Linux systems. Without disabling the parallelism, +the process dead-locks while encoding. (Cf [#187] for more information) ### Changed - Improved errors generated during truncation: When the provided max length is too low are @@ -190,6 +194,7 @@ delimiter (Works like `.split(delimiter)`) - Fix a bug with the IDs associated with added tokens. - Fix a bug that was causing crashes in Python 3.5 +[#311]: https://github.com/huggingface/tokenizers/pull/311 [#309]: https://github.com/huggingface/tokenizers/pull/309 [#289]: https://github.com/huggingface/tokenizers/pull/289 [#286]: https://github.com/huggingface/tokenizers/pull/286 @@ -207,6 +212,7 @@ delimiter (Works like `.split(delimiter)`) [#193]: https://github.com/huggingface/tokenizers/pull/193 [#190]: https://github.com/huggingface/tokenizers/pull/190 [#188]: https://github.com/huggingface/tokenizers/pull/188 +[#187]: https://github.com/huggingface/tokenizers/issues/187 [#175]: https://github.com/huggingface/tokenizers/issues/175 [#174]: https://github.com/huggingface/tokenizers/issues/174 [#165]: https://github.com/huggingface/tokenizers/pull/165 diff --git a/bindings/python/Cargo.lock b/bindings/python/Cargo.lock index a5b4b95f3..2b9122302 100644 --- a/bindings/python/Cargo.lock +++ b/bindings/python/Cargo.lock @@ -244,6 +244,14 @@ dependencies = [ "syn 1.0.17 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "itertools" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "itoa" version = "0.4.5" @@ -486,6 +494,16 @@ dependencies = [ "rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "rayon-cond" +version = "0.1.0" +source = "git+https://github.com/n1t0/rayon-cond#c56e4f1ded0fcb92eac70e0533703bba3ca2983f" +dependencies = [ + "either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)", + "itertools 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)", + "rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "rayon-core" version = "1.7.0" @@ -611,6 +629,7 @@ dependencies = [ "onig 6.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)", "rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "rayon-cond 0.1.0 (git+https://github.com/n1t0/rayon-cond)", "regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)", "regex-syntax 0.6.17 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)", @@ -624,6 +643,7 @@ dependencies = [ name = "tokenizers-python" version = "0.8.0-rc3" dependencies = [ + "libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)", "pyo3 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)", "rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)", @@ -744,6 +764,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum indoc-impl 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)" = "54554010aa3d17754e484005ea0022f1c93839aabc627c2c55f3d7b47206134c" "checksum inventory 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "82d3f4b90287725c97b17478c60dda0c6324e7c84ee1ed72fb9179d0fdf13956" "checksum inventory-impl 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "9092a4fefc9d503e9287ef137f03180a6e7d1b04c419563171ee14947c5e80ec" +"checksum itertools 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f56a2d0bc861f9165be4eb3442afd3c236d8a98afd426f65d92324ae1091a484" "checksum itoa 0.4.5 (registry+https://github.com/rust-lang/crates.io-index)" = "b8b7a7c0c47db5545ed3fef7468ee7bb5b74691498139e4b3f6a20685dc6dd8e" "checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" "checksum libc 0.2.68 (registry+https://github.com/rust-lang/crates.io-index)" = "dea0c0405123bba743ee3f91f49b1c7cfb684eef0da0a50110f758ccf24cdff0" @@ -773,6 +794,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" "checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" "checksum rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "db6ce3297f9c85e16621bb8cca38a06779ffc31bb8184e1be4bed2be4678a098" +"checksum rayon-cond 0.1.0 (git+https://github.com/n1t0/rayon-cond)" = "" "checksum rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "08a89b46efaf957e52b18062fb2f4660f8b8a4dde1807ca002690868ef2c85a9" "checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84" "checksum regex 1.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3" diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index f49d143dc..a7a5796b8 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -13,6 +13,7 @@ rayon = "1.3" typetag = "0.1" serde = "1.0" serde_json = "1.0" +libc = "0.2" [dependencies.pyo3] version = "0.9.2" diff --git a/bindings/python/examples/train_bert_wordpiece.py b/bindings/python/examples/train_bert_wordpiece.py index 58560dde2..c31146b8a 100644 --- a/bindings/python/examples/train_bert_wordpiece.py +++ b/bindings/python/examples/train_bert_wordpiece.py @@ -47,4 +47,4 @@ ) # Save the files -tokenizer.save(args.out, args.name) +tokenizer.save_model(args.out, args.name) diff --git a/bindings/python/examples/train_bytelevel_bpe.py b/bindings/python/examples/train_bytelevel_bpe.py index 2a0382d17..bd78710b2 100644 --- a/bindings/python/examples/train_bytelevel_bpe.py +++ b/bindings/python/examples/train_bytelevel_bpe.py @@ -44,7 +44,7 @@ ) # Save the files -tokenizer.save(args.out, args.name) +tokenizer.save_model(args.out, args.name) # Restoring model from learned vocab/merges tokenizer = ByteLevelBPETokenizer( diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index cee23fdd1..991a1e25c 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -1,3 +1,5 @@ +extern crate tokenizers as tk; + mod decoders; mod encoding; mod error; @@ -13,6 +15,23 @@ mod utils; use pyo3::prelude::*; use pyo3::wrap_pymodule; +// For users using multiprocessing in python, it is quite easy to fork the process running +// tokenizers, ending up with a deadlock because we internaly make use of multithreading. So +// we register a callback to be called in the event of a fork so that we can warn the user. +static mut REGISTERED_FORK_CALLBACK: bool = false; +extern "C" fn child_after_fork() { + if !tk::parallelism::is_parallelism_configured() { + println!( + "The current process just got forked. Disabling parallelism to avoid deadlocks..." + ); + println!( + "To disable this warning, please explicitly set {}=(true | false)", + tk::parallelism::ENV_VARIABLE + ); + tk::parallelism::set_parallelism(false); + } +} + /// Trainers Module #[pymodule] fn trainers(_py: Python, m: &PyModule) -> PyResult<()> { @@ -84,6 +103,15 @@ fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> { /// Tokenizers Module #[pymodule] fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> { + // Register the fork callback + #[cfg(target_family = "unix")] + unsafe { + if !REGISTERED_FORK_CALLBACK { + libc::pthread_atfork(None, None, Some(child_after_fork)); + REGISTERED_FORK_CALLBACK = true; + } + } + m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 677e66085..8214c4413 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -6,8 +6,8 @@ use super::utils::Container; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; -use rayon::prelude::*; use std::path::Path; +use tk::parallelism::*; #[pyclass] struct EncodeInput { @@ -154,7 +154,7 @@ impl Model { fn encode_batch(&self, sequences: Vec, type_id: u32) -> PyResult> { ToPyResult(self.model.execute(|model| { sequences - .into_par_iter() + .into_maybe_par_iter() .map(|sequence| { let sequence = sequence.into_input(); if sequence.is_empty() { diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index 26be32cc0..b320f8354 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -1,6 +1,6 @@ import pickle import pytest -from ..utils import data_dir, roberta_files, bert_files +from ..utils import data_dir, roberta_files, bert_files, multiprocessing_with_parallelism from tokenizers import AddedToken, Tokenizer, Encoding from tokenizers.models import Model, BPE, WordPiece @@ -289,3 +289,8 @@ def test_post_process(self): # Can post process a pair of encodings output = tokenizer.post_process(encoding, pair_encoding) assert output.tokens == ["my", "pair", "[PAD]", "[PAD]"] + + def test_multiprocessing_with_parallelism(self): + tokenizer = Tokenizer(BPE()) + multiprocessing_with_parallelism(tokenizer, False) + multiprocessing_with_parallelism(tokenizer, True) diff --git a/bindings/python/tests/implementations/test_bert_wordpiece.py b/bindings/python/tests/implementations/test_bert_wordpiece.py index fb241cd10..9c0ffc40e 100644 --- a/bindings/python/tests/implementations/test_bert_wordpiece.py +++ b/bindings/python/tests/implementations/test_bert_wordpiece.py @@ -1,4 +1,4 @@ -from ..utils import data_dir, bert_files +from ..utils import data_dir, bert_files, multiprocessing_with_parallelism from tokenizers import BertWordPieceTokenizer @@ -19,3 +19,8 @@ def test_basic_encode(self, bert_files): assert output.tokens == ["my", "name", "is", "john", "pair"] assert output.offsets == [(0, 2), (3, 7), (8, 10), (11, 15), (0, 4)] assert output.type_ids == [0, 0, 0, 0, 1] + + def test_multiprocessing_with_parallelism(self, bert_files): + tokenizer = BertWordPieceTokenizer(bert_files["vocab"]) + multiprocessing_with_parallelism(tokenizer, False) + multiprocessing_with_parallelism(tokenizer, True) diff --git a/bindings/python/tests/implementations/test_byte_level_bpe.py b/bindings/python/tests/implementations/test_byte_level_bpe.py index d5a4673e0..27a2a8554 100644 --- a/bindings/python/tests/implementations/test_byte_level_bpe.py +++ b/bindings/python/tests/implementations/test_byte_level_bpe.py @@ -1,4 +1,4 @@ -from ..utils import data_dir, roberta_files +from ..utils import data_dir, roberta_files, multiprocessing_with_parallelism from tokenizers import ByteLevelBPETokenizer @@ -79,3 +79,8 @@ def test_lowerspace(self, roberta_files): "Ġlazy", "Ġdog", ] + + def test_multiprocessing_with_parallelism(self, roberta_files): + tokenizer = ByteLevelBPETokenizer(roberta_files["vocab"], roberta_files["merges"]) + multiprocessing_with_parallelism(tokenizer, False) + multiprocessing_with_parallelism(tokenizer, True) diff --git a/bindings/python/tests/implementations/test_char_bpe.py b/bindings/python/tests/implementations/test_char_bpe.py index 66b45f43a..6867e2cc8 100644 --- a/bindings/python/tests/implementations/test_char_bpe.py +++ b/bindings/python/tests/implementations/test_char_bpe.py @@ -1,4 +1,4 @@ -from ..utils import data_dir, openai_files +from ..utils import data_dir, openai_files, multiprocessing_with_parallelism from tokenizers import CharBPETokenizer @@ -42,3 +42,8 @@ def test_decoding(self, openai_files): tokenizer = CharBPETokenizer(openai_files["vocab"], openai_files["merges"], lowercase=True) decoded = tokenizer.decode(tokenizer.encode("my name is john").ids) assert decoded == "my name is john" + + def test_multiprocessing_with_parallelism(self, openai_files): + tokenizer = CharBPETokenizer(openai_files["vocab"], openai_files["merges"]) + multiprocessing_with_parallelism(tokenizer, False) + multiprocessing_with_parallelism(tokenizer, True) diff --git a/bindings/python/tests/utils.py b/bindings/python/tests/utils.py index cdf52fcb2..9b89287ba 100644 --- a/bindings/python/tests/utils.py +++ b/bindings/python/tests/utils.py @@ -1,3 +1,4 @@ +import multiprocessing as mp import os import requests import pytest @@ -56,3 +57,33 @@ def openai_files(data_dir): "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt" ), } + + +def multiprocessing_with_parallelism(tokenizer, enabled: bool): + """ + This helper can be used to test that disabling parallelism avoids dead locks when the + same tokenizer is used after forking. + """ + # It's essential to this test that we call 'encode' or 'encode_batch' + # before the fork. This causes the main process to "lock" some resources + # provided by the Rust "rayon" crate that are needed for parallel processing. + tokenizer.encode("Hi") + tokenizer.encode_batch(["hi", "there"]) + + def encode(tokenizer): + tokenizer.encode("Hi") + tokenizer.encode_batch(["hi", "there"]) + + # Make sure this environment variable is set before the fork happens + os.environ["TOKENIZERS_PARALLELISM"] = str(enabled) + p = mp.Process(target=encode, args=(tokenizer,)) + p.start() + p.join(timeout=1) + + # At this point the process should have successfully exited, depending on whether parallelism + # was activated or not. So we check the status and kill it if needed + alive = p.is_alive() + if alive: + p.terminate() + + assert (alive and mp.get_start_method() == "fork") == enabled diff --git a/tokenizers/CHANGELOG.md b/tokenizers/CHANGELOG.md index 834b08063..5d8e2de90 100644 --- a/tokenizers/CHANGELOG.md +++ b/tokenizers/CHANGELOG.md @@ -43,6 +43,8 @@ using serde. It is now easy to save/load an entire tokenizer. - [#289]: Ability to pad to a multiple of a specified value. This is especially useful to ensure activation of the Tensor Cores, while ensuring padding to a multiple of 8. - [#298]: Ability to get the currently set truncation/padding params +- [#311]: Ability to enable/disable the parallelism using the `TOKENIZERS_PARALLELISM` environment +variable. ### How to migrate - Replace any `XXX_to_YYY_offsets()` method call by any of the new ones. @@ -117,6 +119,7 @@ advised, but that's not the question) split up in multiple bytes - [#174]: The `LongestFirst` truncation strategy had a bug +[#311]: https://github.com/huggingface/tokenizers/pull/311 [#309]: https://github.com/huggingface/tokenizers/pull/309 [#298]: https://github.com/huggingface/tokenizers/pull/298 [#289]: https://github.com/huggingface/tokenizers/pull/289 diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 2d11ac9b2..e67f6d2cd 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -36,6 +36,7 @@ onig = { version = "6.0", default-features = false } regex = "1.3" regex-syntax = "0.6" rayon = "1.3" +rayon-cond = { version = "*", git = "https://github.com/n1t0/rayon-cond" } serde = { version = "1.0", features = [ "derive" ] } serde_json = "1.0" typetag = "0.1" diff --git a/tokenizers/src/lib.rs b/tokenizers/src/lib.rs index 437e61dbd..74e3eb2df 100644 --- a/tokenizers/src/lib.rs +++ b/tokenizers/src/lib.rs @@ -57,3 +57,6 @@ pub mod utils; // Re-export from tokenizer pub use tokenizer::*; + +// Re-export also parallelism utils +pub use utils::parallelism; diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index 44e24523c..4c17d42c1 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -1,9 +1,9 @@ #![allow(clippy::map_entry)] use super::{Pair, WithFirstLastIterator, Word, BPE}; +use crate::parallelism::*; use crate::tokenizer::{AddedToken, Model, Result, Trainer}; use indicatif::{ProgressBar, ProgressStyle}; -use rayon::prelude::*; use std::cmp::Ordering; use std::collections::{BinaryHeap, HashMap, HashSet}; @@ -352,7 +352,7 @@ impl BpeTrainer { p: &Option, ) -> (HashMap, HashMap>) { words - .par_iter() + .maybe_par_iter() .enumerate() .map(|(i, word)| { let mut pair_counts = HashMap::new(); @@ -379,10 +379,10 @@ impl BpeTrainer { h }); *pair_counts.get_mut(&cur_pair).unwrap() += count as i32; + } - if let Some(p) = &p { - p.inc(1); - } + if let Some(p) = &p { + p.inc(1); } (pair_counts, where_to_update) @@ -499,7 +499,7 @@ impl BpeTrainer { // Merge the new pair in every words let changes = top .pos - .par_iter() + .maybe_par_iter() .flat_map(|i| { let w = &words[*i] as *const _ as *mut _; // We can merge each of these words in parallel here because each position diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index c082ac9b3..096900ff1 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -1,8 +1,8 @@ +use crate::parallelism::*; use crate::tokenizer::{ Decoder, Encoding, NormalizedString, Offsets, PostProcessor, PreTokenizer, Result, }; use onig::Regex; -use rayon::prelude::*; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; @@ -97,7 +97,7 @@ impl PreTokenizer for ByteLevel { .collect::>(); let splits = positions - .into_par_iter() + .into_maybe_par_iter() .map(|range| { // Process one of the splits let slice = &normalized.get()[range]; diff --git a/tokenizers/src/tokenizer/encoding.rs b/tokenizers/src/tokenizer/encoding.rs index 390d864bb..d280a16be 100644 --- a/tokenizers/src/tokenizer/encoding.rs +++ b/tokenizers/src/tokenizer/encoding.rs @@ -1,6 +1,6 @@ +use crate::parallelism::*; use crate::tokenizer::{Offsets, Token}; use crate::utils::padding::PaddingDirection; -use rayon::prelude::*; use serde::{Deserialize, Serialize}; /// Represents the output of a `Tokenizer`. @@ -362,7 +362,7 @@ impl Encoding { direction: PaddingDirection, ) { // Dispatch call to all the overflowings first - self.overflowing.par_iter_mut().for_each(|encoding| { + self.overflowing.maybe_par_iter_mut().for_each(|encoding| { encoding.pad(target_length, pad_id, pad_type_id, pad_token, direction) }); diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index be98a8219..8c167aadf 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -14,7 +14,6 @@ use crate::utils::iter::ResultShunt; pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams, PaddingStrategy}; pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy}; use indicatif::{ProgressBar, ProgressStyle}; -use rayon::prelude::*; use std::{ collections::HashMap, fs::File, @@ -36,6 +35,8 @@ pub type Error = Box; pub type Result = std::result::Result; pub type Offsets = (usize, usize); +use crate::utils::parallelism::*; + #[typetag::serde(tag = "type")] /// Takes care of pre-processing strings. pub trait Normalizer: Send + Sync { @@ -532,7 +533,7 @@ impl Tokenizer { add_special_tokens: bool, ) -> Result> { let mut encodings = inputs - .into_par_iter() + .into_maybe_par_iter() .map(|input| self.encode(input, add_special_tokens)) .collect::>>()?; @@ -574,7 +575,7 @@ impl Tokenizer { skip_special_tokens: bool, ) -> Result> { sentences - .into_par_iter() + .into_maybe_par_iter() .map(|sentence| self.decode(sentence, skip_special_tokens)) .collect() } @@ -612,9 +613,8 @@ impl Tokenizer { // We read new lines using this API instead of the Lines Iterator // on purpose. We want to keep the `\n` and potential `\r` between each lines // We use an iterator to be able to chain with par_bridge. - let words = file - .lines_with_ending() - .par_bridge() + file.lines_with_ending() + .maybe_par_bridge() .map_with( &progress, |progress, line| -> Result> { @@ -635,14 +635,16 @@ impl Tokenizer { Ok(words) }, ) - .try_reduce(HashMap::new, |mut acc, ws| { - for (k, v) in ws { - acc.entry(k).and_modify(|c| *c += v).or_insert(v); - } - Ok(acc) - })?; - - Ok(words) + .reduce( + || Ok(HashMap::new()), + |acc, ws| { + let mut acc = acc?; + for (k, v) in ws? { + acc.entry(k).and_modify(|c| *c += v).or_insert(v); + } + Ok(acc) + }, + ) }) .try_fold( HashMap::new(), diff --git a/tokenizers/src/utils/mod.rs b/tokenizers/src/utils/mod.rs index f961fd7da..e641102e5 100644 --- a/tokenizers/src/utils/mod.rs +++ b/tokenizers/src/utils/mod.rs @@ -1,3 +1,4 @@ pub mod iter; pub mod padding; +pub mod parallelism; pub mod truncation; diff --git a/tokenizers/src/utils/padding.rs b/tokenizers/src/utils/padding.rs index 9d03df10c..07cbf1fcf 100644 --- a/tokenizers/src/utils/padding.rs +++ b/tokenizers/src/utils/padding.rs @@ -1,5 +1,5 @@ +use crate::parallelism::*; use crate::tokenizer::{Encoding, Result}; -use rayon::prelude::*; use serde::{Deserialize, Serialize}; /// The various possible padding directions. @@ -55,7 +55,7 @@ pub fn pad_encodings(encodings: &mut [Encoding], params: &PaddingParams) -> Resu let mut pad_length = match params.strategy { PaddingStrategy::Fixed(size) => size, PaddingStrategy::BatchLongest => encodings - .par_iter() + .maybe_par_iter() .map(|e| e.get_ids().len()) .max() .unwrap(), @@ -67,7 +67,7 @@ pub fn pad_encodings(encodings: &mut [Encoding], params: &PaddingParams) -> Resu } } - encodings.par_iter_mut().for_each(|encoding| { + encodings.maybe_par_iter_mut().for_each(|encoding| { encoding.pad( pad_length, params.pad_id, diff --git a/tokenizers/src/utils/parallelism.rs b/tokenizers/src/utils/parallelism.rs new file mode 100644 index 000000000..fe6d95843 --- /dev/null +++ b/tokenizers/src/utils/parallelism.rs @@ -0,0 +1,179 @@ +//! +//! This module defines helpers to allow optional Rayon usage. +//! + +use rayon::iter::IterBridge; +use rayon::prelude::*; +use rayon_cond::CondIterator; + +pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM"; + +/// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set +pub fn is_parallelism_configured() -> bool { + std::env::var(ENV_VARIABLE).is_ok() +} + +/// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable +pub fn get_parallelism() -> bool { + match std::env::var(ENV_VARIABLE) { + Ok(mut v) => { + v.make_ascii_lowercase(); + match v.as_ref() { + "" | "off" | "false" | "f" | "no" | "n" | "0" => false, + _ => true, + } + } + Err(_) => true, // If we couldn't get the variable, we use the default + } +} + +/// Set the value for `TOKENIZERS_PARALLELISM` for the current process +pub fn set_parallelism(val: bool) { + std::env::set_var(ENV_VARIABLE, if val { "true" } else { "false" }) +} + +/// Allows to convert into an iterator that can be executed either parallelly or serially. +/// +/// The choice is made according to the currently set `TOKENIZERS_PARALLELISM` environment variable. +/// This variable can have one of the following values +/// - False => "" (empty value), "false", "f", "off", "no", "n", "0" +/// - True => Any other value +/// +pub trait MaybeParallelIterator +where + P: ParallelIterator, + S: Iterator, +{ + /// Convert ourself in a CondIterator, that will be executed either in parallel or serially, + /// based solely on the `TOKENIZERS_PARALLELISM` environment variable + fn into_maybe_par_iter(self) -> CondIterator; + /// Convert ourself in a CondIterator, that will be executed either in parallel or serially, + /// based on both the `TOKENIZERS_PARALLELISM` environment variable and the provided bool. + /// Both must be true to run with parallelism activated. + fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator; +} + +impl MaybeParallelIterator for I +where + I: IntoParallelIterator + IntoIterator, + P: ParallelIterator, + S: Iterator, +{ + fn into_maybe_par_iter(self) -> CondIterator { + CondIterator::new(self, get_parallelism()) + } + + fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator { + if cond { + self.into_maybe_par_iter() + } else { + CondIterator::from_serial(self) + } + } +} + +/// Shared reference version of MaybeParallelIterator, works the same but returns an iterator +/// over references, does not consume self +pub trait MaybeParallelRefIterator<'data, P, S> +where + P: ParallelIterator, + S: Iterator, + P::Item: 'data, +{ + fn maybe_par_iter(&'data self) -> CondIterator; + fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator; +} + +impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefIterator<'data, P, S> for I +where + &'data I: MaybeParallelIterator, + P: ParallelIterator, + S: Iterator, + P::Item: 'data, +{ + fn maybe_par_iter(&'data self) -> CondIterator { + self.into_maybe_par_iter() + } + + fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator { + self.into_maybe_par_iter_cond(cond) + } +} + +/// Exclusive reference version of MaybeParallelIterator, works the same but returns an iterator +/// over mutable references, does not consume self +pub trait MaybeParallelRefMutIterator<'data, P, S> +where + P: ParallelIterator, + S: Iterator, + P::Item: 'data, +{ + fn maybe_par_iter_mut(&'data mut self) -> CondIterator; + fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator; +} + +impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefMutIterator<'data, P, S> for I +where + &'data mut I: MaybeParallelIterator, + P: ParallelIterator, + S: Iterator, + P::Item: 'data, +{ + fn maybe_par_iter_mut(&'data mut self) -> CondIterator { + self.into_maybe_par_iter() + } + + fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator { + self.into_maybe_par_iter_cond(cond) + } +} + +/// Converts any serial iterator into a CondIterator, that can either run parallelly or serially. +pub trait MaybeParallelBridge +where + S: Iterator + Send, + T: Send, +{ + fn maybe_par_bridge(self) -> CondIterator, S>; + fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator, S>; +} + +impl MaybeParallelBridge for S +where + S: Iterator + Send, + T: Send, +{ + fn maybe_par_bridge(self) -> CondIterator, S> { + let iter = CondIterator::from_serial(self); + + if get_parallelism() { + CondIterator::from_parallel(iter.into_parallel().right().unwrap()) + } else { + iter + } + } + + fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator, S> { + if cond { + self.maybe_par_bridge() + } else { + CondIterator::from_serial(self) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[ignore] + fn test_maybe_parallel_iterator() { + let mut v = vec![1u32, 2, 3, 4, 5, 6]; + + assert_eq!(v.maybe_par_iter().sum::(), 21); + assert_eq!(v.maybe_par_iter_mut().map(|v| *v * 2).sum::(), 42); + assert_eq!(v.maybe_par_iter().sum::(), 42); + assert_eq!(v.into_maybe_par_iter().sum::(), 42); + } +}