Skip to content

Commit

Permalink
Tests + Deserialization improvement for normalizers. (#1604)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored Aug 8, 2024
1 parent 49dafd7 commit 56c9c70
Showing 1 changed file with 178 additions and 3 deletions.
181 changes: 178 additions & 3 deletions tokenizers/src/normalizers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ pub use crate::normalizers::replace::Replace;
pub use crate::normalizers::strip::{Strip, StripAccents};
pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD};
pub use crate::normalizers::utils::{Lowercase, Sequence};
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};

use crate::{NormalizedString, Normalizer};

/// Wrapper for known Normalizers.
#[derive(Clone, Debug, Deserialize, Serialize)]
#[derive(Clone, Debug, Serialize)]
#[serde(untagged)]
pub enum NormalizerWrapper {
BertNormalizer(BertNormalizer),
Expand All @@ -38,6 +38,149 @@ pub enum NormalizerWrapper {
ByteLevel(ByteLevel),
}

impl<'de> Deserialize<'de> for NormalizerWrapper {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
pub struct Tagged {
#[serde(rename = "type")]
variant: EnumType,
#[serde(flatten)]
rest: serde_json::Value,
}
#[derive(Serialize, Deserialize)]
pub enum EnumType {
Bert,
Strip,
StripAccents,
NFC,
NFD,
NFKC,
NFKD,
Sequence,
Lowercase,
Nmt,
Precompiled,
Replace,
Prepend,
ByteLevel,
}

#[derive(Deserialize)]
#[serde(untagged)]
pub enum NormalizerHelper {
Tagged(Tagged),
Legacy(serde_json::Value),
}

#[derive(Deserialize)]
#[serde(untagged)]
pub enum NormalizerUntagged {
BertNormalizer(BertNormalizer),
StripNormalizer(Strip),
StripAccents(StripAccents),
NFC(NFC),
NFD(NFD),
NFKC(NFKC),
NFKD(NFKD),
Sequence(Sequence),
Lowercase(Lowercase),
Nmt(Nmt),
Precompiled(Precompiled),
Replace(Replace),
Prepend(Prepend),
ByteLevel(ByteLevel),
}

let helper = NormalizerHelper::deserialize(deserializer)?;
Ok(match helper {
NormalizerHelper::Tagged(model) => {
let mut values: serde_json::Map<String, serde_json::Value> =
serde_json::from_value(model.rest).expect("Parsed values");
values.insert(
"type".to_string(),
serde_json::to_value(&model.variant).expect("Reinsert"),
);
let values = serde_json::Value::Object(values);
match model.variant {
EnumType::Bert => NormalizerWrapper::BertNormalizer(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Strip => NormalizerWrapper::StripNormalizer(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::StripAccents => NormalizerWrapper::StripAccents(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::NFC => NormalizerWrapper::NFC(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::NFD => NormalizerWrapper::NFD(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::NFKC => NormalizerWrapper::NFKC(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::NFKD => NormalizerWrapper::NFKD(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Sequence => NormalizerWrapper::Sequence(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Lowercase => NormalizerWrapper::Lowercase(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Nmt => NormalizerWrapper::Nmt(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Precompiled => NormalizerWrapper::Precompiled(
serde_json::from_str(
&serde_json::to_string(&values).expect("Can reserialize precompiled"),
)
// .map_err(serde::de::Error::custom)
.expect("Precompiled"),
),
EnumType::Replace => NormalizerWrapper::Replace(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::Prepend => NormalizerWrapper::Prepend(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
EnumType::ByteLevel => NormalizerWrapper::ByteLevel(
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
),
}
}

NormalizerHelper::Legacy(value) => {
let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
match untagged {
NormalizerUntagged::BertNormalizer(bpe) => {
NormalizerWrapper::BertNormalizer(bpe)
}
NormalizerUntagged::StripNormalizer(bpe) => {
NormalizerWrapper::StripNormalizer(bpe)
}
NormalizerUntagged::StripAccents(bpe) => NormalizerWrapper::StripAccents(bpe),
NormalizerUntagged::NFC(bpe) => NormalizerWrapper::NFC(bpe),
NormalizerUntagged::NFD(bpe) => NormalizerWrapper::NFD(bpe),
NormalizerUntagged::NFKC(bpe) => NormalizerWrapper::NFKC(bpe),
NormalizerUntagged::NFKD(bpe) => NormalizerWrapper::NFKD(bpe),
NormalizerUntagged::Sequence(bpe) => NormalizerWrapper::Sequence(bpe),
NormalizerUntagged::Lowercase(bpe) => NormalizerWrapper::Lowercase(bpe),
NormalizerUntagged::Nmt(bpe) => NormalizerWrapper::Nmt(bpe),
NormalizerUntagged::Precompiled(bpe) => NormalizerWrapper::Precompiled(bpe),
NormalizerUntagged::Replace(bpe) => NormalizerWrapper::Replace(bpe),
NormalizerUntagged::Prepend(bpe) => NormalizerWrapper::Prepend(bpe),
NormalizerUntagged::ByteLevel(bpe) => NormalizerWrapper::ByteLevel(bpe),
}
}
})
}
}

impl Normalizer for NormalizerWrapper {
fn normalize(&self, normalized: &mut NormalizedString) -> crate::Result<()> {
match self {
Expand Down Expand Up @@ -91,7 +234,7 @@ mod tests {
match reconstructed {
Err(err) => assert_eq!(
err.to_string(),
"data did not match any variant of untagged enum NormalizerWrapper"
"data did not match any variant of untagged enum NormalizerUntagged"
),
_ => panic!("Expected an error here"),
}
Expand All @@ -103,4 +246,36 @@ mod tests {
NormalizerWrapper::Prepend(_)
));
}

#[test]
fn normalizer_serialization() {
let json = r#"{"type":"Sequence","normalizers":[]}"#;
assert!(serde_json::from_str::<NormalizerWrapper>(json).is_ok());
let json = r#"{"type":"Sequence","normalizers":[{}]}"#;
let parse = serde_json::from_str::<NormalizerWrapper>(json);
match parse {
Err(err) => assert_eq!(
format!("{err}"),
"data did not match any variant of untagged enum NormalizerUntagged"
),
_ => panic!("Expected error"),
}

let json = r#"{"replacement":"▁","prepend_scheme":"always"}"#;
let parse = serde_json::from_str::<NormalizerWrapper>(json);
match parse {
Err(err) => assert_eq!(
format!("{err}"),
"data did not match any variant of untagged enum NormalizerUntagged"
),
_ => panic!("Expected error"),
}

let json = r#"{"type":"Sequence","prepend_scheme":"always"}"#;
let parse = serde_json::from_str::<NormalizerWrapper>(json);
match parse {
Err(err) => assert_eq!(format!("{err}"), "missing field `normalizers`"),
_ => panic!("Expected error"),
}
}
}

0 comments on commit 56c9c70

Please sign in to comment.