From ffce32cf0b035ae3fec17f713341e7db53774796 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Sun, 17 Sep 2023 19:37:59 +0900 Subject: [PATCH 1/7] Support `&[u8]` --- src/lib.rs | 145 +++++++++++++++++++++++++++++++------------------ tests/tests.rs | 42 ++++++++++++++ 2 files changed, 133 insertions(+), 54 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 540df57..17d2466 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,66 +43,108 @@ use std::collections::HashMap; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use syn::{ - parse_macro_input, spanned::Spanned, Arm, Error, Expr, ExprLit, ExprMatch, Lit, Pat, PatOr, + parse_macro_input, spanned::Spanned, Arm, Error, Expr, ExprLit, ExprMatch, Lit, Pat, PatOr, PatSlice, PatWild, }; +static ERROR_UNEXPECTED_PATTERN: &str = + "`trie_match` only supports string literals, byte string literals, and u8 slices as patterns"; +static ERROR_ATTRIBUTE_NOT_SUPPORTED: &str = "attribute not supported here"; +static ERROR_GUARD_NOT_SUPPORTED: &str = "match guard not supported"; +static ERROR_UNREACHABLE_PATTERN: &str = "unreachable pattern"; +static ERROR_PATTERN_NOT_COVERED: &str = "non-exhaustive patterns: `_` not covered"; +static ERROR_EXPECTED_U8_LITERAL: &str = "expected `u8` integer literal"; + use crate::trie::Sparse; +/// Converts a literal pattern into a byte sequence. +fn convert_literal_pattern(pat: &ExprLit) -> Result>, Error> { + let ExprLit { attrs, lit } = pat; + if let Some(attr) = attrs.first() { + return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED)); + } + match lit { + Lit::Str(s) => Ok(Some(s.value().into())), + Lit::ByteStr(s) => Ok(Some(s.value())), + _ => Err(Error::new(lit.span(), ERROR_UNEXPECTED_PATTERN)), + } +} + +/// Converts a slice pattern into a byte sequence. +fn convert_slice_pattern(pat: &PatSlice) -> Result>, Error> { + let PatSlice { attrs, elems, .. } = pat; + if let Some(attr) = attrs.first() { + return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED)); + } + let mut result = vec![]; + for elem in elems { + match elem { + Pat::Lit(ExprLit { attrs, lit }) => { + if let Some(attr) = attrs.first() { + return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED)); + } + match lit { + Lit::Int(i) => { + let int_type = i.suffix(); + if int_type != "u8" && int_type != "" { + return Err(Error::new(i.span(), ERROR_EXPECTED_U8_LITERAL)); + } + result.push(i.base10_parse::()?); + } + Lit::Byte(b) => { + result.push(b.value()); + } + _ => { + return Err(Error::new(elem.span(), ERROR_EXPECTED_U8_LITERAL)); + } + } + } + _ => { + return Err(Error::new(elem.span(), ERROR_EXPECTED_U8_LITERAL)); + } + } + } + Ok(Some(result)) +} + +fn convert_wildcard_pattern(pat: &PatWild) -> Result>, Error> { + let PatWild { attrs, .. } = pat; + if let Some(attr) = attrs.first() { + return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED)); + } + Ok(None) +} + /// Retrieves pattern strings from the given token. /// /// None indicates a wild card pattern (`_`). -fn retrieve_match_patterns(pat: &Pat) -> Result>, Error> { +fn retrieve_match_patterns(pat: &Pat) -> Result>>, Error> { let mut pats = vec![]; match pat { - Pat::Lit(ExprLit { - lit: Lit::Str(s), - attrs, - }) => { - if let Some(attr) = attrs.first() { - return Err(Error::new(attr.span(), "attribute not supported here")); - } - pats.push(Some(s.value())); - } + Pat::Lit(pat) => pats.push(convert_literal_pattern(pat)?), + Pat::Slice(pat) => pats.push(convert_slice_pattern(pat)?), + Pat::Wild(pat) => pats.push(convert_wildcard_pattern(pat)?), Pat::Or(PatOr { attrs, leading_vert: None, cases, }) => { if let Some(attr) = attrs.first() { - return Err(Error::new(attr.span(), "attribute not supported here")); + return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED)); } for pat in cases { match pat { - Pat::Lit(ExprLit { - lit: Lit::Str(s), - attrs, - }) => { - if let Some(attr) = attrs.first() { - return Err(Error::new(attr.span(), "attribute not supported here")); - } - pats.push(Some(s.value())); - } + Pat::Lit(pat) => pats.push(convert_literal_pattern(pat)?), + Pat::Slice(pat) => pats.push(convert_slice_pattern(pat)?), + Pat::Wild(pat) => pats.push(convert_wildcard_pattern(pat)?), _ => { - return Err(Error::new( - pat.span(), - "`trie_match` only supports string literal patterns", - )); + return Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)); } } } } - Pat::Wild(PatWild { attrs, .. }) => { - if let Some(attr) = attrs.first() { - return Err(Error::new(attr.span(), "attribute not supported here")); - } - pats.push(None); - } _ => { - return Err(Error::new( - pat.span(), - "`trie_match` only supports string literal patterns", - )); + return Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)); } } Ok(pats) @@ -110,7 +152,7 @@ fn retrieve_match_patterns(pat: &Pat) -> Result>, Error> { struct MatchInfo { bodies: Vec, - pattern_map: HashMap, + pattern_map: HashMap, usize>, wildcard_idx: usize, } @@ -130,21 +172,21 @@ fn parse_match_arms(arms: Vec) -> Result { ) in arms.into_iter().enumerate() { if let Some(attr) = attrs.first() { - return Err(Error::new(attr.span(), "attribute not supported here")); + return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED)); } if let Some((if_token, _)) = guard { - return Err(Error::new(if_token.span(), "match guard not supported")); + return Err(Error::new(if_token.span(), ERROR_GUARD_NOT_SUPPORTED)); } - let pat_strs = retrieve_match_patterns(&pat)?; - for pat_str in pat_strs { - if let Some(pat_str) = pat_str { - if pattern_map.contains_key(&pat_str) { - return Err(Error::new(pat.span(), "unreachable pattern")); + let pat_bytes_set = retrieve_match_patterns(&pat)?; + for pat_bytes in pat_bytes_set { + if let Some(pat_bytes) = pat_bytes { + if pattern_map.contains_key(&pat_bytes) { + return Err(Error::new(pat.span(), ERROR_UNREACHABLE_PATTERN)); } - pattern_map.insert(pat_str, i); + pattern_map.insert(pat_bytes, i); } else { if wildcard_idx.is_some() { - return Err(Error::new(pat.span(), "unreachable pattern")); + return Err(Error::new(pat.span(), ERROR_UNREACHABLE_PATTERN)); } wildcard_idx.replace(i); } @@ -152,10 +194,7 @@ fn parse_match_arms(arms: Vec) -> Result { bodies.push(*body); } let Some(wildcard_idx) = wildcard_idx else { - return Err(Error::new( - Span::call_site(), - "non-exhaustive patterns: `_` not covered", - )); + return Err(Error::new(Span::call_site(), ERROR_PATTERN_NOT_COVERED)); }; Ok(MatchInfo { bodies, @@ -168,13 +207,11 @@ fn trie_match_inner(input: ExprMatch) -> Result { let ExprMatch { attrs, expr, arms, .. } = input; - let MatchInfo { bodies, pattern_map, wildcard_idx, } = parse_match_arms(arms)?; - let mut trie = Sparse::new(); for (k, v) in pattern_map { if v == wildcard_idx { @@ -203,12 +240,12 @@ fn trie_match_inner(input: ExprMatch) -> Result { #( #enumvalue, )* } #( #attr )* - match (|query: &str| unsafe { + match (|query: &[u8]| unsafe { let bases: &'static [i32] = &[ #( #base, )* ]; let out_checks: &'static [(__TrieMatchValue, u8)] = &[ #( #out_check, )* ]; let mut pos = 0; let mut base = bases[0]; - for &b in query.as_bytes() { + for &b in query { pos = base.wrapping_add(i32::from(b)) as usize; if let Some((_, check)) = out_checks.get(pos) { if *check == b { @@ -219,7 +256,7 @@ fn trie_match_inner(input: ExprMatch) -> Result { return __TrieMatchValue::#wildcard_ident; } out_checks.get_unchecked(pos).0 - })( #expr ) { + })( ::std::convert::AsRef::<[u8]>::as_ref( #expr ) ) { #( #arm, )* } } diff --git a/tests/tests.rs b/tests/tests.rs index 3632f37..21863af 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -125,3 +125,45 @@ fn test_try_base_conflict() { assert_eq!(f("\u{2}\u{3}"), 1); assert_eq!(f("\u{3}"), 1); } + +#[test] +fn test_bytes_literal() { + let f = |text: &[u8]| { + trie_match! { + match text { + b"abc" => 0, + _ => 1, + } + } + }; + assert_eq!(f(b"abc"), 0); + assert_eq!(f(b"ab"), 1); +} + +#[test] +fn test_slice_byte_literal() { + let f = |text: &[u8]| { + trie_match! { + match text { + [b'a', b'b', b'c'] => 0, + _ => 1, + } + } + }; + assert_eq!(f(b"abc"), 0); + assert_eq!(f(b"ab"), 1); +} + +#[test] +fn test_slice_numbers() { + let f = |text: &[u8]| { + trie_match! { + match text { + [0, 1, 2] => 0, + _ => 1, + } + } + }; + assert_eq!(f(&[0, 1, 2]), 0); + assert_eq!(f(&[0, 1]), 1); +} From b63792f5ccbaaa6dddda2a7f6495882f5f9f5d88 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Sun, 17 Sep 2023 19:38:45 +0900 Subject: [PATCH 2/7] clippy --- src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 17d2466..4601211 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,8 +43,8 @@ use std::collections::HashMap; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use syn::{ - parse_macro_input, spanned::Spanned, Arm, Error, Expr, ExprLit, ExprMatch, Lit, Pat, PatOr, PatSlice, - PatWild, + parse_macro_input, spanned::Spanned, Arm, Error, Expr, ExprLit, ExprMatch, Lit, Pat, PatOr, + PatSlice, PatWild, }; static ERROR_UNEXPECTED_PATTERN: &str = @@ -86,7 +86,7 @@ fn convert_slice_pattern(pat: &PatSlice) -> Result>, Error> { match lit { Lit::Int(i) => { let int_type = i.suffix(); - if int_type != "u8" && int_type != "" { + if int_type != "u8" && !int_type.is_empty() { return Err(Error::new(i.span(), ERROR_EXPECTED_U8_LITERAL)); } result.push(i.base10_parse::()?); From 96374b4c6feb73d95f1e776fca924dc7b496ad80 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Sun, 17 Sep 2023 19:50:01 +0900 Subject: [PATCH 3/7] Support reference --- src/lib.rs | 20 +++++++++++++++++++- tests/tests.rs | 14 ++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 4601211..c400fcb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,7 +44,7 @@ use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use syn::{ parse_macro_input, spanned::Spanned, Arm, Error, Expr, ExprLit, ExprMatch, Lit, Pat, PatOr, - PatSlice, PatWild, + PatReference, PatSlice, PatWild, }; static ERROR_UNEXPECTED_PATTERN: &str = @@ -115,6 +115,21 @@ fn convert_wildcard_pattern(pat: &PatWild) -> Result>, Error> { Ok(None) } +fn convert_reference_pattern(pat: &PatReference) -> Result>, Error> { + let PatReference { attrs, pat, .. } = pat; + if let Some(attr) = attrs.first() { + return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED)); + } + match &**pat { + Pat::Lit(pat) => Ok(convert_literal_pattern(pat)?), + Pat::Slice(pat) => Ok(convert_slice_pattern(pat)?), + Pat::Wild(pat) => Ok(convert_wildcard_pattern(pat)?), + _ => { + return Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)); + } + } +} + /// Retrieves pattern strings from the given token. /// /// None indicates a wild card pattern (`_`). @@ -124,6 +139,7 @@ fn retrieve_match_patterns(pat: &Pat) -> Result>>, Error> { Pat::Lit(pat) => pats.push(convert_literal_pattern(pat)?), Pat::Slice(pat) => pats.push(convert_slice_pattern(pat)?), Pat::Wild(pat) => pats.push(convert_wildcard_pattern(pat)?), + Pat::Reference(pat) => pats.push(convert_reference_pattern(pat)?), Pat::Or(PatOr { attrs, leading_vert: None, @@ -137,7 +153,9 @@ fn retrieve_match_patterns(pat: &Pat) -> Result>>, Error> { Pat::Lit(pat) => pats.push(convert_literal_pattern(pat)?), Pat::Slice(pat) => pats.push(convert_slice_pattern(pat)?), Pat::Wild(pat) => pats.push(convert_wildcard_pattern(pat)?), + Pat::Reference(pat) => pats.push(convert_reference_pattern(pat)?), _ => { + dbg!(pat); return Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)); } } diff --git a/tests/tests.rs b/tests/tests.rs index 21863af..f88f3a7 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -167,3 +167,17 @@ fn test_slice_numbers() { assert_eq!(f(&[0, 1, 2]), 0); assert_eq!(f(&[0, 1]), 1); } + +#[test] +fn test_slice_ref_numbers() { + let f = |text: &[u8]| { + trie_match! { + match text { + &[0, 1, 2] => 0, + _ => 1, + } + } + }; + assert_eq!(f(&[0, 1, 2]), 0); + assert_eq!(f(&[0, 1]), 1); +} From 228a13a88dbc8031bb87b3b03db8b52fcba620fd Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Sun, 17 Sep 2023 19:53:54 +0900 Subject: [PATCH 4/7] clippy --- src/lib.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index c400fcb..975d2f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,9 +124,7 @@ fn convert_reference_pattern(pat: &PatReference) -> Result>, Erro Pat::Lit(pat) => Ok(convert_literal_pattern(pat)?), Pat::Slice(pat) => Ok(convert_slice_pattern(pat)?), Pat::Wild(pat) => Ok(convert_wildcard_pattern(pat)?), - _ => { - return Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)); - } + _ => Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)), } } @@ -155,7 +153,6 @@ fn retrieve_match_patterns(pat: &Pat) -> Result>>, Error> { Pat::Wild(pat) => pats.push(convert_wildcard_pattern(pat)?), Pat::Reference(pat) => pats.push(convert_reference_pattern(pat)?), _ => { - dbg!(pat); return Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)); } } From ad852a32629ad56866199443c94e4fc479f9b491 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Sun, 17 Sep 2023 20:01:33 +0900 Subject: [PATCH 5/7] fix README --- README.md | 2 +- src/lib.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c7c1b47..d9a8ea4 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ to achieve efficient state-to-state traversal, and the time complexity becomes The followings are different from the normal `match` expression: -* Only supports string comparison. +* Only supports strings, byte strings, and u8 slices as patterns. * The wildcard is evaluated last. (The normal `match` expression does not match patterns after the wildcard.) * Pattern bindings are unavailable. diff --git a/src/lib.rs b/src/lib.rs index 975d2f0..68786e9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,7 @@ //! //! The followings are different from the normal `match` expression: //! -//! * Only supports string comparison. +//! * Only supports strings, byte strings, and u8 slices as patterns. //! * The wildcard is evaluated last. (The normal `match` expression does not //! match patterns after the wildcard.) //! * Pattern bindings are unavailable. From 2198c697a5611d6ed06f3949baec4178556202d1 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Sun, 17 Sep 2023 22:14:25 +0900 Subject: [PATCH 6/7] update doc --- src/lib.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 68786e9..6ff724f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -107,6 +107,10 @@ fn convert_slice_pattern(pat: &PatSlice) -> Result>, Error> { Ok(Some(result)) } +/// Checks a wildcard pattern and returns `None`. +/// +/// The reason the type is `Result>, Error>` instead of `Result<(), Error>` is for +/// consistency with other functions. fn convert_wildcard_pattern(pat: &PatWild) -> Result>, Error> { let PatWild { attrs, .. } = pat; if let Some(attr) = attrs.first() { @@ -115,6 +119,7 @@ fn convert_wildcard_pattern(pat: &PatWild) -> Result>, Error> { Ok(None) } +/// Converts a reference pattern (e.g. `&[0, 1, ...]`) into a byte sequence. fn convert_reference_pattern(pat: &PatReference) -> Result>, Error> { let PatReference { attrs, pat, .. } = pat; if let Some(attr) = attrs.first() { @@ -123,7 +128,7 @@ fn convert_reference_pattern(pat: &PatReference) -> Result>, Erro match &**pat { Pat::Lit(pat) => Ok(convert_literal_pattern(pat)?), Pat::Slice(pat) => Ok(convert_slice_pattern(pat)?), - Pat::Wild(pat) => Ok(convert_wildcard_pattern(pat)?), + Pat::Reference(pat) => Ok(convert_reference_pattern(pat)?), _ => Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)), } } From 1fe5299b62101faef0c840ebcc4d920ac36d2030 Mon Sep 17 00:00:00 2001 From: Koichi Akabe Date: Mon, 18 Sep 2023 00:24:56 +0900 Subject: [PATCH 7/7] Update lib.rs --- src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 6ff724f..2980ef1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -126,9 +126,9 @@ fn convert_reference_pattern(pat: &PatReference) -> Result>, Erro return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED)); } match &**pat { - Pat::Lit(pat) => Ok(convert_literal_pattern(pat)?), - Pat::Slice(pat) => Ok(convert_slice_pattern(pat)?), - Pat::Reference(pat) => Ok(convert_reference_pattern(pat)?), + Pat::Lit(pat) => convert_literal_pattern(pat), + Pat::Slice(pat) => convert_slice_pattern(pat), + Pat::Reference(pat) => convert_reference_pattern(pat), _ => Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)), } }