diff --git a/compiler/rustc_parse/src/parser/item.rs b/compiler/rustc_parse/src/parser/item.rs index ade441b0e7d5c..06849b3125683 100644 --- a/compiler/rustc_parse/src/parser/item.rs +++ b/compiler/rustc_parse/src/parser/item.rs @@ -423,7 +423,7 @@ impl<'a> Parser<'a> { // Maybe the user misspelled `macro_rules` (issue #91227) if self.token.is_ident() && path.segments.len() == 1 - && lev_distance("macro_rules", &path.segments[0].ident.to_string()) <= 3 + && lev_distance("macro_rules", &path.segments[0].ident.to_string(), 3).is_some() { err.span_suggestion( path.span, diff --git a/compiler/rustc_span/src/lev_distance.rs b/compiler/rustc_span/src/lev_distance.rs index aed699e4839e9..93cf965f1056b 100644 --- a/compiler/rustc_span/src/lev_distance.rs +++ b/compiler/rustc_span/src/lev_distance.rs @@ -11,16 +11,21 @@ use std::cmp; mod tests; /// Finds the Levenshtein distance between two strings. -pub fn lev_distance(a: &str, b: &str) -> usize { - // cases which don't require further computation - if a.is_empty() { - return b.chars().count(); - } else if b.is_empty() { - return a.chars().count(); +/// +/// Returns None if the distance exceeds the limit. +pub fn lev_distance(a: &str, b: &str, limit: usize) -> Option { + let n = a.chars().count(); + let m = b.chars().count(); + let min_dist = if n < m { m - n } else { n - m }; + + if min_dist > limit { + return None; + } + if n == 0 || m == 0 { + return (min_dist <= limit).then_some(min_dist); } - let mut dcol: Vec<_> = (0..=b.len()).collect(); - let mut t_last = 0; + let mut dcol: Vec<_> = (0..=m).collect(); for (i, sc) in a.chars().enumerate() { let mut current = i; @@ -35,10 +40,10 @@ pub fn lev_distance(a: &str, b: &str) -> usize { dcol[j + 1] = cmp::min(dcol[j + 1], dcol[j]) + 1; } current = next; - t_last = j; } } - dcol[t_last + 1] + + (dcol[m] <= limit).then_some(dcol[m]) } /// Finds the best match for a given word in the given iterator. @@ -51,39 +56,38 @@ pub fn lev_distance(a: &str, b: &str) -> usize { /// on an edge case with a lower(upper)case letters mismatch. #[cold] pub fn find_best_match_for_name( - name_vec: &[Symbol], + candidates: &[Symbol], lookup: Symbol, dist: Option, ) -> Option { let lookup = lookup.as_str(); - let max_dist = dist.unwrap_or_else(|| cmp::max(lookup.len(), 3) / 3); + let lookup_uppercase = lookup.to_uppercase(); // Priority of matches: // 1. Exact case insensitive match // 2. Levenshtein distance match // 3. Sorted word match - if let Some(case_insensitive_match) = - name_vec.iter().find(|candidate| candidate.as_str().to_uppercase() == lookup.to_uppercase()) - { - return Some(*case_insensitive_match); + if let Some(c) = candidates.iter().find(|c| c.as_str().to_uppercase() == lookup_uppercase) { + return Some(*c); } - let levenshtein_match = name_vec - .iter() - .filter_map(|&name| { - let dist = lev_distance(lookup, name.as_str()); - if dist <= max_dist { Some((name, dist)) } else { None } - }) - // Here we are collecting the next structure: - // (levenshtein_match, levenshtein_distance) - .fold(None, |result, (candidate, dist)| match result { - None => Some((candidate, dist)), - Some((c, d)) => Some(if dist < d { (candidate, dist) } else { (c, d) }), - }); - if levenshtein_match.is_some() { - levenshtein_match.map(|(candidate, _)| candidate) - } else { - find_match_by_sorted_words(name_vec, lookup) + + let mut dist = dist.unwrap_or_else(|| cmp::max(lookup.len(), 3) / 3); + let mut best = None; + for c in candidates { + match lev_distance(lookup, c.as_str(), dist) { + Some(0) => return Some(*c), + Some(d) => { + dist = d - 1; + best = Some(*c); + } + None => {} + } } + if best.is_some() { + return best; + } + + find_match_by_sorted_words(candidates, lookup) } fn find_match_by_sorted_words(iter_names: &[Symbol], lookup: &str) -> Option { diff --git a/compiler/rustc_span/src/lev_distance/tests.rs b/compiler/rustc_span/src/lev_distance/tests.rs index b32f8d32c1391..4e34219248d41 100644 --- a/compiler/rustc_span/src/lev_distance/tests.rs +++ b/compiler/rustc_span/src/lev_distance/tests.rs @@ -5,18 +5,26 @@ fn test_lev_distance() { use std::char::{from_u32, MAX}; // Test bytelength agnosticity for c in (0..MAX as u32).filter_map(from_u32).map(|i| i.to_string()) { - assert_eq!(lev_distance(&c[..], &c[..]), 0); + assert_eq!(lev_distance(&c[..], &c[..], usize::MAX), Some(0)); } let a = "\nMäry häd ä little lämb\n\nLittle lämb\n"; let b = "\nMary häd ä little lämb\n\nLittle lämb\n"; let c = "Mary häd ä little lämb\n\nLittle lämb\n"; - assert_eq!(lev_distance(a, b), 1); - assert_eq!(lev_distance(b, a), 1); - assert_eq!(lev_distance(a, c), 2); - assert_eq!(lev_distance(c, a), 2); - assert_eq!(lev_distance(b, c), 1); - assert_eq!(lev_distance(c, b), 1); + assert_eq!(lev_distance(a, b, usize::MAX), Some(1)); + assert_eq!(lev_distance(b, a, usize::MAX), Some(1)); + assert_eq!(lev_distance(a, c, usize::MAX), Some(2)); + assert_eq!(lev_distance(c, a, usize::MAX), Some(2)); + assert_eq!(lev_distance(b, c, usize::MAX), Some(1)); + assert_eq!(lev_distance(c, b, usize::MAX), Some(1)); +} + +#[test] +fn test_lev_distance_limit() { + assert_eq!(lev_distance("abc", "abcd", 1), Some(1)); + assert_eq!(lev_distance("abc", "abcd", 0), None); + assert_eq!(lev_distance("abc", "xyz", 3), Some(3)); + assert_eq!(lev_distance("abc", "xyz", 2), None); } #[test] diff --git a/compiler/rustc_span/src/lib.rs b/compiler/rustc_span/src/lib.rs index 92360164a019b..29c76027c150e 100644 --- a/compiler/rustc_span/src/lib.rs +++ b/compiler/rustc_span/src/lib.rs @@ -15,6 +15,7 @@ #![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")] #![feature(array_windows)] +#![feature(bool_to_option)] #![feature(crate_visibility_modifier)] #![feature(if_let_guard)] #![feature(negative_impls)] diff --git a/compiler/rustc_typeck/src/check/method/probe.rs b/compiler/rustc_typeck/src/check/method/probe.rs index 9efaa37633e3e..3815fd1992bf3 100644 --- a/compiler/rustc_typeck/src/check/method/probe.rs +++ b/compiler/rustc_typeck/src/check/method/probe.rs @@ -1904,8 +1904,13 @@ impl<'a, 'tcx> ProbeContext<'a, 'tcx> { .associated_items(def_id) .in_definition_order() .filter(|x| { - let dist = lev_distance(name.as_str(), x.name.as_str()); - x.kind.namespace() == Namespace::ValueNS && dist > 0 && dist <= max_dist + if x.kind.namespace() != Namespace::ValueNS { + return false; + } + match lev_distance(name.as_str(), x.name.as_str(), max_dist) { + Some(d) => d > 0, + None => false, + } }) .copied() .collect()