Skip to content

Commit

Permalink
refactor: mark the Simp.Context constructor as private
Browse files Browse the repository at this point in the history
motivation: this is the first step to fix the mismatch
between `isDefEq` and the discrimination tree indexing.
  • Loading branch information
leodemoura committed Nov 13, 2024
1 parent 456e6d2 commit 7791154
Show file tree
Hide file tree
Showing 19 changed files with 136 additions and 96 deletions.
4 changes: 3 additions & 1 deletion src/Lean/Elab/PreDefinition/Nonrec/Eqns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do
go mvarId
else if let some mvarId ← whnfReducibleLHS? mvarId then
go mvarId
else match (← simpTargetStar mvarId { config.dsimp := false } (simprocs := {})).1 with
else
let ctx ← Simp.mkContext (config := { dsimp := false })
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
| TacticResultCNM.closed => return ()
| TacticResultCNM.modified mvarId => go mvarId
| TacticResultCNM.noChange =>
Expand Down
4 changes: 3 additions & 1 deletion src/Lean/Elab/PreDefinition/Structural/Eqns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ where
go mvarId
else if let some mvarId ← simpIf? mvarId then
go mvarId
else match (← simpTargetStar mvarId {} (simprocs := {})).1 with
else
let ctx ← Simp.mkContext
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
| TacticResultCNM.closed => return ()
| TacticResultCNM.modified mvarId => go mvarId
| TacticResultCNM.noChange =>
Expand Down
4 changes: 3 additions & 1 deletion src/Lean/Elab/PreDefinition/WF/Eqns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do
go mvarId
else if let some mvarId ← whnfReducibleLHS? mvarId then
go mvarId
else match (← simpTargetStar mvarId { config.dsimp := false } (simprocs := {})).1 with
else
let ctx ← Simp.mkContext (config := { dsimp := false })
match (← simpTargetStar mvarId ctx (simprocs := {})).1 with
| TacticResultCNM.closed => return ()
| TacticResultCNM.modified mvarId => go mvarId
| TacticResultCNM.noChange =>
Expand Down
22 changes: 8 additions & 14 deletions src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,10 @@ def rewriteRulesPass (maxSteps : Nat) : Pass := fun goal => do
let bvSimprocs ← bvNormalizeSimprocExt.getSimprocs
let sevalThms ← getSEvalTheorems
let sevalSimprocs ← Simp.getSEvalSimprocs

let simpCtx : Simp.Context := {
config := { failIfUnchanged := false, zetaDelta := true, maxSteps }
simpTheorems := #[bvThms, sevalThms]
congrTheorems := (← getSimpCongrTheorems)
}

