Skip to content

Commit

Permalink
Return possible byte candidates from next_byte()
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jan 25, 2025
1 parent 94a980a commit 96bdecc
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 66 deletions.
136 changes: 112 additions & 24 deletions src/ast.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::{
fmt::Debug,
hash::Hash,
ops::{BitAnd, BitOr, RangeInclusive},
ops::{BitOr, RangeInclusive},
};

use crate::{hashcons::VecHashCons, pp::PrettyPrinter, simplify::OwnedConcatElement};
use crate::{hashcons::VecHashCons, pp::PrettyPrinter, simplify::OwnedConcatElement, AlphabetInfo};
use bytemuck_derive::{Pod, Zeroable};
use hashbrown::HashMap;

Expand Down Expand Up @@ -621,48 +622,135 @@ impl ExprSet {
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum NextByte {
/// Transition via any other byte, or EOI leads to a dead state.
ForcedByte(u8),
/// Transition via any byte leads to a dead state but EOI is possible.
ForcedEOI,
/// Transition via some bytes *may be* possible.
SomeBytes,
/// The bytes are possible examples.
SomeBytes0,
SomeBytes1(u8),
SomeBytes2([u8; 2]),
/// The current state is dead.
/// Should be only true for NO_MATCH.
Dead,
}

impl BitAnd for NextByte {
type Output = Self;
fn bitand(self, other: Self) -> Self {
if self == other {
self
impl Debug for NextByte {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
NextByte::ForcedByte(b) => write!(f, "ForcedByte({:?})", *b as char),
NextByte::ForcedEOI => write!(f, "ForcedEOI"),
NextByte::SomeBytes0 => write!(f, "SomeBytes0"),
NextByte::SomeBytes1(b) => write!(f, "SomeBytes1({:?})", *b as char),
NextByte::SomeBytes2([a, b]) => write!(f, "SomeBytes2({:?}, {:?})", *a as char, *b as char),
NextByte::Dead => write!(f, "Dead"),
}
}
}

impl NextByte {
pub fn some_bytes(&self) -> &[u8] {
match self {
NextByte::ForcedByte(b) => std::slice::from_ref(b),
NextByte::SomeBytes1(b) => std::slice::from_ref(b),
NextByte::SomeBytes2(b) => b,
_ => &[],
}
}

pub fn is_some_bytes(&self) -> bool {
match self {
NextByte::SomeBytes0 | NextByte::SomeBytes1(_) | NextByte::SomeBytes2(_) => true,
_ => false,
}
}

pub fn some_bytes_from_slice(s: &[u8]) -> Self {
match s.len() {
0 => NextByte::SomeBytes0,
1 => NextByte::SomeBytes1(s[0]),
_ => NextByte::SomeBytes2([s[0], s[1]]),
}
}

pub fn make_fuzzy(&self) -> Self {
match self {
NextByte::ForcedByte(a) => NextByte::SomeBytes1(*a),
NextByte::ForcedEOI => NextByte::SomeBytes0,
_ => self.clone(),
}
}

fn sorted_some_bytes(a: u8, b: u8) -> Self {
assert!(a != b);
if a < b {
NextByte::SomeBytes2([a, b])
} else {
if self == NextByte::SomeBytes {
other
} else if other == NextByte::SomeBytes {
self
} else {
NextByte::Dead
NextByte::SomeBytes2([b, a])
}
}

pub fn map_alpha(&self, alpha: &AlphabetInfo) -> Self {
match self {
NextByte::ForcedByte(b) => {
let (x, y) = alpha.inv_map(*b as usize);
if x == y {
NextByte::ForcedByte(x)
} else {
Self::sorted_some_bytes(x, y)
}
}
NextByte::SomeBytes1(a) => {
let (a, b) = alpha.inv_map(*a as usize);
if a != b {
Self::sorted_some_bytes(a, b)
} else {
NextByte::SomeBytes1(a)
}
}
NextByte::SomeBytes2([a, b]) => {
let a = alpha.inv_map(*a as usize).0;
let b = alpha.inv_map(*b as usize).0;
Self::sorted_some_bytes(a, b)
}
_ => self.clone(),
}
}
}

impl BitOr for NextByte {
type Output = Self;
fn bitor(self, other: Self) -> Self {
if self == other {
self
} else {
if self == NextByte::Dead {
other
} else if other == NextByte::Dead {
self
} else {
NextByte::SomeBytes
match (self, other) {
(NextByte::Dead, _) => other,
(_, NextByte::Dead) => self,
(NextByte::ForcedByte(a), NextByte::ForcedByte(b)) => {
if a == b {
self
} else {
NextByte::SomeBytes2([a, b])
}
}
(NextByte::ForcedEOI, NextByte::ForcedEOI) => self,
_ => {
let a = self.some_bytes();
let b = other.some_bytes();
if a.is_empty() || b.len() > 1 {
NextByte::some_bytes_from_slice(b)
} else if b.is_empty() || a.len() > 1 {
NextByte::some_bytes_from_slice(a)
} else {
let a = a[0];
let b = b[0];
if a == b {
NextByte::SomeBytes1(a)
} else {
NextByte::SomeBytes2([a, b])
}
}
}
}
}
Expand Down
62 changes: 44 additions & 18 deletions src/nextbyte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,61 @@ pub struct NextByteCache {
}

pub(crate) fn next_byte_simple(exprs: &ExprSet, mut r: ExprRef) -> NextByte {
loop {
let mut fuzzy = false;
let res = 'dfs: loop {
match exprs.get(r) {
Expr::EmptyString => return NextByte::ForcedEOI,
Expr::NoMatch => return NextByte::Dead,
Expr::ByteSet(_) => return NextByte::SomeBytes,
Expr::Byte(b) => return NextByte::ForcedByte(b),
Expr::ByteConcat(_, bytes, _) => return NextByte::ForcedByte(bytes[0]),
Expr::And(_, _) => return NextByte::SomeBytes,
Expr::Not(_, _) => return NextByte::SomeBytes,
Expr::RemainderIs { .. } => return NextByte::SomeBytes,
Expr::EmptyString => break NextByte::ForcedEOI,
Expr::NoMatch => break NextByte::Dead,
Expr::ByteSet(lst) => {
let mut b0 = None;
for (idx, &w) in lst.iter().enumerate() {
if w > 0 {
let b = (idx as u32 * 32 + w.trailing_zeros()) as u8;
if b0.is_some() {
break 'dfs NextByte::SomeBytes2([b0.unwrap(), b]);
} else {
b0 = Some(b);
}
let w = w & !(1 << (b as u32 % 32));
if w > 0 {
let b2 = (idx as u32 * 32 + w.trailing_zeros()) as u8;
break 'dfs NextByte::SomeBytes2([b, b2]);
}
}
}
unreachable!("ByteSet should have at least two bytes set");
}
Expr::Byte(b) => break NextByte::ForcedByte(b),
Expr::ByteConcat(_, bytes, _) => break NextByte::ForcedByte(bytes[0]),
Expr::Or(_, args) | Expr::And(_, args) => {
fuzzy = true;
r = args[0];
}
Expr::Not(_, _) => break NextByte::SomeBytes0,
Expr::RemainderIs { .. } => {
break NextByte::SomeBytes2([exprs.digits[0], exprs.digits[1]]);
}
Expr::Lookahead(_, e, _) => {
r = e;
}
Expr::Repeat(_, arg, min, _) => {
if min == 0 {
return NextByte::SomeBytes;
} else {
r = arg;
fuzzy = true;
}
r = arg;
}
Expr::Concat(_, args) => {
if exprs.is_nullable(args[0]) {
return NextByte::SomeBytes;
} else {
r = args[0];
fuzzy = true;
}
r = args[0];
}
Expr::Or(_, _) => return NextByte::SomeBytes,
}
};
if fuzzy {
res.make_fuzzy()
} else {
res
}
}

Expand All @@ -59,10 +85,10 @@ impl NextByteCache {
Expr::Or(_, args) => {
let mut found = next_byte_simple(exprs, args[0]);
for child in args.iter().skip(1) {
found = found | next_byte_simple(exprs, *child);
if found == NextByte::SomeBytes {
if found.is_some_bytes() {
break;
}
found = found | next_byte_simple(exprs, *child);
}
found
}
Expand Down
32 changes: 12 additions & 20 deletions src/regex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl Debug for StateID {
#[derive(Clone)]
pub struct AlphabetInfo {
mapping: [u8; 256],
inv_mapping: [Option<u8>; 256],
inv_mapping: [(u8, u8); 256],
size: usize,
}

Expand Down Expand Up @@ -239,16 +239,7 @@ impl Regex {

let e = Self::resolve(&self.rx_sets, state);
let next_byte = self.next_byte.next_byte(&self.exprs, e);
let next_byte = match next_byte {
NextByte::ForcedByte(b) => {
if let Some(b) = self.alpha.inv_map(b as usize) {
NextByte::ForcedByte(b)
} else {
NextByte::SomeBytes
}
}
_ => next_byte,
};
let next_byte = next_byte.map_alpha(&self.alpha);
desc.next_byte = Some(next_byte);
next_byte
}
Expand Down Expand Up @@ -340,16 +331,17 @@ impl AlphabetInfo {
// disable expensive optimizations after initial construction
exprset.disable_optimizations();

let mut inv_alphabet_mapping = [None; 256];
let mut num_mappings = [0; 256];
let mut inv_alphabet_mapping = [(0u8, 0u8); 256];
let mut num_mappings = [0u16; 256];
for (i, &b) in mapping.iter().enumerate() {
inv_alphabet_mapping[b as usize] = Some(i as u8);
num_mappings[b as usize] += 1;
}
for i in 0..alphabet_size {
if num_mappings[i] != 1 {
inv_alphabet_mapping[i] = None;
let bi = b as usize;
let i_byte = i as u8;
if num_mappings[bi] == 0 {
inv_alphabet_mapping[bi] = (i_byte, i_byte);
} else if num_mappings[bi] == 1 {
inv_alphabet_mapping[bi].1 = i_byte;
}
num_mappings[b as usize] += 1;
}

debug!(
Expand Down Expand Up @@ -383,7 +375,7 @@ impl AlphabetInfo {
}
}

pub fn inv_map(&self, v: usize) -> Option<u8> {
pub fn inv_map(&self, v: usize) -> (u8, u8) {
self.inv_mapping[v]
}

Expand Down
9 changes: 5 additions & 4 deletions tests/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ fn unicode_case() {
fn validate_next_byte(rx: &mut Regex, data: Vec<(NextByte, u8)>) {
let mut s = rx.initial_state();
for (exp, b) in data {
println!("next_byte {:?} {:?}", exp, b as char);
let nb = rx.next_byte(s);
if nb != exp {
panic!("expected {:?}, got {:?}", exp, nb);
Expand All @@ -269,8 +270,8 @@ fn next_byte() {
&mut rx,
vec![
(NextByte::ForcedByte(b'a'), b'a'),
(NextByte::SomeBytes, b'b'),
(NextByte::SomeBytes, b'd'),
(NextByte::SomeBytes2([b'b', b'c']), b'b'),
(NextByte::SomeBytes2([b'b', b'c']), b'd'),
(NextByte::ForcedByte(b'x'), b'x'),
(NextByte::ForcedEOI, b'x'),
],
Expand All @@ -281,7 +282,7 @@ fn next_byte() {
&mut rx,
vec![
(NextByte::ForcedByte(b'a'), b'a'),
(NextByte::SomeBytes, b'B'),
(NextByte::SomeBytes2([b'B', b'b']), b'B'),
(NextByte::ForcedByte(b'D'), b'D'),
],
);
Expand All @@ -290,7 +291,7 @@ fn next_byte() {
validate_next_byte(
&mut rx,
vec![
(NextByte::SomeBytes, b'f'),
(NextByte::SomeBytes2([b'b', b'f']), b'f'),
(NextByte::ForcedByte(b'o'), b'o'),
(NextByte::ForcedByte(b'o'), b'o'),
(NextByte::ForcedEOI, b'X'),
Expand Down

0 comments on commit 96bdecc

Please sign in to comment.