diff --git a/src/Lean/Meta/LazyDiscrTree.lean b/src/Lean/Meta/LazyDiscrTree.lean index 103b760510f6..7f3200991137 100644 --- a/src/Lean/Meta/LazyDiscrTree.lean +++ b/src/Lean/Meta/LazyDiscrTree.lean @@ -445,6 +445,35 @@ private def newTrie [Monad m] [MonadState (Array (Trie α)) m] (e : LazyEntry α private def addLazyEntryToTrie (i:TrieIndex) (e : LazyEntry α) : MatchM α Unit := modify (·.modify i (·.pushPending e)) +private def evalLazyEntry (config : WhnfCoreConfig) + (p : Array α × TrieIndex × HashMap Key TrieIndex) + (entry : LazyEntry α) + : MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do + let (values, starIdx, children) := p + let (todo, lctx, v) := entry + if todo.isEmpty then + let values := values.push v + pure (values, starIdx, children) + else + let e := todo.back + let todo := todo.pop + let (k, todo) ← withLCtx lctx.1 lctx.2 $ pushArgs false todo e config + if k == .star then + if starIdx = 0 then + let starIdx ← newTrie (todo, lctx, v) + pure (values, starIdx, children) + else + addLazyEntryToTrie starIdx (todo, lctx, v) + pure (values, starIdx, children) + else + match children.find? k with + | none => + let children := children.insert k (← newTrie (todo, lctx, v)) + pure (values, starIdx, children) + | some idx => + addLazyEntryToTrie idx (todo, lctx, v) + pure (values, starIdx, children) + /-- This evaluates all lazy entries in a trie and updates `values`, `starIdx`, and `children` accordingly. @@ -453,34 +482,10 @@ private partial def evalLazyEntries (config : WhnfCoreConfig) (values : Array α) (starIdx : TrieIndex) (children : HashMap Key TrieIndex) (entries : Array (LazyEntry α)) : MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do - let rec iter values starIdx children (i : Nat) : MatchM α _ := do - if p : i < entries.size then - let (todo, lctx, v) := entries[i] - if todo.isEmpty then - let values := values.push v - iter values starIdx children (i+1) - else - let e := todo.back - let todo := todo.pop - let (k, todo) ← withLCtx lctx.1 lctx.2 $ pushArgs false todo e config - if k == .star then - if starIdx = 0 then - let starIdx ← newTrie (todo, lctx, v) - iter values starIdx children (i+1) - else - addLazyEntryToTrie starIdx (todo, lctx, v) - iter values starIdx children (i+1) - else - match children.find? k with - | none => - let children := children.insert k (← newTrie (todo, lctx, v)) - iter values starIdx children (i+1) - | some idx => - addLazyEntryToTrie idx (todo, lctx, v) - iter values starIdx children (i+1) - else - pure (values, starIdx, children) - iter values starIdx children 0 + let mut values := values + let mut starIdx := starIdx + let mut children := children + entries.foldlM (init := (values, starIdx, children)) (evalLazyEntry config) private def evalNode (c : TrieIndex) : MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do @@ -590,7 +595,7 @@ private partial def getMatchLoop (todo : Array Expr) (score : Nat) (c : TrieInde and there is an edge for `k` and `k != Key.star`. -/ let visitStar (result : MatchResult α) : MatchM α (MatchResult α) := if star != 0 then - getMatchLoop todo score star result + getMatchLoop todo (score + 1) star result else return result let visitNonStar (k : Key) (args : Array Expr) (result : MatchResult α) := @@ -617,13 +622,13 @@ private def getStarResult (root : Lean.HashMap Key TrieIndex) : MatchM α (Match pure <| {} | some idx => do let (vs, _) ← evalNode idx - pure <| ({} : MatchResult α).push 0 vs + pure <| ({} : MatchResult α).push (score := 1) vs private def getMatchRoot (r : Lean.HashMap Key TrieIndex) (k : Key) (args : Array Expr) (result : MatchResult α) : MatchM α (MatchResult α) := match r.find? k with | none => pure result - | some c => getMatchLoop args 1 c result + | some c => getMatchLoop args (score := 1) c result /-- Find values that match `e` in `root`. @@ -631,12 +636,12 @@ private def getMatchRoot (r : Lean.HashMap Key TrieIndex) (k : Key) (args : Arra private def getMatchCore (root : Lean.HashMap Key TrieIndex) (e : Expr) : MatchM α (MatchResult α) := do let result ← getStarResult root - let (k, args) ← MatchClone.getMatchKeyArgs e (root := true) (←read) + let (k, args) ← MatchClone.getMatchKeyArgs e (root := true) (← read) match k with | .star => return result /- See note about "dep-arrow vs arrow" at `getMatchLoop` -/ | .arrow => - getMatchRoot root k args (←getMatchRoot root .other #[] result) + getMatchRoot root k args (← getMatchRoot root .other #[] result) | _ => getMatchRoot root k args result @@ -756,7 +761,28 @@ structure Cache where def Cache.empty (ngen : NameGenerator) : Cache := { ngen := ngen, core := {}, meta := {} } +def matchPrefix (s : String) (pre : String) := + s.startsWith pre && (s |>.drop pre.length |>.all Char.isDigit) + +def isInternalDetail : Name → Bool + | .str p s => + s.startsWith "_" + || matchPrefix s "eq_" + || matchPrefix s "match_" + || matchPrefix s "proof_" + || p.isInternalOrNum + | .num _ _ => true + | p => p.isInternalOrNum + +def blacklistInsertion (env : Environment) (declName : Name) : Bool := + !allowCompletion env declName + || declName == ``sorryAx + || isInternalDetail declName + || (declName matches .str _ "inj") + || (declName matches .str _ "noConfusionType") + private def addConstImportData + (cctx : Core.Context) (env : Environment) (modName : Name) (d : ImportData) @@ -765,16 +791,12 @@ private def addConstImportData (act : Name → ConstantInfo → MetaM (Array (InitEntry α))) (name : Name) (constInfo : ConstantInfo) : BaseIO (PreDiscrTree α) := do if constInfo.isUnsafe then return tree - if !allowCompletion env name then return tree + if blacklistInsertion env name then return tree let { ngen, core := core_cache, meta := meta_cache } ← cacheRef.get let mstate : Meta.State := { cache := meta_cache } cacheRef.set (Cache.empty ngen) let ctx : Meta.Context := { config := { transparency := .reducible } } let cm := (act name constInfo).run ctx mstate - let cctx : Core.Context := { - fileName := default, - fileMap := default - } let cstate : Core.State := {env, cache := core_cache, ngen} match ←(cm.run cctx cstate).toBaseIO with | .ok ((a, ms), cs) => @@ -818,7 +840,9 @@ private def toFlat (d : ImportData) (tree : PreDiscrTree α) : let de ← d.errors.swap #[] pure ⟨tree, de⟩ -private partial def loadImportedModule (env : Environment) +private partial def loadImportedModule + (cctx : Core.Context) + (env : Environment) (act : Name → ConstantInfo → MetaM (Array (InitEntry α))) (d : ImportData) (cacheRef : IO.Ref Cache) @@ -829,12 +853,12 @@ private partial def loadImportedModule (env : Environment) if h : i < mdata.constNames.size then let name := mdata.constNames[i] let constInfo := mdata.constants[i]! - let tree ← addConstImportData env mname d cacheRef tree act name constInfo - loadImportedModule env act d cacheRef tree mname mdata (i+1) + let tree ← addConstImportData cctx env mname d cacheRef tree act name constInfo + loadImportedModule cctx env act d cacheRef tree mname mdata (i+1) else pure tree -private def createImportedEnvironmentSeq (ngen : NameGenerator) (env : Environment) +private def createImportedEnvironmentSeq (cctx : Core.Context) (ngen : NameGenerator) (env : Environment) (act : Name → ConstantInfo → MetaM (Array (InitEntry α))) (start stop : Nat) : BaseIO (InitResults α) := do let cacheRef ← IO.mkRef (Cache.empty ngen) @@ -843,7 +867,7 @@ private def createImportedEnvironmentSeq (ngen : NameGenerator) (env : Environme if start < stop then let mname := env.header.moduleNames[start]! let mdata := env.header.moduleData[start]! - let tree ← loadImportedModule env act d cacheRef tree mname mdata + let tree ← loadImportedModule cctx env act d cacheRef tree mname mdata go d cacheRef tree (start+1) stop else toFlat d tree @@ -860,6 +884,7 @@ def getChildNgen [Monad M] [MonadNameGenerator M] : M NameGenerator := do pure cngen def createLocalPreDiscrTree + (cctx : Core.Context) (ngen : NameGenerator) (env : Environment) (d : ImportData) @@ -868,18 +893,22 @@ def createLocalPreDiscrTree let modName := env.header.mainModule let cacheRef ← IO.mkRef (Cache.empty ngen) let act (t : PreDiscrTree α) (n : Name) (c : ConstantInfo) : BaseIO (PreDiscrTree α) := - addConstImportData env modName d cacheRef t act n c + addConstImportData cctx env modName d cacheRef t act n c let r ← (env.constants.map₂.foldlM (init := {}) act : BaseIO (PreDiscrTree α)) pure r def dropKeys (t : LazyDiscrTree α) (keys : List (List LazyDiscrTree.Key)) : MetaM (LazyDiscrTree α) := do keys.foldlM (init := t) (·.dropKey ·) +def logImportFailure [Monad m] [MonadLog m] [AddMessageContext m] [MonadOptions m] (f : ImportFailure) : m Unit := + logError m!"Processing failure with {f.const} in {f.module}:\n {f.exception.toMessageData}" + /-- Create a discriminator tree for imported environment. -/ -def createImportedDiscrTree (ngen : NameGenerator) (env : Environment) +def createImportedDiscrTree [Monad m] [MonadLog m] [AddMessageContext m] [MonadOptions m] [MonadLiftT BaseIO m] + (cctx : Core.Context) (ngen : NameGenerator) (env : Environment) (act : Name → ConstantInfo → MetaM (Array (InitEntry α))) (constantsPerTask : Nat := 1000) : - EIO Exception (LazyDiscrTree α) := do + m (LazyDiscrTree α) := do let n := env.header.moduleData.size let rec /-- Allocate constants to tasks according to `constantsPerTask`. -/ @@ -889,29 +918,40 @@ def createImportedDiscrTree (ngen : NameGenerator) (env : Environment) let cnt := cnt + mdata.constants.size if cnt > constantsPerTask then let (childNGen, ngen) := ngen.mkChild - let t ← createImportedEnvironmentSeq childNGen env act start (idx+1) |>.asTask + let t ← liftM <| createImportedEnvironmentSeq cctx childNGen env act start (idx+1) |>.asTask go ngen (tasks.push t) (idx+1) 0 (idx+1) else go ngen tasks start cnt (idx+1) else if start < n then let (childNGen, _) := ngen.mkChild - tasks.push <$> (createImportedEnvironmentSeq childNGen env act start n).asTask + let t ← (createImportedEnvironmentSeq cctx childNGen env act start n).asTask + pure (tasks.push t) else pure tasks termination_by env.header.moduleData.size - idx let tasks ← go ngen #[] 0 0 0 let r := combineGet default tasks - if p : r.errors.size > 0 then - throw r.errors[0].exception + r.errors.forM logImportFailure pure <| r.tree.toLazy +/-- Creates the core context used for initializing a tree using the current context. -/ +private def createTreeCtx (ctx : Core.Context) : Core.Context := { + fileName := ctx.fileName, + fileMap := ctx.fileMap, + options := ctx.options, + maxRecDepth := ctx.maxRecDepth, + maxHeartbeats := 0, + ref := ctx.ref, + } + def findImportMatches (ext : EnvExtension (IO.Ref (Option (LazyDiscrTree α)))) (addEntry : Name → ConstantInfo → MetaM (Array (InitEntry α))) (droppedKeys : List (List LazyDiscrTree.Key) := []) (constantsPerTask : Nat := 1000) (ty : Expr) : MetaM (MatchResult α) := do + let cctx ← (read : CoreM Core.Context) let ngen ← getNGen let (cNGen, ngen) := ngen.mkChild setNGen ngen @@ -919,7 +959,7 @@ def findImportMatches let ref := @EnvExtension.getState _ ⟨dummy⟩ ext (←getEnv) let importTree ← (←ref.get).getDM $ do profileitM Exception "lazy discriminator import initialization" (←getOptions) $ do - let t ← createImportedDiscrTree cNGen (←getEnv) addEntry + let t ← createImportedDiscrTree (createTreeCtx cctx) cNGen (←getEnv) addEntry (constantsPerTask := constantsPerTask) dropKeys t droppedKeys let (importCandidates, importTree) ← importTree.getMatch ty @@ -943,10 +983,9 @@ def createModuleDiscrTree let env ← getEnv let ngen ← getChildNgen let d ← ImportData.new - let t ← createLocalPreDiscrTree ngen env d entriesForConst - let errors ← d.errors.get - if p : errors.size > 0 then - throw errors[0].exception + let ctx ← read + let t ← createLocalPreDiscrTree ctx ngen env d entriesForConst + (← d.errors.get).forM logImportFailure pure <| t.toLazy /-- @@ -966,7 +1005,7 @@ Returns candidates from this module in this module that match the expression. this module's definitions. -/ def findModuleMatches (moduleRef : ModuleDiscrTreeRef α) (ty : Expr) : MetaM (MatchResult α) := do - profileitM Exception "lazy discriminator local search" (←getOptions) $ do + profileitM Exception "lazy discriminator local search" (← getOptions) $ do let discrTree ← moduleRef.ref.get let (localCandidates, localTree) ← discrTree.getMatch ty moduleRef.ref.set localTree diff --git a/src/Lean/Meta/Tactic/Rewrites.lean b/src/Lean/Meta/Tactic/Rewrites.lean index c4832351997d..7167412b943a 100644 --- a/src/Lean/Meta/Tactic/Rewrites.lean +++ b/src/Lean/Meta/Tactic/Rewrites.lean @@ -35,9 +35,12 @@ def forwardWeight := 2 /-- Weight to multiply the "specificity" of a rewrite lemma by when rewriting backwards. -/ def backwardWeight := 1 +inductive RwDirection : Type where + | forward : RwDirection + | backward : RwDirection private def addImport (name : Name) (constInfo : ConstantInfo) : - MetaM (Array (InitEntry (Name × Bool × Nat))) := do + MetaM (Array (InitEntry (Name × RwDirection))) := do if constInfo.isUnsafe then return #[] if !allowCompletion (←getEnv) name then return #[] -- We now remove some injectivity lemmas which are not useful to rewrite by. @@ -46,16 +49,22 @@ private def addImport (name : Name) (constInfo : ConstantInfo) : match name with | .str _ n => if n.endsWith "_inj" ∨ n.endsWith "_inj'" then return #[] | _ => pure () - withNewMCtxDepth do withReducible do - forallTelescopeReducing constInfo.type fun _ type => do - match type.getAppFnArgs with - | (``Eq, #[_, lhs, rhs]) - | (``Iff, #[lhs, rhs]) => do - let a := Array.mkEmpty 2 - let a := a.push (← InitEntry.fromExpr lhs (name, false, forwardWeight)) - let a := a.push (← InitEntry.fromExpr rhs (name, true, backwardWeight)) - pure a - | _ => return #[] + try + withNewMCtxDepth do withReducible do + forallTelescopeReducing constInfo.type fun _ type => do + match type.getAppFnArgs with + | (``Eq, #[_, lhs, rhs]) + | (``Iff, #[lhs, rhs]) => do + let a := Array.mkEmpty 2 + let a := a.push (← InitEntry.fromExpr lhs (name, RwDirection.forward)) + let a := a.push (← InitEntry.fromExpr rhs (name, RwDirection.backward)) + pure a + | _ => return #[] + catch _e => + throwError "Jhx. Timeout initializing entries" +-- if e.isMaxHeartbeat then +-- else +-- throw e /-- Configuration for `DiscrTree`. -/ def discrTreeConfig : WhnfCoreConfig := {} @@ -69,12 +78,10 @@ def localHypotheses (except : List FVarId := []) : MetaM (Array (Expr × Bool × let (_, _, type) ← forallMetaTelescopeReducing (← inferType h) let type ← whnfR type match type.getAppFnArgs with - | (``Eq, #[_, lhs, rhs]) - | (``Iff, #[lhs, rhs]) => do - let lhsKey : Array DiscrTree.Key ← DiscrTree.mkPath lhs discrTreeConfig - let rhsKey : Array DiscrTree.Key ← DiscrTree.mkPath rhs discrTreeConfig - result := result.push (h, false, forwardWeight * lhsKey.size) - |>.push (h, true, backwardWeight * rhsKey.size) + | (``Eq, #[_, _, _]) + | (``Iff, #[_, _]) => do + result := result.push (h, false, forwardWeight) + |>.push (h, true, backwardWeight) | _ => pure () return result @@ -84,12 +91,12 @@ they match too much. -/ def droppedKeys : List (List LazyDiscrTree.Key) := [[.star], [.const `Eq 3, .star, .star, .star]] -def createModuleTreeRef : MetaM (LazyDiscrTree.ModuleDiscrTreeRef (Name × Bool × Nat)) := +def createModuleTreeRef : MetaM (LazyDiscrTree.ModuleDiscrTreeRef (Name × RwDirection)) := LazyDiscrTree.createModuleTreeRef addImport droppedKeys -private def ExtState := IO.Ref (Option (LazyDiscrTree (Name × Bool × Nat))) +private def ExtState := IO.Ref (Option (LazyDiscrTree (Name × RwDirection))) -private builtin_initialize ExtState.default : IO.Ref (Option (LazyDiscrTree (Name × Bool × Nat))) ← do +private builtin_initialize ExtState.default : IO.Ref (Option (LazyDiscrTree (Name × RwDirection))) ← do IO.mkRef .none private instance : Inhabited ExtState where @@ -108,11 +115,14 @@ initialization performance. -/ private def constantsPerImportTask : Nat := 6500 -def incPrio : Nat → Name × Bool × Nat → Name × Bool × Nat -| p, (nm, d, prio) => (nm, d, prio * 100 + p) +def incPrio : Nat → Name × RwDirection → Name × Bool × Nat +| q, (nm, d) => + match d with + | .forward => (nm, false, 2 * q) + | .backward => (nm, true, q) /-- Create function for finding relevant declarations. -/ -def rwFindDecls (moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × Bool × Nat)) : Expr → MetaM (Array (Name × Bool × Nat)) := +def rwFindDecls (moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × RwDirection)) : Expr → MetaM (Array (Name × Bool × Nat)) := LazyDiscrTree.findMatchesExt moduleRef ext addImport (droppedKeys := droppedKeys) (constantsPerTask := constantsPerImportTask) @@ -134,14 +144,12 @@ structure RewriteResult where mctx : MetavarContext rfl? : Bool -/-- Update a `RewriteResult` by filling in the `rfl?` field if it is currently `none`, -to reflect whether the remaining goal can be closed by `with_reducible rfl`. -/ -def computeRfl (mctx : MetavarContext) (res : Meta.RewriteResult) : MetaM Bool := do +/-- Check to see if this expression (which must be a type) can be closed by `with_reducible rfl`. -/ +def dischargableWithRfl? (mctx : MetavarContext) (e : Expr) : MetaM Bool := do try withoutModifyingState <| withMCtx mctx do -- We use `withReducible` here to follow the behaviour of `rw`. - withReducible (← mkFreshExprMVar res.eNew).mvarId!.applyRfl - -- We do not need to record the updated `MetavarContext` here. + withReducible (← mkFreshExprMVar e).mvarId!.applyRfl pure true catch _e => pure false @@ -168,7 +176,7 @@ def solveByElim (goals : List MVarId) (depth : Nat := 6) : MetaM PUnit := do let [] ← SolveByElim.solveByElim cfg lemmas ctx goals | failure -def rwLemma (ctx : MetavarContext) (goal : MVarId) (target : Expr) (side : SideConditions := .solveByElim) +def rwLemma (ctx : MetavarContext) (goal : MVarId) (target : Expr) (side : SideConditions := .solveByElim) (lem : Expr ⊕ Name) (symm : Bool) (weight : Nat) : MetaM (Option RewriteResult) := withMCtx ctx do let some expr ← (match lem with @@ -180,7 +188,7 @@ def rwLemma (ctx : MetavarContext) (goal : MVarId) (target : Expr) (side : SideC | return none if result.mvarIds.isEmpty then let mctx ← getMCtx - let rfl? ← computeRfl mctx result + let rfl? ← dischargableWithRfl? mctx result.eNew return some { expr, symm, weight, result, mctx, rfl? } else -- There are side conditions, which we try to discharge using local hypotheses. @@ -201,7 +209,7 @@ def rwLemma (ctx : MetavarContext) (goal : MVarId) (target : Expr) (side : SideC else (expr, false) let mctx ← getMCtx - let rfl? ← computeRfl mctx result + let rfl? ← dischargableWithRfl? mctx result.eNew return some { expr, symm, weight, result, mctx, rfl? } /-- @@ -237,7 +245,7 @@ Find lemmas which can rewrite the goal. See also `rewrites` for a more convenient interface. -/ def rewriteCandidates (hyps : Array (Expr × Bool × Nat)) - (moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × Bool × Nat)) + (moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × RwDirection)) (target : Expr) (forbidden : NameSet := ∅) : MetaM (Array ((Expr ⊕ Name) × Bool × Nat)) := do @@ -306,7 +314,7 @@ def takeListAux (cfg : RewriteResultConfig) (seen : HashMap String Unit) (acc : let s ← withoutModifyingState <| withMCtx r.mctx r.ppResult if seen.contains s then continue - let rfl? ← computeRfl r.mctx r.result + let rfl? ← dischargableWithRfl? r.mctx r.result.eNew if cfg.stopAtRfl then if rfl? then return #[r] @@ -320,7 +328,7 @@ def takeListAux (cfg : RewriteResultConfig) (seen : HashMap String Unit) (acc : /-- Find lemmas which can rewrite the goal. -/ def findRewrites (hyps : Array (Expr × Bool × Nat)) - (moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × Bool × Nat)) + (moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × RwDirection)) (goal : MVarId) (target : Expr) (forbidden : NameSet := ∅) (side : SideConditions := .solveByElim) (stopAtRfl : Bool) (max : Nat := 20) diff --git a/tests/lean/librarySearch.lean b/tests/lean/librarySearch.lean index 02a7cd5f26a7..ede49e4cbdce 100644 --- a/tests/lean/librarySearch.lean +++ b/tests/lean/librarySearch.lean @@ -21,7 +21,7 @@ noncomputable section #guard_msgs in example (x : Nat) : x ≠ x.succ := Nat.ne_of_lt (by apply?) -/-- info: Try this: exact Nat.zero_lt_succ 1 -/ +/-- info: Try this: exact Nat.lt_of_sub_eq_succ rfl -/ #guard_msgs in example : 0 ≠ 1 + 1 := Nat.ne_of_lt (by apply?) @@ -83,11 +83,11 @@ example (n m k : Nat) : n * m - n * k = n * (m - k) := by #guard_msgs in example {α : Type} (x y : α) : x = y ↔ y = x := by apply? -/-- info: Try this: exact Nat.add_pos_left ha b -/ +/-- info: Try this: exact Nat.lt_add_right b ha -/ #guard_msgs in example (a b : Nat) (ha : 0 < a) (_hb : 0 < b) : 0 < a + b := by apply? -/-- info: Try this: exact Nat.add_pos_left ha b -/ +/-- info: Try this: exact Nat.lt_add_right b ha -/ #guard_msgs in -- Verify that if maxHeartbeats is 0 we don't stop immediately. set_option maxHeartbeats 0 in @@ -95,7 +95,7 @@ example (a b : Nat) (ha : 0 < a) (_hb : 0 < b) : 0 < a + b := by apply? section synonym -/-- info: Try this: exact Nat.add_pos_left ha b -/ +/-- info: Try this: exact Nat.lt_add_right b ha -/ #guard_msgs in example (a b : Nat) (ha : a > 0) (_hb : 0 < b) : 0 < a + b := by apply?