Skip to content

Commit

Permalink
Cache substituted patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Sep 26, 2024
1 parent 4a6ed83 commit 628ebfe
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 6 deletions.
17 changes: 17 additions & 0 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,9 @@ impl PythonTransformer {
/// The first level is 0 and the level is increased when going into a function or one level deeper in the expression tree,
/// depending on `level_is_tree_depth`.
///
/// For efficiency, the first `rhs_cache_size` substituted patterns are cached.
/// If set to `None`, an internally determined cache size is used.
///
/// Examples
/// --------
///
Expand All @@ -1172,6 +1175,7 @@ impl PythonTransformer {
level_range: Option<(usize, Option<usize>)>,
level_is_tree_depth: Option<bool>,
allow_new_wildcards_on_rhs: Option<bool>,
rhs_cache_size: Option<usize>,
) -> PyResult<PythonTransformer> {
let mut settings = MatchSettings::default();

Expand Down Expand Up @@ -1203,6 +1207,9 @@ impl PythonTransformer {
if let Some(allow_new_wildcards_on_rhs) = allow_new_wildcards_on_rhs {
settings.allow_new_wildcards_on_rhs = allow_new_wildcards_on_rhs;
}
if let Some(rhs_cache_size) = rhs_cache_size {
settings.rhs_cache_size = rhs_cache_size;
}

return append_transformer!(
self,
Expand Down Expand Up @@ -3787,6 +3794,8 @@ impl PythonExpression {
/// If set to `True`, the level is increased when going one level deeper in the expression tree.
/// allow_new_wildcards_on_rhs: bool, optional
/// If set to `True`, wildcards that do not appear ion the pattern are allowed on the right-hand side.
/// rhs_cache_size: int, optional
/// Cache the first `rhs_cache_size` substituted patterns. If set to `None`, an internally determined cache size is used.
/// repeat: bool, optional
/// If set to `True`, the entire operation will be repeated until there are no more matches.
pub fn replace_all(
Expand All @@ -3798,6 +3807,7 @@ impl PythonExpression {
level_range: Option<(usize, Option<usize>)>,
level_is_tree_depth: Option<bool>,
allow_new_wildcards_on_rhs: Option<bool>,
rhs_cache_size: Option<usize>,
repeat: Option<bool>,
) -> PyResult<PythonExpression> {
let pattern = &pattern.to_pattern()?.expr;
Expand Down Expand Up @@ -3833,6 +3843,9 @@ impl PythonExpression {
if let Some(allow_new_wildcards_on_rhs) = allow_new_wildcards_on_rhs {
settings.allow_new_wildcards_on_rhs = allow_new_wildcards_on_rhs;
}
if let Some(rhs_cache_size) = rhs_cache_size {
settings.rhs_cache_size = rhs_cache_size;
}

let mut expr_ref = self.expr.as_view();

Expand Down Expand Up @@ -4410,6 +4423,7 @@ impl PythonReplacement {
level_range: Option<(usize, Option<usize>)>,
level_is_tree_depth: Option<bool>,
allow_new_wildcards_on_rhs: Option<bool>,
rhs_cache_size: Option<usize>,
) -> PyResult<Self> {
let pattern = pattern.to_pattern()?.expr;
let rhs = rhs.to_pattern_or_map()?;
Expand Down Expand Up @@ -4444,6 +4458,9 @@ impl PythonReplacement {
if let Some(allow_new_wildcards_on_rhs) = allow_new_wildcards_on_rhs {
settings.allow_new_wildcards_on_rhs = allow_new_wildcards_on_rhs;
}
if let Some(rhs_cache_size) = rhs_cache_size {
settings.rhs_cache_size = rhs_cache_size;
}

let cond = cond
.as_ref()
Expand Down
2 changes: 1 addition & 1 deletion src/atom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ pub enum AtomType {
Fun,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SliceType {
Add,
Mul,
Expand Down
58 changes: 53 additions & 5 deletions src/id.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{ops::DerefMut, str::FromStr};

use ahash::HashSet;
use ahash::{HashMap, HashSet};
use dyn_clone::DynClone;

use crate::{
Expand Down Expand Up @@ -306,7 +306,8 @@ impl<'a> AtomView<'a> {
out: &mut Atom,
) -> bool {
Workspace::get_local().with(|ws| {
let matched = self.replace_all_no_norm(replacements, ws, 0, 0, out);
let mut rhs_cache = HashMap::default();
let matched = self.replace_all_no_norm(replacements, ws, 0, 0, &mut rhs_cache, out);

if matched {
let mut norm = ws.new_atom();
Expand All @@ -325,6 +326,7 @@ impl<'a> AtomView<'a> {
workspace: &Workspace,
tree_level: usize,
fn_level: usize,
rhs_cache: &mut HashMap<Vec<(Symbol, Match<'a>)>, Atom>,
out: &mut Atom,
) -> bool {
let mut beyond_max_level = true;
Expand Down Expand Up @@ -355,6 +357,11 @@ impl<'a> AtomView<'a> {

let mut it = AtomMatchIterator::new(r.pat, *self);
if let Some((_, used_flags)) = it.next(&mut match_stack) {
if let Some(rhs) = rhs_cache.get(&match_stack.stack) {
out.set_from_view(&rhs.as_view());
return true;
}

let mut rhs_subs = workspace.new_atom();

match r.rhs {
Expand All @@ -371,6 +378,13 @@ impl<'a> AtomView<'a> {
if used_flags.iter().all(|x| *x) {
// all used, return rhs
out.set_from_view(&rhs_subs.as_view());

if rhs_cache.len() < settings.rhs_cache_size
&& !matches!(r.rhs, PatternOrMap::Pattern(Pattern::Literal(_)))
{
rhs_cache.insert(match_stack.stack.clone(), rhs_subs.into_inner());
}

return true;
}

Expand Down Expand Up @@ -402,6 +416,12 @@ impl<'a> AtomView<'a> {
}
}

if rhs_cache.len() < settings.rhs_cache_size
&& !matches!(r.rhs, PatternOrMap::Pattern(Pattern::Literal(_)))
{
rhs_cache.insert(match_stack.stack.clone(), rhs_subs.into_inner());
}

return true;
}
}
Expand All @@ -426,6 +446,7 @@ impl<'a> AtomView<'a> {
workspace,
tree_level + 1,
fn_level + 1,
rhs_cache,
&mut child_buf,
);

Expand All @@ -444,6 +465,7 @@ impl<'a> AtomView<'a> {
workspace,
tree_level + 1,
fn_level,
rhs_cache,
&mut base_out,
);

Expand All @@ -453,6 +475,7 @@ impl<'a> AtomView<'a> {
workspace,
tree_level + 1,
fn_level,
rhs_cache,
&mut exp_out,
);

Expand All @@ -471,6 +494,7 @@ impl<'a> AtomView<'a> {
workspace,
tree_level + 1,
fn_level,
rhs_cache,
&mut child_buf,
);

Expand All @@ -491,6 +515,7 @@ impl<'a> AtomView<'a> {
workspace,
tree_level + 1,
fn_level,
rhs_cache,
&mut child_buf,
);

Expand Down Expand Up @@ -1153,7 +1178,15 @@ impl Pattern {
rep = rep.with_settings(s);
}

let matched = target.replace_all_no_norm(std::slice::from_ref(&rep), workspace, 0, 0, out);
let mut rhs_cache = HashMap::default();
let matched = target.replace_all_no_norm(
std::slice::from_ref(&rep),
workspace,
0,
0,
&mut rhs_cache,
out,
);

if matched {
let mut norm = workspace.new_atom();
Expand Down Expand Up @@ -1545,7 +1578,7 @@ impl std::fmt::Debug for PatternRestriction {
}

/// A part of an expression that was matched to a wildcard.
#[derive(Clone, PartialEq)]
#[derive(Clone, PartialEq, Eq, Hash)]
pub enum Match<'a> {
/// A matched single atom.
Single(AtomView<'a>),
Expand Down Expand Up @@ -1664,7 +1697,7 @@ impl<'a> Match<'a> {
}

/// Settings related to pattern matching.
#[derive(Debug, Default, Clone)]
#[derive(Debug, Clone)]
pub struct MatchSettings {
/// Specifies wildcards that try to match as little as possible.
pub non_greedy_wildcards: Vec<Symbol>,
Expand All @@ -1676,6 +1709,21 @@ pub struct MatchSettings {
pub level_is_tree_depth: bool,
/// Allow wildcards on the right-hand side that do not appear in the pattern.
pub allow_new_wildcards_on_rhs: bool,
/// The maximum size of the cache for the right-hand side of a replacement.
/// This can be used to prevent expensive recomputations.
pub rhs_cache_size: usize,
}

impl Default for MatchSettings {
fn default() -> Self {
Self {
non_greedy_wildcards: Vec::new(),
level_range: (0, None),
level_is_tree_depth: false,
allow_new_wildcards_on_rhs: false,
rhs_cache_size: 100,
}
}
}

/// An insertion-ordered map of wildcard identifiers to a subexpressions.
Expand Down
5 changes: 5 additions & 0 deletions symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,7 @@ class Expression:
level_range: Optional[Tuple[int, Optional[int]]] = None,
level_is_tree_depth: Optional[bool] = False,
allow_new_wildcards_on_rhs: Optional[bool] = False,
rhs_cache_size: Optional[int] = None,
repeat: Optional[bool] = False,
) -> Expression:
"""
Expand Down Expand Up @@ -1096,6 +1097,8 @@ class Expression:
If set to `True`, the level is increased when going one level deeper in the expression tree.
allow_new_wildcards_on_rhs: bool, optional
If set to `True`, allow wildcards that do not appear in the pattern on the right-hand side.
rhs_cache_size: int, optional
Cache the first `rhs_cache_size` substituted patterns. If set to `None`, an internally determined cache size is used.
repeat: bool, optional
If set to `True`, the entire operation will be repeated until there are no more matches.
"""
Expand Down Expand Up @@ -1767,6 +1770,8 @@ class Transformer:
If set to `True`, the level is increased when going one level deeper in the expression tree.
allow_new_wildcards_on_rhs:
If set to `True`, allow wildcards that do not appear in the pattern on the right-hand side.
rhs_cache_size: int, optional
Cache the first `rhs_cache_size` substituted patterns. If set to `None`, an internally determined cache size is used.
repeat:
If set to `True`, the entire operation will be repeated until there are no more matches.
"""
Expand Down

0 comments on commit 628ebfe

Please sign in to comment.