let simpCtx ← Simp.mkContext
(config := { failIfUnchanged := false, zetaDelta := true, maxSteps })
(simpTheorems := #[bvThms, sevalThms])
(congrTheorems := (← getSimpCongrTheorems))
let hyps ← goal.getNondepPropHyps
let ⟨result?, _⟩ ← simpGoal goal
(ctx := simpCtx)
Expand All @@ -193,13 +190,10 @@ def embeddedConstraintPass (maxSteps : Nat) : Pass := fun goal =>
let proof := localDecl.toExpr
acc.addTheorem (.fvar hyp) proof
let relevantHyps : SimpTheoremsArray ← hyps.foldlM (init := #[]) relevanceFilter

let simpCtx : Simp.Context := {
config := { failIfUnchanged := false, maxSteps }
simpTheorems := relevantHyps
congrTheorems := (← getSimpCongrTheorems)
}

let simpCtx ← Simp.mkContext
(config := { failIfUnchanged := false, maxSteps })
(simpTheorems := relevantHyps)
(congrTheorems := (← getSimpCongrTheorems))
let ⟨result?, _⟩ ← simpGoal goal (ctx := simpCtx) (fvarIdsToSimp := hyps)
let some (_, newGoal) := result? | return none
return newGoal
Expand Down
11 changes: 5 additions & 6 deletions src/Lean/Elab/Tactic/Conv/Pattern.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ namespace Lean.Elab.Tactic.Conv
open Meta

private def getContext : MetaM Simp.Context := do
return {
simpTheorems := {}
congrTheorems := (← getSimpCongrTheorems)
config := Simp.neutralConfig
}
Simp.mkContext
(simpTheorems := {})
(congrTheorems := (← getSimpCongrTheorems))
(config := Simp.neutralConfig)

partial def matchPattern? (pattern : AbstractMVarsResult) (e : Expr) : MetaM (Option (Expr × Array Expr)) :=
withNewMCtxDepth do
Expand Down Expand Up @@ -126,7 +125,7 @@ private def pre (pattern : AbstractMVarsResult) (state : IO.Ref PatternMatchStat
pure (.occs #[] 0 ids.toList)
| _ => throwUnsupportedSyntax
let state ← IO.mkRef occs
let ctx := { ← getContext with config.memoize := occs matches .all _ }
let ctx := (← getContext).setMemoize (occs matches .all _)
let (result, _) ← Simp.main lhs ctx (methods := { pre := pre patternA state })
let subgoals ← match ← state.get with
| .all #[] | .occs _ 0 _ =>
Expand Down
20 changes: 14 additions & 6 deletions src/Lean/Elab/Tactic/NormCast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def proveEqUsing (s : SimpTheorems) (a b : Expr) : MetaM (Option Simp.Result) :=
unless ← isDefEq a'.expr b'.expr do return none
a'.mkEqTrans (← b'.mkEqSymm b)
withReducible do
(go (← Simp.mkDefaultMethods).toMethodsRef
{ simpTheorems := #[s], congrTheorems := ← Meta.getSimpCongrTheorems }).run' {}
let ctx ← Simp.mkContext
(simpTheorems := #[s])
(congrTheorems := ← Meta.getSimpCongrTheorems)
(go (← Simp.mkDefaultMethods).toMethodsRef ctx).run' {}

/-- Proves `a = b` by simplifying using move and squash lemmas. -/
def proveEqUsingDown (a b : Expr) : MetaM (Option Simp.Result) := do
Expand Down Expand Up @@ -191,19 +193,25 @@ def derive (e : Expr) : MetaM Simp.Result := do
-- step 1: pre-processing of numerals
let r ← withTrace "pre-processing numerals" do
let post e := return Simp.Step.done (← try numeralToCoe e catch _ => pure {expr := e})
r.mkEqTrans (← Simp.main r.expr { config, congrTheorems } (methods := { post })).1
let ctx ← Simp.mkContext (config := config) (congrTheorems := congrTheorems)
r.mkEqTrans (← Simp.main r.expr ctx (methods := { post })).1

-- step 2: casts are moved upwards and eliminated
let r ← withTrace "moving upward, splitting and eliminating" do
let post := upwardAndElim (← normCastExt.up.getTheorems)
r.mkEqTrans (← Simp.main r.expr { config, congrTheorems } (methods := { post })).1
let ctx ← Simp.mkContext (config := config) (congrTheorems := congrTheorems)
r.mkEqTrans (← Simp.main r.expr ctx (methods := { post })).1

let simprocs ← ({} : Simp.SimprocsArray).add `reduceCtorEq false

-- step 3: casts are squashed
let r ← withTrace "squashing" do
let simpTheorems := #[← normCastExt.squash.getTheorems]
r.mkEqTrans (← simp r.expr { simpTheorems, config, congrTheorems } simprocs).1
let ctx ← Simp.mkContext
(config := config)
(simpTheorems := simpTheorems)
(congrTheorems := congrTheorems)
r.mkEqTrans (← simp r.expr ctx simprocs).1

return r

Expand Down Expand Up @@ -263,7 +271,7 @@ def evalConvNormCast : Tactic :=
def evalPushCast : Tactic := fun stx => do
let { ctx, simprocs, dischargeWrapper } ← withMainContext do
mkSimpContext (simpTheorems := pushCastExt.getTheorems) stx (eraseLocal := false)
let ctx := { ctx with config := { ctx.config with failIfUnchanged := false } }
let ctx := ctx.setFailIfUnchanged false
dischargeWrapper.with fun discharge? =>
discard <| simpLocation ctx simprocs discharge? (expandOptLocation stx[5])

Expand Down
13 changes: 7 additions & 6 deletions src/Lean/Elab/Tactic/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def elabSimpArgs (stx : Syntax) (ctx : Simp.Context) (simprocs : Simp.SimprocsAr
logException ex
else
throw ex
return { ctx := { ctx with simpTheorems := thmsArray.set! 0 thms }, simprocs, starArg }
return { ctx := ctx.setSimpTheorems (thmsArray.set! 0 thms), simprocs, starArg }
-- If recovery is disabled, then we want simp argument elaboration failures to be exceptions.
-- This affects `addSimpTheorem`.
if (← read).recover then
Expand Down Expand Up @@ -311,10 +311,11 @@ def mkSimpContext (stx : Syntax) (eraseLocal : Bool) (kind := SimpKind.simp)
simpTheorems
let simprocs ← if simpOnly then pure {} else Simp.getSimprocs
let congrTheorems ← getSimpCongrTheorems
let r ← elabSimpArgs stx[4] (eraseLocal := eraseLocal) (kind := kind) (simprocs := #[simprocs]) {
config := (← elabSimpConfig stx[1] (kind := kind))
simpTheorems := #[simpTheorems], congrTheorems
}
let ctx ← Simp.mkContext
(config := (← elabSimpConfig stx[1] (kind := kind)))
(simpTheorems := #[simpTheorems])
congrTheorems
let r ← elabSimpArgs stx[4] (eraseLocal := eraseLocal) (kind := kind) (simprocs := #[simprocs]) ctx
if !r.starArg || ignoreStarArg then
return { r with dischargeWrapper }
else
Expand All @@ -329,7 +330,7 @@ def mkSimpContext (stx : Syntax) (eraseLocal : Bool) (kind := SimpKind.simp)
for h in hs do
unless simpTheorems.isErased (.fvar h) do
simpTheorems ← simpTheorems.addTheorem (.fvar h) (← h.getDecl).toExpr
let ctx := { ctx with simpTheorems }
let ctx := ctx.setSimpTheorems simpTheorems
return { ctx, simprocs, dischargeWrapper }

register_builtin_option tactic.simp.trace : Bool := {
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Elab/Tactic/Simpa.lean
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ deriving instance Repr for UseImplicitLambdaResult
let stx ← `(tactic| simp $cfg:optConfig $(disch)? $[only%$only]? $[[$args,*]]?)
let { ctx, simprocs, dischargeWrapper } ←
withMainContext <| mkSimpContext stx (eraseLocal := false)
let ctx := if unfold.isSome then { ctx with config.autoUnfold := true } else ctx
let ctx := if unfold.isSome then ctx.setAutoUnfold else ctx
-- TODO: have `simpa` fail if it doesn't use `simp`.
let ctx := { ctx with config := { ctx.config with failIfUnchanged := false } }
let ctx := ctx.setFailIfUnchanged false
dischargeWrapper.with fun discharge? => do
let (some (_, g), stats) ← simpGoal (← getMainGoal) ctx (simprocs := simprocs)
(simplifyTarget := true) (discharge? := discharge?)
Expand Down
20 changes: 8 additions & 12 deletions src/Lean/Meta/Tactic/AC/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,10 @@ def post (e : Expr) : SimpM Simp.Step := do
| e, _ => return Simp.Step.done { expr := e }

def rewriteUnnormalized (mvarId : MVarId) : MetaM MVarId := do
let simpCtx :=
{
simpTheorems := {}
congrTheorems := (← getSimpCongrTheorems)
config := Simp.neutralConfig
}
let simpCtx ← Simp.mkContext
(simpTheorems := {})
(congrTheorems := (← getSimpCongrTheorems))
(config := Simp.neutralConfig)
let tgt ← instantiateMVars (← mvarId.getType)
let (res, _) ← Simp.main tgt simpCtx (methods := { post })
applySimpResultToTarget mvarId tgt res
Expand All @@ -207,12 +205,10 @@ def rewriteUnnormalizedRefl (goal : MVarId) : MetaM Unit := do

def acNfHypMeta (goal : MVarId) (fvarId : FVarId) : MetaM (Option MVarId) := do
goal.withContext do
let simpCtx :=
{
simpTheorems := {}
congrTheorems := (← getSimpCongrTheorems)
config := Simp.neutralConfig
}
let simpCtx ← Simp.mkContext
(simpTheorems := {})
(congrTheorems := (← getSimpCongrTheorems))
(config := Simp.neutralConfig)
let tgt ← instantiateMVars (← fvarId.getType)
let (res, _) ← Simp.main tgt simpCtx (methods := { post })
return (← applySimpResultToLocalDecl goal fvarId res false).map (·.snd)
Expand Down
5 changes: 4 additions & 1 deletion src/Lean/Meta/Tactic/Acyclic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ where
let sizeOfEq ← mkLT sizeOf_lhs sizeOf_rhs
let hlt ← mkFreshExprSyntheticOpaqueMVar sizeOfEq
-- TODO: we only need the `sizeOf` simp theorems
match (← simpTarget hlt.mvarId! { config.arith := true, simpTheorems := #[ (← getSimpTheorems) ] } {}).1 with
let ctx ← Simp.mkContext
(config := { arith := true })
(simpTheorems := #[ (← getSimpTheorems) ])
match (← simpTarget hlt.mvarId! ctx {}).1 with
| some _ => return false
| none =>
let heq ← mkCongrArg sizeOf_lhs.appFn! (← mkEqSymm h)
Expand Down
9 changes: 4 additions & 5 deletions src/Lean/Meta/Tactic/Grind/Preprocessor.lean
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@ abbrev PreM := ReaderT Context $ StateRefT State GrindM
def PreM.run (x : PreM α) : GrindM α := do
let thms ← grindNormExt.getTheorems
let simprocs := #[(← grindNormSimprocExt.getSimprocs)]
let simp : Simp.Context := {
config := { arith := true }
simpTheorems := #[thms]
congrTheorems := (← getSimpCongrTheorems)
}
let simp ← Simp.mkContext
(config := { arith := true })
(simpTheorems := #[thms])
(congrTheorems := (← getSimpCongrTheorems))
x { simp, simprocs } |>.run' {}

def simp (_goal : Goal) (e : Expr) : PreM Simp.Result := do
Expand Down
7 changes: 5 additions & 2 deletions src/Lean/Meta/Tactic/Simp/Attr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def getSimpTheorems : CoreM SimpTheorems :=
def getSEvalTheorems : CoreM SimpTheorems :=
sevalSimpExtension.getTheorems

def Simp.Context.mkDefault : MetaM Context :=
return { config := {}, simpTheorems := #[(← Meta.getSimpTheorems)], congrTheorems := (← Meta.getSimpCongrTheorems) }
def Simp.Context.mkDefault : MetaM Context := do
mkContext
(config := {})
(simpTheorems := #[(← Meta.getSimpTheorems)])
(congrTheorems := (← Meta.getSimpCongrTheorems))

end Lean.Meta
23 changes: 6 additions & 17 deletions src/Lean/Meta/Tactic/Simp/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,6 @@ builtin_initialize congrHypothesisExceptionId : InternalExceptionId ←
def throwCongrHypothesisFailed : MetaM α :=
throw <| Exception.internal congrHypothesisExceptionId

/--
Helper method for bootstrapping purposes. It disables `arith` if support theorems have not been defined yet.
-/
def Config.updateArith (c : Config) : CoreM Config := do
if c.arith then
if (← getEnv).contains ``Nat.Linear.ExprCnstr.eq_of_toNormPoly_eq then
return c
else
return { c with arith := false }
else
return c

/-- Return true if `e` is of the form `ofNat n` where `n` is a kernel Nat literal -/
def isOfNatNatLit (e : Expr) : Bool :=
e.isAppOf ``OfNat.ofNat && e.getAppNumArgs >= 3 && (e.getArg! 1).isRawNatLit
Expand Down Expand Up @@ -256,7 +244,7 @@ def withNewLemmas {α} (xs : Array Expr) (f : SimpM α) : SimpM α := do
s ← s.addTheorem (.fvar x.fvarId!) x
updated := true
if updated then
withTheReader Context (fun ctx => { ctx with simpTheorems := s }) f
withSimpTheorems s f
else
f
else if (← getMethods).wellBehavedDischarge then
Expand Down Expand Up @@ -463,7 +451,7 @@ private partial def dsimpImpl (e : Expr) : SimpM Expr := do
let m ← getMethods
let pre := m.dpre >> doNotVisitOfNat >> doNotVisitOfScientific >> doNotVisitCharLit
let post := m.dpost >> dsimpReduce
withTheReader Simp.Context (fun ctx => { ctx with inDSimp := true }) do
withInDSimp do
transform (usedLetOnly := cfg.zeta) e (pre := pre) (post := post)

def visitFn (e : Expr) : SimpM Result := do
Expand Down Expand Up @@ -658,11 +646,12 @@ where
trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}"
simpLoop e

-- TODO: delete
@[inline] def withSimpContext (ctx : Context) (x : MetaM α) : MetaM α :=
withConfig (fun c => { c with etaStruct := ctx.config.etaStruct }) <| withReducible x

def main (e : Expr) (ctx : Context) (stats : Stats := {}) (methods : Methods := {}) : MetaM (Result × Stats) := do
let ctx := { ctx with config := (← ctx.config.updateArith), lctxInitIndices := (← getLCtx).numIndices }
let ctx ← ctx.setLctxInitIndices
withSimpContext ctx do
let (r, s) ← go e methods.toMethodsRef ctx |>.run { stats with }
trace[Meta.Tactic.simp.numSteps] "{s.numSteps}"
Expand Down Expand Up @@ -810,7 +799,7 @@ def simpGoal (mvarId : MVarId) (ctx : Simp.Context) (simprocs : SimprocsArray :=
for fvarId in fvarIdsToSimp do
let localDecl ← fvarId.getDecl
let type ← instantiateMVars localDecl.type
let ctx := { ctx with simpTheorems := ctx.simpTheorems.eraseTheorem (.fvar localDecl.fvarId) }
let ctx := ctx.setSimpTheorems <| ctx.simpTheorems.eraseTheorem (.fvar localDecl.fvarId)
let (r, stats') ← simp type ctx simprocs discharge? stats
stats := stats'
match r.proof? with
Expand Down Expand Up @@ -844,7 +833,7 @@ def simpTargetStar (mvarId : MVarId) (ctx : Simp.Context) (simprocs : SimprocsAr
let localDecl ← h.getDecl
let proof := localDecl.toExpr
let simpTheorems ← ctx.simpTheorems.addTheorem (.fvar h) proof
ctx := { ctx with simpTheorems }
ctx := ctx.setSimpTheorems simpTheorems
match (← simpTarget mvarId ctx simprocs discharge? (stats := stats)) with
| (none, stats) => return (TacticResultCNM.closed, stats)
| (some mvarId', stats') =>
Expand Down
9 changes: 6 additions & 3 deletions src/Lean/Meta/Tactic/Simp/Rewrite.lean
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def discharge?' (thmId : Origin) (x : Expr) (type : Expr) : SimpM Bool := do
let ctx ← getContext
if ctx.dischargeDepth >= ctx.maxDischargeDepth then
return .maxDepth
else withTheReader Context (fun ctx => { ctx with dischargeDepth := ctx.dischargeDepth + 1 }) do
else withIncDischargeDepth do
-- We save the state, so that `UsedTheorems` does not accumulate
-- `simp` lemmas used during unsuccessful discharging.
-- We use `withPreservedCache` to ensure the cache is restored after `discharge?`
Expand Down Expand Up @@ -446,10 +446,13 @@ def mkSEvalMethods : CoreM Methods := do
wellBehavedDischarge := true
}

def mkSEvalContext : CoreM Context := do
def mkSEvalContext : MetaM Context := do
let s ← getSEvalTheorems
let c ← Meta.getSimpCongrTheorems
return { simpTheorems := #[s], congrTheorems := c, config := { ground := true } }
mkContext
(simpTheorems := #[s])
(congrTheorems := c)
(config := { ground := true })

/--
Invoke ground/symbolic evaluator from `simp`.
Expand Down
6 changes: 3 additions & 3 deletions src/Lean/Meta/Tactic/Simp/SimpAll.lean
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ private def initEntries : M Unit := do
let localDecl ← h.getDecl
let proof := localDecl.toExpr
simpThms ← simpThms.addTheorem (.fvar h) proof
modify fun s => { s with ctx.simpTheorems := simpThms }
modify fun s => { s with ctx := s.ctx.setSimpTheorems simpThms }
if hsNonDeps.contains h then
-- We only simplify nondependent hypotheses
let type ← instantiateMVars localDecl.type
Expand All @@ -62,7 +62,7 @@ private partial def loop : M Bool := do
let ctx := (← get).ctx
-- We disable the current entry to prevent it to be simplified to `True`
let simpThmsWithoutEntry := (← getSimpTheorems).eraseTheorem entry.id
let ctx := { ctx with simpTheorems := simpThmsWithoutEntry }
let ctx := ctx.setSimpTheorems simpThmsWithoutEntry
let (r, stats) ← simpStep (← get).mvarId entry.proof entry.type ctx simprocs (stats := { (← get) with })
modify fun s => { s with usedTheorems := stats.usedTheorems, diag := stats.diag }
match r with
Expand Down Expand Up @@ -98,7 +98,7 @@ private partial def loop : M Bool := do
simpThmsNew ← simpThmsNew.addTheorem (.other idNew) (← mkExpectedTypeHint proofNew typeNew)
modify fun s => { s with
modified := true
ctx.simpTheorems := simpThmsNew
ctx := ctx.setSimpTheorems simpThmsNew
entries[i] := { entry with type := typeNew, proof := proofNew, id := .other idNew }
}
-- simplify target
Expand Down
Loading

0 comments on commit 7791154

Please sign in to comment.