Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support byte string patterns #12

Merged
merged 8 commits into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ to achieve efficient state-to-state traversal, and the time complexity becomes

The followings are different from the normal `match` expression:

* Only supports string comparison.
* Only supports strings, byte strings, and u8 slices as patterns.
* The wildcard is evaluated last. (The normal `match` expression does not
match patterns after the wildcard.)
* Pattern bindings are unavailable.
Expand Down
167 changes: 112 additions & 55 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
//!
//! The followings are different from the normal `match` expression:
//!
//! * Only supports string comparison.
//! * Only supports strings, byte strings, and u8 slices as patterns.
//! * The wildcard is evaluated last. (The normal `match` expression does not
//! match patterns after the wildcard.)
//! * Pattern bindings are unavailable.
Expand All @@ -44,73 +44,135 @@ use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
use syn::{
parse_macro_input, spanned::Spanned, Arm, Error, Expr, ExprLit, ExprMatch, Lit, Pat, PatOr,
PatWild,
PatReference, PatSlice, PatWild,
};

static ERROR_UNEXPECTED_PATTERN: &str =
"`trie_match` only supports string literals, byte string literals, and u8 slices as patterns";
static ERROR_ATTRIBUTE_NOT_SUPPORTED: &str = "attribute not supported here";
static ERROR_GUARD_NOT_SUPPORTED: &str = "match guard not supported";
static ERROR_UNREACHABLE_PATTERN: &str = "unreachable pattern";
static ERROR_PATTERN_NOT_COVERED: &str = "non-exhaustive patterns: `_` not covered";
static ERROR_EXPECTED_U8_LITERAL: &str = "expected `u8` integer literal";

use crate::trie::Sparse;

/// Converts a literal pattern into a byte sequence.
fn convert_literal_pattern(pat: &ExprLit) -> Result<Option<Vec<u8>>, Error> {
let ExprLit { attrs, lit } = pat;
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
}
match lit {
Lit::Str(s) => Ok(Some(s.value().into())),
Lit::ByteStr(s) => Ok(Some(s.value())),
_ => Err(Error::new(lit.span(), ERROR_UNEXPECTED_PATTERN)),
}
}

/// Converts a slice pattern into a byte sequence.
fn convert_slice_pattern(pat: &PatSlice) -> Result<Option<Vec<u8>>, Error> {
let PatSlice { attrs, elems, .. } = pat;
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
}
let mut result = vec![];
for elem in elems {
match elem {
Pat::Lit(ExprLit { attrs, lit }) => {
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
}
match lit {
Lit::Int(i) => {
let int_type = i.suffix();
if int_type != "u8" && !int_type.is_empty() {
return Err(Error::new(i.span(), ERROR_EXPECTED_U8_LITERAL));
}
result.push(i.base10_parse::<u8>()?);
}
Lit::Byte(b) => {
result.push(b.value());
}
_ => {
return Err(Error::new(elem.span(), ERROR_EXPECTED_U8_LITERAL));
}
}
}
_ => {
return Err(Error::new(elem.span(), ERROR_EXPECTED_U8_LITERAL));
}
}
}
Ok(Some(result))
}

/// Checks a wildcard pattern and returns `None`.
///
/// The reason the type is `Result<Option<Vec<u8>>, Error>` instead of `Result<(), Error>` is for
/// consistency with other functions.
fn convert_wildcard_pattern(pat: &PatWild) -> Result<Option<Vec<u8>>, Error> {
let PatWild { attrs, .. } = pat;
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
}
Ok(None)
}

/// Converts a reference pattern (e.g. `&[0, 1, ...]`) into a byte sequence.
fn convert_reference_pattern(pat: &PatReference) -> Result<Option<Vec<u8>>, Error> {
let PatReference { attrs, pat, .. } = pat;
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
}
match &**pat {
Pat::Lit(pat) => convert_literal_pattern(pat),
Pat::Slice(pat) => convert_slice_pattern(pat),
Pat::Reference(pat) => convert_reference_pattern(pat),
_ => Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN)),
}
}

/// Retrieves pattern strings from the given token.
///
/// None indicates a wild card pattern (`_`).
fn retrieve_match_patterns(pat: &Pat) -> Result<Vec<Option<String>>, Error> {
fn retrieve_match_patterns(pat: &Pat) -> Result<Vec<Option<Vec<u8>>>, Error> {
let mut pats = vec![];
match pat {
Pat::Lit(ExprLit {
lit: Lit::Str(s),
attrs,
}) => {
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), "attribute not supported here"));
}
pats.push(Some(s.value()));
}
Pat::Lit(pat) => pats.push(convert_literal_pattern(pat)?),
Pat::Slice(pat) => pats.push(convert_slice_pattern(pat)?),
Pat::Wild(pat) => pats.push(convert_wildcard_pattern(pat)?),
Pat::Reference(pat) => pats.push(convert_reference_pattern(pat)?),
Pat::Or(PatOr {
attrs,
leading_vert: None,
cases,
}) => {
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), "attribute not supported here"));
return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
}
for pat in cases {
match pat {
Pat::Lit(ExprLit {
lit: Lit::Str(s),
attrs,
}) => {
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), "attribute not supported here"));
}
pats.push(Some(s.value()));
}
Pat::Lit(pat) => pats.push(convert_literal_pattern(pat)?),
Pat::Slice(pat) => pats.push(convert_slice_pattern(pat)?),
Pat::Wild(pat) => pats.push(convert_wildcard_pattern(pat)?),
Pat::Reference(pat) => pats.push(convert_reference_pattern(pat)?),
_ => {
return Err(Error::new(
pat.span(),
"`trie_match` only supports string literal patterns",
));
return Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN));
}
}
}
}
Pat::Wild(PatWild { attrs, .. }) => {
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), "attribute not supported here"));
}
pats.push(None);
}
_ => {
return Err(Error::new(
pat.span(),
"`trie_match` only supports string literal patterns",
));
return Err(Error::new(pat.span(), ERROR_UNEXPECTED_PATTERN));
}
}
Ok(pats)
}

