Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ByteFallback support for tokenizers. #1183

Merged
merged 11 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions bindings/node/lib/bindings/decoders.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ export function byteLevelDecoder(): Decoder;
*/
export function wordPieceDecoder(prefix?: string, cleanup?: boolean): Decoder;

/**
* Instantiate a new ByteFallback Decoder
* ByteFallback is a simple trick which converts tokens looking like `<0x61>`
* to pure bytes, and attempts to make them into a string. If the tokens
* cannot be decoded you will get � instead for each inconvertable byte token
*/
export function byteFallbackDecoder(): Decoder;

/**
* Instantiate a new Metaspace
*
Expand Down
1 change: 1 addition & 0 deletions bindings/node/lib/bindings/decoders.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ const native = require("./native");
module.exports = {
byteLevelDecoder: native.decoders_ByteLevel,
wordPieceDecoder: native.decoders_WordPiece,
byteFallbackDecoder: native.decoders_ByteFallback,
metaspaceDecoder: native.decoders_Metaspace,
bpeDecoder: native.decoders_BPEDecoder,
ctcDecoder: native.decoders_CTC,
Expand Down
22 changes: 22 additions & 0 deletions bindings/node/lib/bindings/decoders.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {
bpeDecoder,
byteFallbackDecoder,
ctcDecoder,
metaspaceDecoder,
sequenceDecoder,
Expand All @@ -22,6 +23,27 @@ describe("wordPieceDecoder", () => {
});
});

describe("byteFallbackDecoder", () => {
it("accepts `undefined` as first parameter", () => {
expect(byteFallbackDecoder()).toBeDefined();
});

it("can decode arrays of strings", () => {
expect(byteFallbackDecoder().decode(["Hel", "lo"])).toEqual("Hello");
expect(byteFallbackDecoder().decode(["<0x61>"])).toEqual("a");
expect(byteFallbackDecoder().decode(["<0x61>"])).toEqual("a");
expect(byteFallbackDecoder().decode(["My", " na", "me"])).toEqual("My name");
expect(byteFallbackDecoder().decode(["<0x61>"])).toEqual("a");
expect(byteFallbackDecoder().decode(["<0xE5>"])).toEqual("�");
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>"])).toEqual("��");
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>", "<0xab>"])).toEqual("叫");
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>", "a"])).toEqual("��a");
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>", "<0xab>", "a"])).toEqual(
"叫a"
);
});
});

describe("metaspaceDecoder", () => {
it("accepts `undefined` as first parameter", () => {
expect(metaspaceDecoder(undefined)).toBeDefined();
Expand Down
11 changes: 11 additions & 0 deletions bindings/node/native/src/decoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ fn wordpiece(mut cx: FunctionContext) -> JsResult<JsDecoder> {
Ok(decoder)
}

/// byte_fallback()
fn byte_fallback(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
let guard = cx.lock();
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
tk::decoders::byte_fallback::ByteFallback::new().into(),
));
Ok(decoder)
}

/// metaspace(replacement: String = "_", add_prefix_space: bool = true)
fn metaspace(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let replacement = cx.extract_opt::<char>(0)?.unwrap_or('▁');
Expand Down Expand Up @@ -147,6 +157,7 @@ fn sequence(mut cx: FunctionContext) -> JsResult<JsDecoder> {
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?;
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
m.export_function(&format!("{}_ByteFallback", prefix), byte_fallback)?;
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?;
m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?;
Expand Down
4 changes: 4 additions & 0 deletions bindings/node/native/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ struct BpeOptions {
continuing_subword_prefix: Option<String>,
end_of_word_suffix: Option<String>,
fuse_unk: Option<bool>,
byte_fallback: Option<bool>,
}
impl BpeOptions {
fn apply_to_bpe_builder(self, mut builder: BpeBuilder) -> BpeBuilder {
Expand All @@ -153,6 +154,9 @@ impl BpeOptions {
if let Some(fuse_unk) = self.fuse_unk {
builder = builder.fuse_unk(fuse_unk);
}
if let Some(byte_fallback) = self.byte_fallback {
builder = builder.byte_fallback(byte_fallback);
}

builder
}
Expand Down
1 change: 1 addition & 0 deletions bindings/python/py_src/tokenizers/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Decoder = decoders.Decoder
ByteLevel = decoders.ByteLevel
WordPiece = decoders.WordPiece
ByteFallback = decoders.ByteFallback
Metaspace = decoders.Metaspace
BPEDecoder = decoders.BPEDecoder
CTC = decoders.CTC
Expand Down
24 changes: 24 additions & 0 deletions bindings/python/py_src/tokenizers/decoders/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ class BPEDecoder(Decoder):
"""
pass

class ByteFallback(Decoder):
"""
ByteFallback Decoder
ByteFallback is a simple trick which converts tokens looking like `<0x61>`
to pure bytes, and attempts to make them into a string. If the tokens
cannot be decoded you will get � instead for each inconvertable byte token

"""

def __init__(self):
pass
def decode(self, tokens):
"""
Decode the given list of tokens to a final string

Args:
tokens (:obj:`List[str]`):
The list of tokens to decode

Returns:
:obj:`str`: The decoded string
"""
pass

class ByteLevel(Decoder):
"""
ByteLevel Decoder
Expand Down
4 changes: 4 additions & 0 deletions bindings/python/py_src/tokenizers/models/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ class BPE(Model):

fuse_unk (:obj:`bool`, `optional`):
Whether to fuse any subsequent unknown tokens into a single one

byte_fallback (:obj:`bool`, `optional`):
Whether to use spm byte-fallback trick (defaults to False)
"""

def __init__(
Expand All @@ -118,6 +121,7 @@ class BPE(Model):
continuing_subword_prefix=None,
end_of_word_suffix=None,
fuse_unk=None,
byte_fallback=False,
):
pass
@staticmethod
Expand Down
22 changes: 22 additions & 0 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use pyo3::types::*;
use serde::de::Error;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::decoders::bpe::BPEDecoder;
use tk::decoders::byte_fallback::ByteFallback;
use tk::decoders::byte_level::ByteLevel;
use tk::decoders::ctc::CTC;
use tk::decoders::metaspace::Metaspace;
Expand Down Expand Up @@ -41,6 +42,9 @@ impl PyDecoder {
PyDecoderWrapper::Wrapped(inner) => match &*inner.as_ref().read().unwrap() {
DecoderWrapper::Metaspace(_) => Py::new(py, (PyMetaspaceDec {}, base))?.into_py(py),
DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_py(py),
DecoderWrapper::ByteFallback(_) => {
Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py)
}
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
Expand Down Expand Up @@ -196,6 +200,23 @@ impl PyWordPieceDec {
}
}

/// ByteFallback Decoder
/// ByteFallback is a simple trick which converts tokens looking like `<0x61>`
/// to pure bytes, and attempts to make them into a string. If the tokens
/// cannot be decoded you will get � instead for each inconvertable byte token
///
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "ByteFallback")]
#[pyo3(text_signature = "(self)")]
pub struct PyByteFallbackDec {}
#[pymethods]
impl PyByteFallbackDec {
#[new]
#[pyo3(signature = ())]
fn new() -> (Self, PyDecoder) {
(PyByteFallbackDec {}, ByteFallback::new().into())
}
}

/// Metaspace Decoder
///
/// Args:
Expand Down Expand Up @@ -453,6 +474,7 @@ pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyDecoder>()?;
m.add_class::<PyByteLevelDec>()?;
m.add_class::<PyWordPieceDec>()?;
m.add_class::<PyByteFallbackDec>()?;
m.add_class::<PyMetaspaceDec>()?;
m.add_class::<PyBPEDecoder>()?;
m.add_class::<PyCTCDecoder>()?;
Expand Down
16 changes: 15 additions & 1 deletion bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,12 @@ impl PyModel {
///
/// fuse_unk (:obj:`bool`, `optional`):
/// Whether to fuse any subsequent unknown tokens into a single one
///
/// byte_fallback (:obj:`bool`, `optional`):
/// Whether to use spm byte-fallback trick (defaults to False)
#[pyclass(extends=PyModel, module = "tokenizers.models", name = "BPE")]
#[pyo3(
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None)"
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None, byte_fallback=False)"
)]
pub struct PyBPE {}

Expand All @@ -277,6 +280,7 @@ impl PyBPE {
}
"end_of_word_suffix" => builder = builder.end_of_word_suffix(value.extract()?),
"fuse_unk" => builder = builder.fuse_unk(value.extract()?),
"byte_fallback" => builder = builder.byte_fallback(value.extract()?),
_ => println!("Ignored unknown kwarg option {}", key),
};
}
Expand Down Expand Up @@ -385,6 +389,16 @@ impl PyBPE {
setter!(self_, BPE, fuse_unk, fuse_unk);
}

#[getter]
fn get_byte_fallback(self_: PyRef<Self>) -> bool {
getter!(self_, BPE, byte_fallback)
}

#[setter]
fn set_byte_fallback(self_: PyRef<Self>, byte_fallback: bool) {
setter!(self_, BPE, byte_fallback, byte_fallback);
}

#[new]
#[pyo3(signature = (vocab=None, merges=None, **kwargs))]
fn new(
Expand Down
20 changes: 19 additions & 1 deletion bindings/python/tests/bindings/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece
from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece, ByteFallback


class TestByteLevel:
Expand Down Expand Up @@ -54,6 +54,24 @@ def test_can_modify(self):
assert decoder.cleanup == True


class TestByteFallback:
def test_instantiate(self):
assert ByteFallback() is not None
assert isinstance(ByteFallback(), Decoder)
assert isinstance(ByteFallback(), ByteFallback)
assert isinstance(pickle.loads(pickle.dumps(ByteFallback())), ByteFallback)

def test_decoding(self):
decoder = ByteFallback()
assert decoder.decode(["My", " na", "me"]) == "My name"
assert decoder.decode(["<0x61>"]) == "a"
assert decoder.decode(["<0xE5>"]) == "�"
assert decoder.decode(["<0xE5>", "<0x8f>"]) == "��"
assert decoder.decode(["<0xE5>", "<0x8f>", "<0xab>"]) == "叫"
assert decoder.decode(["<0xE5>", "<0x8f>", "a"]) == "��a"
assert decoder.decode(["<0xE5>", "<0x8f>", "<0xab>", "a"]) == "叫a"


class TestMetaspace:
def test_instantiate(self):
assert Metaspace() is not None
Expand Down
3 changes: 3 additions & 0 deletions bindings/python/tests/bindings/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_can_modify(self):
assert model.continuing_subword_prefix == "__prefix__"
assert model.end_of_word_suffix == "__suffix__"
assert model.fuse_unk == False
assert model.byte_fallback == False

# Modify these
model.dropout = 0.1
Expand All @@ -66,6 +67,8 @@ def test_can_modify(self):
assert model.end_of_word_suffix == "suff"
model.fuse_unk = True
assert model.fuse_unk == True
model.byte_fallback = True
assert model.byte_fallback == True


class TestWordPiece:
Expand Down
Loading