From 77911546f7faa000f3e79ae5626d244af8f6def1 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 12 Nov 2024 17:24:39 -0800 Subject: [PATCH] refactor: mark the `Simp.Context` constructor as private motivation: this is the first step to fix the mismatch between `isDefEq` and the discrimination tree indexing. --- src/Lean/Elab/PreDefinition/Nonrec/Eqns.lean | 4 +- .../Elab/PreDefinition/Structural/Eqns.lean | 4 +- src/Lean/Elab/PreDefinition/WF/Eqns.lean | 4 +- .../Tactic/BVDecide/Frontend/Normalize.lean | 22 ++++------ src/Lean/Elab/Tactic/Conv/Pattern.lean | 11 +++-- src/Lean/Elab/Tactic/NormCast.lean | 20 ++++++--- src/Lean/Elab/Tactic/Simp.lean | 13 +++--- src/Lean/Elab/Tactic/Simpa.lean | 4 +- src/Lean/Meta/Tactic/AC/Main.lean | 20 ++++----- src/Lean/Meta/Tactic/Acyclic.lean | 5 ++- src/Lean/Meta/Tactic/Grind/Preprocessor.lean | 9 ++-- src/Lean/Meta/Tactic/Simp/Attr.lean | 7 +++- src/Lean/Meta/Tactic/Simp/Main.lean | 23 +++------- src/Lean/Meta/Tactic/Simp/Rewrite.lean | 9 ++-- src/Lean/Meta/Tactic/Simp/SimpAll.lean | 6 +-- src/Lean/Meta/Tactic/Simp/Types.lean | 42 +++++++++++++++++++ src/Lean/Meta/Tactic/Split.lean | 11 +++-- src/Lean/Meta/Tactic/SplitIf.lean | 9 ++-- src/Lean/Meta/Tactic/Unfold.lean | 9 ++-- 19 files changed, 136 insertions(+), 96 deletions(-) diff --git a/src/Lean/Elab/PreDefinition/Nonrec/Eqns.lean b/src/Lean/Elab/PreDefinition/Nonrec/Eqns.lean index f8acce7db9fa..91c5f8f151c3 100644 --- a/src/Lean/Elab/PreDefinition/Nonrec/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/Nonrec/Eqns.lean @@ -50,7 +50,9 @@ private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do go mvarId else if let some mvarId ← whnfReducibleLHS? mvarId then go mvarId - else match (← simpTargetStar mvarId { config.dsimp := false } (simprocs := {})).1 with + else + let ctx ← Simp.mkContext (config := { dsimp := false }) + match (← simpTargetStar mvarId ctx (simprocs := {})).1 with | TacticResultCNM.closed => return () | TacticResultCNM.modified mvarId => go mvarId | TacticResultCNM.noChange => diff --git a/src/Lean/Elab/PreDefinition/Structural/Eqns.lean b/src/Lean/Elab/PreDefinition/Structural/Eqns.lean index 12ad3186afd0..b3e039ae9fd6 100644 --- a/src/Lean/Elab/PreDefinition/Structural/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/Structural/Eqns.lean @@ -45,7 +45,9 @@ where go mvarId else if let some mvarId ← simpIf? mvarId then go mvarId - else match (← simpTargetStar mvarId {} (simprocs := {})).1 with + else + let ctx ← Simp.mkContext + match (← simpTargetStar mvarId ctx (simprocs := {})).1 with | TacticResultCNM.closed => return () | TacticResultCNM.modified mvarId => go mvarId | TacticResultCNM.noChange => diff --git a/src/Lean/Elab/PreDefinition/WF/Eqns.lean b/src/Lean/Elab/PreDefinition/WF/Eqns.lean index 36e9e9b380ca..43c4f23b08bd 100644 --- a/src/Lean/Elab/PreDefinition/WF/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/WF/Eqns.lean @@ -57,7 +57,9 @@ private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do go mvarId else if let some mvarId ← whnfReducibleLHS? mvarId then go mvarId - else match (← simpTargetStar mvarId { config.dsimp := false } (simprocs := {})).1 with + else + let ctx ← Simp.mkContext (config := { dsimp := false }) + match (← simpTargetStar mvarId ctx (simprocs := {})).1 with | TacticResultCNM.closed => return () | TacticResultCNM.modified mvarId => go mvarId | TacticResultCNM.noChange => diff --git a/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean b/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean index 8bad5ffb2659..f325c49950fe 100644 --- a/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean +++ b/src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean @@ -162,13 +162,10 @@ def rewriteRulesPass (maxSteps : Nat) : Pass := fun goal => do let bvSimprocs ← bvNormalizeSimprocExt.getSimprocs let sevalThms ← getSEvalTheorems let sevalSimprocs ← Simp.getSEvalSimprocs - - let simpCtx : Simp.Context := { - config := { failIfUnchanged := false, zetaDelta := true, maxSteps } - simpTheorems := #[bvThms, sevalThms] - congrTheorems := (← getSimpCongrTheorems) - } - + let simpCtx ← Simp.mkContext + (config := { failIfUnchanged := false, zetaDelta := true, maxSteps }) + (simpTheorems := #[bvThms, sevalThms]) + (congrTheorems := (← getSimpCongrTheorems)) let hyps ← goal.getNondepPropHyps let ⟨result?, _⟩ ← simpGoal goal (ctx := simpCtx) @@ -193,13 +190,10 @@ def embeddedConstraintPass (maxSteps : Nat) : Pass := fun goal => let proof := localDecl.toExpr acc.addTheorem (.fvar hyp) proof let relevantHyps : SimpTheoremsArray ← hyps.foldlM (init := #[]) relevanceFilter - - let simpCtx : Simp.Context := { - config := { failIfUnchanged := false, maxSteps } - simpTheorems := relevantHyps - congrTheorems := (← getSimpCongrTheorems) - } - + let simpCtx ← Simp.mkContext + (config := { failIfUnchanged := false, maxSteps }) + (simpTheorems := relevantHyps) + (congrTheorems := (← getSimpCongrTheorems)) let ⟨result?, _⟩ ← simpGoal goal (ctx := simpCtx) (fvarIdsToSimp := hyps) let some (_, newGoal) := result? | return none return newGoal diff --git a/src/Lean/Elab/Tactic/Conv/Pattern.lean b/src/Lean/Elab/Tactic/Conv/Pattern.lean index 13f8cfef1ab2..7faba51f1e8b 100644 --- a/src/Lean/Elab/Tactic/Conv/Pattern.lean +++ b/src/Lean/Elab/Tactic/Conv/Pattern.lean @@ -12,11 +12,10 @@ namespace Lean.Elab.Tactic.Conv open Meta private def getContext : MetaM Simp.Context := do - return { - simpTheorems := {} - congrTheorems := (← getSimpCongrTheorems) - config := Simp.neutralConfig - } + Simp.mkContext + (simpTheorems := {}) + (congrTheorems := (← getSimpCongrTheorems)) + (config := Simp.neutralConfig) partial def matchPattern? (pattern : AbstractMVarsResult) (e : Expr) : MetaM (Option (Expr × Array Expr)) := withNewMCtxDepth do @@ -126,7 +125,7 @@ private def pre (pattern : AbstractMVarsResult) (state : IO.Ref PatternMatchStat pure (.occs #[] 0 ids.toList) | _ => throwUnsupportedSyntax let state ← IO.mkRef occs - let ctx := { ← getContext with config.memoize := occs matches .all _ } + let ctx := (← getContext).setMemoize (occs matches .all _) let (result, _) ← Simp.main lhs ctx (methods := { pre := pre patternA state }) let subgoals ← match ← state.get with | .all #[] | .occs _ 0 _ => diff --git a/src/Lean/Elab/Tactic/NormCast.lean b/src/Lean/Elab/Tactic/NormCast.lean index e4e13c379c93..f27113cf9355 100644 --- a/src/Lean/Elab/Tactic/NormCast.lean +++ b/src/Lean/Elab/Tactic/NormCast.lean @@ -28,8 +28,10 @@ def proveEqUsing (s : SimpTheorems) (a b : Expr) : MetaM (Option Simp.Result) := unless ← isDefEq a'.expr b'.expr do return none a'.mkEqTrans (← b'.mkEqSymm b) withReducible do - (go (← Simp.mkDefaultMethods).toMethodsRef - { simpTheorems := #[s], congrTheorems := ← Meta.getSimpCongrTheorems }).run' {} + let ctx ← Simp.mkContext + (simpTheorems := #[s]) + (congrTheorems := ← Meta.getSimpCongrTheorems) + (go (← Simp.mkDefaultMethods).toMethodsRef ctx).run' {} /-- Proves `a = b` by simplifying using move and squash lemmas. -/ def proveEqUsingDown (a b : Expr) : MetaM (Option Simp.Result) := do @@ -191,19 +193,25 @@ def derive (e : Expr) : MetaM Simp.Result := do -- step 1: pre-processing of numerals let r ← withTrace "pre-processing numerals" do let post e := return Simp.Step.done (← try numeralToCoe e catch _ => pure {expr := e}) - r.mkEqTrans (← Simp.main r.expr { config, congrTheorems } (methods := { post })).1 + let ctx ← Simp.mkContext (config := config) (congrTheorems := congrTheorems) + r.mkEqTrans (← Simp.main r.expr ctx (methods := { post })).1 -- step 2: casts are moved upwards and eliminated let r ← withTrace "moving upward, splitting and eliminating" do let post := upwardAndElim (← normCastExt.up.getTheorems) - r.mkEqTrans (← Simp.main r.expr { config, congrTheorems } (methods := { post })).1 + let ctx ← Simp.mkContext (config := config) (congrTheorems := congrTheorems) + r.mkEqTrans (← Simp.main r.expr ctx (methods := { post })).1 let simprocs ← ({} : Simp.SimprocsArray).add `reduceCtorEq false -- step 3: casts are squashed let r ← withTrace "squashing" do let simpTheorems := #[← normCastExt.squash.getTheorems] - r.mkEqTrans (← simp r.expr { simpTheorems, config, congrTheorems } simprocs).1 + let ctx ← Simp.mkContext + (config := config) + (simpTheorems := simpTheorems) + (congrTheorems := congrTheorems) + r.mkEqTrans (← simp r.expr ctx simprocs).1 return r @@ -263,7 +271,7 @@ def evalConvNormCast : Tactic := def evalPushCast : Tactic := fun stx => do let { ctx, simprocs, dischargeWrapper } ← withMainContext do mkSimpContext (simpTheorems := pushCastExt.getTheorems) stx (eraseLocal := false) - let ctx := { ctx with config := { ctx.config with failIfUnchanged := false } } + let ctx := ctx.setFailIfUnchanged false dischargeWrapper.with fun discharge? => discard <| simpLocation ctx simprocs discharge? (expandOptLocation stx[5]) diff --git a/src/Lean/Elab/Tactic/Simp.lean b/src/Lean/Elab/Tactic/Simp.lean index 4a8b27d7f7e6..e69a07faf342 100644 --- a/src/Lean/Elab/Tactic/Simp.lean +++ b/src/Lean/Elab/Tactic/Simp.lean @@ -234,7 +234,7 @@ def elabSimpArgs (stx : Syntax) (ctx : Simp.Context) (simprocs : Simp.SimprocsAr logException ex else throw ex - return { ctx := { ctx with simpTheorems := thmsArray.set! 0 thms }, simprocs, starArg } + return { ctx := ctx.setSimpTheorems (thmsArray.set! 0 thms), simprocs, starArg } -- If recovery is disabled, then we want simp argument elaboration failures to be exceptions. -- This affects `addSimpTheorem`. if (← read).recover then @@ -311,10 +311,11 @@ def mkSimpContext (stx : Syntax) (eraseLocal : Bool) (kind := SimpKind.simp) simpTheorems let simprocs ← if simpOnly then pure {} else Simp.getSimprocs let congrTheorems ← getSimpCongrTheorems - let r ← elabSimpArgs stx[4] (eraseLocal := eraseLocal) (kind := kind) (simprocs := #[simprocs]) { - config := (← elabSimpConfig stx[1] (kind := kind)) - simpTheorems := #[simpTheorems], congrTheorems - } + let ctx ← Simp.mkContext + (config := (← elabSimpConfig stx[1] (kind := kind))) + (simpTheorems := #[simpTheorems]) + congrTheorems + let r ← elabSimpArgs stx[4] (eraseLocal := eraseLocal) (kind := kind) (simprocs := #[simprocs]) ctx if !r.starArg || ignoreStarArg then return { r with dischargeWrapper } else @@ -329,7 +330,7 @@ def mkSimpContext (stx : Syntax) (eraseLocal : Bool) (kind := SimpKind.simp) for h in hs do unless simpTheorems.isErased (.fvar h) do simpTheorems ← simpTheorems.addTheorem (.fvar h) (← h.getDecl).toExpr - let ctx := { ctx with simpTheorems } + let ctx := ctx.setSimpTheorems simpTheorems return { ctx, simprocs, dischargeWrapper } register_builtin_option tactic.simp.trace : Bool := { diff --git a/src/Lean/Elab/Tactic/Simpa.lean b/src/Lean/Elab/Tactic/Simpa.lean index 40e38289610e..8ea869a52b5d 100644 --- a/src/Lean/Elab/Tactic/Simpa.lean +++ b/src/Lean/Elab/Tactic/Simpa.lean @@ -36,9 +36,9 @@ deriving instance Repr for UseImplicitLambdaResult let stx ← `(tactic| simp $cfg:optConfig $(disch)? $[only%$only]? $[[$args,*]]?) let { ctx, simprocs, dischargeWrapper } ← withMainContext <| mkSimpContext stx (eraseLocal := false) - let ctx := if unfold.isSome then { ctx with config.autoUnfold := true } else ctx + let ctx := if unfold.isSome then ctx.setAutoUnfold else ctx -- TODO: have `simpa` fail if it doesn't use `simp`. - let ctx := { ctx with config := { ctx.config with failIfUnchanged := false } } + let ctx := ctx.setFailIfUnchanged false dischargeWrapper.with fun discharge? => do let (some (_, g), stats) ← simpGoal (← getMainGoal) ctx (simprocs := simprocs) (simplifyTarget := true) (discharge? := discharge?) diff --git a/src/Lean/Meta/Tactic/AC/Main.lean b/src/Lean/Meta/Tactic/AC/Main.lean index 30362404e077..a047a785b228 100644 --- a/src/Lean/Meta/Tactic/AC/Main.lean +++ b/src/Lean/Meta/Tactic/AC/Main.lean @@ -188,12 +188,10 @@ def post (e : Expr) : SimpM Simp.Step := do | e, _ => return Simp.Step.done { expr := e } def rewriteUnnormalized (mvarId : MVarId) : MetaM MVarId := do - let simpCtx := - { - simpTheorems := {} - congrTheorems := (← getSimpCongrTheorems) - config := Simp.neutralConfig - } + let simpCtx ← Simp.mkContext + (simpTheorems := {}) + (congrTheorems := (← getSimpCongrTheorems)) + (config := Simp.neutralConfig) let tgt ← instantiateMVars (← mvarId.getType) let (res, _) ← Simp.main tgt simpCtx (methods := { post }) applySimpResultToTarget mvarId tgt res @@ -207,12 +205,10 @@ def rewriteUnnormalizedRefl (goal : MVarId) : MetaM Unit := do def acNfHypMeta (goal : MVarId) (fvarId : FVarId) : MetaM (Option MVarId) := do goal.withContext do - let simpCtx := - { - simpTheorems := {} - congrTheorems := (← getSimpCongrTheorems) - config := Simp.neutralConfig - } + let simpCtx ← Simp.mkContext + (simpTheorems := {}) + (congrTheorems := (← getSimpCongrTheorems)) + (config := Simp.neutralConfig) let tgt ← instantiateMVars (← fvarId.getType) let (res, _) ← Simp.main tgt simpCtx (methods := { post }) return (← applySimpResultToLocalDecl goal fvarId res false).map (·.snd) diff --git a/src/Lean/Meta/Tactic/Acyclic.lean b/src/Lean/Meta/Tactic/Acyclic.lean index 59ba3539685c..2770acec5fa2 100644 --- a/src/Lean/Meta/Tactic/Acyclic.lean +++ b/src/Lean/Meta/Tactic/Acyclic.lean @@ -38,7 +38,10 @@ where let sizeOfEq ← mkLT sizeOf_lhs sizeOf_rhs let hlt ← mkFreshExprSyntheticOpaqueMVar sizeOfEq -- TODO: we only need the `sizeOf` simp theorems - match (← simpTarget hlt.mvarId! { config.arith := true, simpTheorems := #[ (← getSimpTheorems) ] } {}).1 with + let ctx ← Simp.mkContext + (config := { arith := true }) + (simpTheorems := #[ (← getSimpTheorems) ]) + match (← simpTarget hlt.mvarId! ctx {}).1 with | some _ => return false | none => let heq ← mkCongrArg sizeOf_lhs.appFn! (← mkEqSymm h) diff --git a/src/Lean/Meta/Tactic/Grind/Preprocessor.lean b/src/Lean/Meta/Tactic/Grind/Preprocessor.lean index 0588a146ddd2..dbbe33c5df77 100644 --- a/src/Lean/Meta/Tactic/Grind/Preprocessor.lean +++ b/src/Lean/Meta/Tactic/Grind/Preprocessor.lean @@ -38,11 +38,10 @@ abbrev PreM := ReaderT Context $ StateRefT State GrindM def PreM.run (x : PreM α) : GrindM α := do let thms ← grindNormExt.getTheorems let simprocs := #[(← grindNormSimprocExt.getSimprocs)] - let simp : Simp.Context := { - config := { arith := true } - simpTheorems := #[thms] - congrTheorems := (← getSimpCongrTheorems) - } + let simp ← Simp.mkContext + (config := { arith := true }) + (simpTheorems := #[thms]) + (congrTheorems := (← getSimpCongrTheorems)) x { simp, simprocs } |>.run' {} def simp (_goal : Goal) (e : Expr) : PreM Simp.Result := do diff --git a/src/Lean/Meta/Tactic/Simp/Attr.lean b/src/Lean/Meta/Tactic/Simp/Attr.lean index 71a84b7f5843..f03bd6625e37 100644 --- a/src/Lean/Meta/Tactic/Simp/Attr.lean +++ b/src/Lean/Meta/Tactic/Simp/Attr.lean @@ -73,7 +73,10 @@ def getSimpTheorems : CoreM SimpTheorems := def getSEvalTheorems : CoreM SimpTheorems := sevalSimpExtension.getTheorems -def Simp.Context.mkDefault : MetaM Context := - return { config := {}, simpTheorems := #[(← Meta.getSimpTheorems)], congrTheorems := (← Meta.getSimpCongrTheorems) } +def Simp.Context.mkDefault : MetaM Context := do + mkContext + (config := {}) + (simpTheorems := #[(← Meta.getSimpTheorems)]) + (congrTheorems := (← Meta.getSimpCongrTheorems)) end Lean.Meta diff --git a/src/Lean/Meta/Tactic/Simp/Main.lean b/src/Lean/Meta/Tactic/Simp/Main.lean index 1bb9a68026c1..a7928f3239f2 100644 --- a/src/Lean/Meta/Tactic/Simp/Main.lean +++ b/src/Lean/Meta/Tactic/Simp/Main.lean @@ -20,18 +20,6 @@ builtin_initialize congrHypothesisExceptionId : InternalExceptionId ← def throwCongrHypothesisFailed : MetaM α := throw <| Exception.internal congrHypothesisExceptionId -/-- - Helper method for bootstrapping purposes. It disables `arith` if support theorems have not been defined yet. --/ -def Config.updateArith (c : Config) : CoreM Config := do - if c.arith then - if (← getEnv).contains ``Nat.Linear.ExprCnstr.eq_of_toNormPoly_eq then - return c - else - return { c with arith := false } - else - return c - /-- Return true if `e` is of the form `ofNat n` where `n` is a kernel Nat literal -/ def isOfNatNatLit (e : Expr) : Bool := e.isAppOf ``OfNat.ofNat && e.getAppNumArgs >= 3 && (e.getArg! 1).isRawNatLit @@ -256,7 +244,7 @@ def withNewLemmas {α} (xs : Array Expr) (f : SimpM α) : SimpM α := do s ← s.addTheorem (.fvar x.fvarId!) x updated := true if updated then - withTheReader Context (fun ctx => { ctx with simpTheorems := s }) f + withSimpTheorems s f else f else if (← getMethods).wellBehavedDischarge then @@ -463,7 +451,7 @@ private partial def dsimpImpl (e : Expr) : SimpM Expr := do let m ← getMethods let pre := m.dpre >> doNotVisitOfNat >> doNotVisitOfScientific >> doNotVisitCharLit let post := m.dpost >> dsimpReduce - withTheReader Simp.Context (fun ctx => { ctx with inDSimp := true }) do + withInDSimp do transform (usedLetOnly := cfg.zeta) e (pre := pre) (post := post) def visitFn (e : Expr) : SimpM Result := do @@ -658,11 +646,12 @@ where trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}" simpLoop e +-- TODO: delete @[inline] def withSimpContext (ctx : Context) (x : MetaM α) : MetaM α := withConfig (fun c => { c with etaStruct := ctx.config.etaStruct }) <| withReducible x def main (e : Expr) (ctx : Context) (stats : Stats := {}) (methods : Methods := {}) : MetaM (Result × Stats) := do - let ctx := { ctx with config := (← ctx.config.updateArith), lctxInitIndices := (← getLCtx).numIndices } + let ctx ← ctx.setLctxInitIndices withSimpContext ctx do let (r, s) ← go e methods.toMethodsRef ctx |>.run { stats with } trace[Meta.Tactic.simp.numSteps] "{s.numSteps}" @@ -810,7 +799,7 @@ def simpGoal (mvarId : MVarId) (ctx : Simp.Context) (simprocs : SimprocsArray := for fvarId in fvarIdsToSimp do let localDecl ← fvarId.getDecl let type ← instantiateMVars localDecl.type - let ctx := { ctx with simpTheorems := ctx.simpTheorems.eraseTheorem (.fvar localDecl.fvarId) } + let ctx := ctx.setSimpTheorems <| ctx.simpTheorems.eraseTheorem (.fvar localDecl.fvarId) let (r, stats') ← simp type ctx simprocs discharge? stats stats := stats' match r.proof? with @@ -844,7 +833,7 @@ def simpTargetStar (mvarId : MVarId) (ctx : Simp.Context) (simprocs : SimprocsAr let localDecl ← h.getDecl let proof := localDecl.toExpr let simpTheorems ← ctx.simpTheorems.addTheorem (.fvar h) proof - ctx := { ctx with simpTheorems } + ctx := ctx.setSimpTheorems simpTheorems match (← simpTarget mvarId ctx simprocs discharge? (stats := stats)) with | (none, stats) => return (TacticResultCNM.closed, stats) | (some mvarId', stats') => diff --git a/src/Lean/Meta/Tactic/Simp/Rewrite.lean b/src/Lean/Meta/Tactic/Simp/Rewrite.lean index c948bd2edb88..f1376fc55e44 100644 --- a/src/Lean/Meta/Tactic/Simp/Rewrite.lean +++ b/src/Lean/Meta/Tactic/Simp/Rewrite.lean @@ -41,7 +41,7 @@ def discharge?' (thmId : Origin) (x : Expr) (type : Expr) : SimpM Bool := do let ctx ← getContext if ctx.dischargeDepth >= ctx.maxDischargeDepth then return .maxDepth - else withTheReader Context (fun ctx => { ctx with dischargeDepth := ctx.dischargeDepth + 1 }) do + else withIncDischargeDepth do -- We save the state, so that `UsedTheorems` does not accumulate -- `simp` lemmas used during unsuccessful discharging. -- We use `withPreservedCache` to ensure the cache is restored after `discharge?` @@ -446,10 +446,13 @@ def mkSEvalMethods : CoreM Methods := do wellBehavedDischarge := true } -def mkSEvalContext : CoreM Context := do +def mkSEvalContext : MetaM Context := do let s ← getSEvalTheorems let c ← Meta.getSimpCongrTheorems - return { simpTheorems := #[s], congrTheorems := c, config := { ground := true } } + mkContext + (simpTheorems := #[s]) + (congrTheorems := c) + (config := { ground := true }) /-- Invoke ground/symbolic evaluator from `simp`. diff --git a/src/Lean/Meta/Tactic/Simp/SimpAll.lean b/src/Lean/Meta/Tactic/Simp/SimpAll.lean index 4a66231b065b..9c904b5eb0fc 100644 --- a/src/Lean/Meta/Tactic/Simp/SimpAll.lean +++ b/src/Lean/Meta/Tactic/Simp/SimpAll.lean @@ -43,7 +43,7 @@ private def initEntries : M Unit := do let localDecl ← h.getDecl let proof := localDecl.toExpr simpThms ← simpThms.addTheorem (.fvar h) proof - modify fun s => { s with ctx.simpTheorems := simpThms } + modify fun s => { s with ctx := s.ctx.setSimpTheorems simpThms } if hsNonDeps.contains h then -- We only simplify nondependent hypotheses let type ← instantiateMVars localDecl.type @@ -62,7 +62,7 @@ private partial def loop : M Bool := do let ctx := (← get).ctx -- We disable the current entry to prevent it to be simplified to `True` let simpThmsWithoutEntry := (← getSimpTheorems).eraseTheorem entry.id - let ctx := { ctx with simpTheorems := simpThmsWithoutEntry } + let ctx := ctx.setSimpTheorems simpThmsWithoutEntry let (r, stats) ← simpStep (← get).mvarId entry.proof entry.type ctx simprocs (stats := { (← get) with }) modify fun s => { s with usedTheorems := stats.usedTheorems, diag := stats.diag } match r with @@ -98,7 +98,7 @@ private partial def loop : M Bool := do simpThmsNew ← simpThmsNew.addTheorem (.other idNew) (← mkExpectedTypeHint proofNew typeNew) modify fun s => { s with modified := true - ctx.simpTheorems := simpThmsNew + ctx := ctx.setSimpTheorems simpThmsNew entries[i] := { entry with type := typeNew, proof := proofNew, id := .other idNew } } -- simplify target diff --git a/src/Lean/Meta/Tactic/Simp/Types.lean b/src/Lean/Meta/Tactic/Simp/Types.lean index a3c1344065b0..5eb059fac687 100644 --- a/src/Lean/Meta/Tactic/Simp/Types.lean +++ b/src/Lean/Meta/Tactic/Simp/Types.lean @@ -52,6 +52,7 @@ abbrev Cache := SExprMap Result abbrev CongrCache := ExprMap (Option CongrTheorem) structure Context where + private mk :: config : Config := {} /-- `maxDischargeDepth` from `config` as an `UInt32`. -/ maxDischargeDepth : UInt32 := UInt32.ofNatTruncate config.maxDischargeDepth @@ -103,6 +104,38 @@ structure Context where inDSimp : Bool := false deriving Inhabited +/-- +Helper method for bootstrapping purposes. +It disables `arith` if support theorems have not been defined yet. +-/ +private def updateArith (c : Config) : CoreM Config := do + if c.arith then + if (← getEnv).contains ``Nat.Linear.ExprCnstr.eq_of_toNormPoly_eq then + return c + else + return { c with arith := false } + else + return c + +def mkContext (config : Config := {}) (simpTheorems : SimpTheoremsArray := {}) (congrTheorems : SimpCongrTheorems := {}) : MetaM Context := do + let config ← updateArith config + return { config, simpTheorems, congrTheorems } + +def Context.setSimpTheorems (c : Context) (simpTheorems : SimpTheoremsArray) : Context := + { c with simpTheorems } + +def Context.setLctxInitIndices (c : Context) : MetaM Context := + return { c with lctxInitIndices := (← getLCtx).numIndices } + +def Context.setAutoUnfold (c : Context) : Context := + { c with config.autoUnfold := true } + +def Context.setFailIfUnchanged (c : Context) (flag : Bool) : Context := + { c with config.failIfUnchanged := flag } + +def Context.setMemoize (c : Context) (flag : Bool) : Context := + { c with config.memoize := flag } + def Context.isDeclToUnfold (ctx : Context) (declName : Name) : Bool := ctx.simpTheorems.isDeclToUnfold declName @@ -158,6 +191,15 @@ instance : Nonempty MethodsRef := MethodsRefPointed.property abbrev SimpM := ReaderT MethodsRef $ ReaderT Context $ StateRefT State MetaM +@[inline] def withIncDischargeDepth : SimpM α → SimpM α := + withTheReader Context (fun ctx => { ctx with dischargeDepth := ctx.dischargeDepth + 1 }) + +@[inline] def withSimpTheorems (s : SimpTheoremsArray) : SimpM α → SimpM α := + withTheReader Context (fun ctx => { ctx with simpTheorems := s }) + +@[inline] def withInDSimp : SimpM α → SimpM α := + withTheReader Context (fun ctx => { ctx with inDSimp := true }) + @[extern "lean_simp"] opaque simp (e : Expr) : SimpM Result diff --git a/src/Lean/Meta/Tactic/Split.lean b/src/Lean/Meta/Tactic/Split.lean index bd20c779acfd..0ebcda9ccc0f 100644 --- a/src/Lean/Meta/Tactic/Split.lean +++ b/src/Lean/Meta/Tactic/Split.lean @@ -13,12 +13,11 @@ import Lean.Meta.Tactic.Generalize namespace Lean.Meta namespace Split -def getSimpMatchContext : MetaM Simp.Context := - return { - simpTheorems := {} - congrTheorems := (← getSimpCongrTheorems) - config := { Simp.neutralConfig with dsimp := false } - } +def getSimpMatchContext : MetaM Simp.Context := do + Simp.mkContext + (simpTheorems := {}) + (congrTheorems := (← getSimpCongrTheorems)) + (config := { Simp.neutralConfig with dsimp := false }) def simpMatch (e : Expr) : MetaM Simp.Result := do let discharge? ← SplitIf.mkDischarge? diff --git a/src/Lean/Meta/Tactic/SplitIf.lean b/src/Lean/Meta/Tactic/SplitIf.lean index d6a21e31495f..b1f6f7c11d7a 100644 --- a/src/Lean/Meta/Tactic/SplitIf.lean +++ b/src/Lean/Meta/Tactic/SplitIf.lean @@ -19,11 +19,10 @@ def getSimpContext : MetaM Simp.Context := do s ← s.addConst ``if_neg s ← s.addConst ``dif_pos s ← s.addConst ``dif_neg - return { - simpTheorems := #[s] - congrTheorems := (← getSimpCongrTheorems) - config := { Simp.neutralConfig with dsimp := false } - } + Simp.mkContext + (simpTheorems := #[s]) + (congrTheorems := (← getSimpCongrTheorems)) + (config := { Simp.neutralConfig with dsimp := false }) /-- Default `discharge?` function for `simpIf` methods. diff --git a/src/Lean/Meta/Tactic/Unfold.lean b/src/Lean/Meta/Tactic/Unfold.lean index eb95630fbad6..524819f8b9e1 100644 --- a/src/Lean/Meta/Tactic/Unfold.lean +++ b/src/Lean/Meta/Tactic/Unfold.lean @@ -10,11 +10,10 @@ import Lean.Meta.Tactic.Simp.Main namespace Lean.Meta -private def getSimpUnfoldContext : MetaM Simp.Context := - return { - congrTheorems := (← getSimpCongrTheorems) - config := Simp.neutralConfig - } +private def getSimpUnfoldContext : MetaM Simp.Context := do + Simp.mkContext + (congrTheorems := (← getSimpCongrTheorems)) + (config := Simp.neutralConfig) def unfold (e : Expr) (declName : Name) : MetaM Simp.Result := do if let some unfoldThm ← getUnfoldEqnFor? declName then