struct MatchInfo {
bodies: Vec<Expr>,
pattern_map: HashMap<String, usize>,
pattern_map: HashMap<Vec<u8>, usize>,
wildcard_idx: usize,
}

Expand All @@ -130,32 +192,29 @@ fn parse_match_arms(arms: Vec<Arm>) -> Result<MatchInfo, Error> {
) in arms.into_iter().enumerate()
{
if let Some(attr) = attrs.first() {
return Err(Error::new(attr.span(), "attribute not supported here"));
return Err(Error::new(attr.span(), ERROR_ATTRIBUTE_NOT_SUPPORTED));
}
if let Some((if_token, _)) = guard {
return Err(Error::new(if_token.span(), "match guard not supported"));
return Err(Error::new(if_token.span(), ERROR_GUARD_NOT_SUPPORTED));
}
let pat_strs = retrieve_match_patterns(&pat)?;
for pat_str in pat_strs {
if let Some(pat_str) = pat_str {
if pattern_map.contains_key(&pat_str) {
return Err(Error::new(pat.span(), "unreachable pattern"));
let pat_bytes_set = retrieve_match_patterns(&pat)?;
for pat_bytes in pat_bytes_set {
if let Some(pat_bytes) = pat_bytes {
if pattern_map.contains_key(&pat_bytes) {
return Err(Error::new(pat.span(), ERROR_UNREACHABLE_PATTERN));
}
pattern_map.insert(pat_str, i);
pattern_map.insert(pat_bytes, i);
} else {
if wildcard_idx.is_some() {
return Err(Error::new(pat.span(), "unreachable pattern"));
return Err(Error::new(pat.span(), ERROR_UNREACHABLE_PATTERN));
}
wildcard_idx.replace(i);
}
}
bodies.push(*body);
}
let Some(wildcard_idx) = wildcard_idx else {
return Err(Error::new(
Span::call_site(),
"non-exhaustive patterns: `_` not covered",
));
return Err(Error::new(Span::call_site(), ERROR_PATTERN_NOT_COVERED));
};
Ok(MatchInfo {
bodies,
Expand All @@ -168,13 +227,11 @@ fn trie_match_inner(input: ExprMatch) -> Result<TokenStream, Error> {
let ExprMatch {
attrs, expr, arms, ..
} = input;

let MatchInfo {
bodies,
pattern_map,
wildcard_idx,
} = parse_match_arms(arms)?;

let mut trie = Sparse::new();
for (k, v) in pattern_map {
if v == wildcard_idx {
Expand Down Expand Up @@ -203,12 +260,12 @@ fn trie_match_inner(input: ExprMatch) -> Result<TokenStream, Error> {
#( #enumvalue, )*
}
#( #attr )*
match (|query: &str| unsafe {
match (|query: &[u8]| unsafe {
let bases: &'static [i32] = &[ #( #base, )* ];
let out_checks: &'static [(__TrieMatchValue, u8)] = &[ #( #out_check, )* ];
let mut pos = 0;
let mut base = bases[0];
for &b in query.as_bytes() {
for &b in query {
pos = base.wrapping_add(i32::from(b)) as usize;
if let Some((_, check)) = out_checks.get(pos) {
if *check == b {
Expand All @@ -219,7 +276,7 @@ fn trie_match_inner(input: ExprMatch) -> Result<TokenStream, Error> {
return __TrieMatchValue::#wildcard_ident;
}
out_checks.get_unchecked(pos).0
})( #expr ) {
})( ::std::convert::AsRef::<[u8]>::as_ref( #expr ) ) {
#( #arm, )*
}
}
Expand Down
56 changes: 56 additions & 0 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,59 @@ fn test_invalid_root_check_of_zero() {
};
assert_eq!(f("\u{0}\u{1}"), 0);
}

#[test]
fn test_bytes_literal() {
let f = |text: &[u8]| {
trie_match! {
match text {
b"abc" => 0,
_ => 1,
}
}
};
assert_eq!(f(b"abc"), 0);
assert_eq!(f(b"ab"), 1);
}

#[test]
fn test_slice_byte_literal() {
let f = |text: &[u8]| {
trie_match! {
match text {
[b'a', b'b', b'c'] => 0,
_ => 1,
}
}
};
assert_eq!(f(b"abc"), 0);
assert_eq!(f(b"ab"), 1);
}

#[test]
fn test_slice_numbers() {
let f = |text: &[u8]| {
trie_match! {
match text {
[0, 1, 2] => 0,
_ => 1,
}
}
};
assert_eq!(f(&[0, 1, 2]), 0);
assert_eq!(f(&[0, 1]), 1);
}

#[test]
fn test_slice_ref_numbers() {
let f = |text: &[u8]| {
trie_match! {
match text {
&[0, 1, 2] => 0,
_ => 1,
}
}
};
assert_eq!(f(&[0, 1, 2]), 0);
assert_eq!(f(&[0, 1]), 1);
}