Skip to content

Commit

Permalink
refactor: store MetaM configuration at Simp.Context
Browse files Browse the repository at this point in the history
  • Loading branch information
leodemoura committed Nov 13, 2024
1 parent 624dfa6 commit 1a96ac5
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 52 deletions.
29 changes: 15 additions & 14 deletions src/Lean/Elab/Tactic/DiscrTreeKey.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,22 @@ private def mkKey (e : Expr) (simp : Bool) : MetaM (Array Key) := do
let (_, _, type) ← withReducible <| forallMetaTelescopeReducing e
let type ← whnfR type
if simp then
if let some (_, lhs, _) := type.eq? then
mkPath lhs simpDtConfig
else if let some (lhs, _) := type.iff? then
mkPath lhs simpDtConfig
else if let some (_, lhs, _) := type.ne? then
mkPath lhs simpDtConfig
else if let some p := type.not? then
match p.eq? with
| some (_, lhs, _) =>
mkPath lhs simpDtConfig
| _ => mkPath p simpDtConfig
else
mkPath type simpDtConfig
withSimpGlobalConfig do
if let some (_, lhs, _) := type.eq? then
mkPath lhs
else if let some (lhs, _) := type.iff? then
mkPath lhs
else if let some (_, lhs, _) := type.ne? then
mkPath lhs
else if let some p := type.not? then
match p.eq? with
| some (_, lhs, _) =>
mkPath lhs
| _ => mkPath p
else
mkPath type
else
mkPath type {}
mkPath type

private def getType (t : TSyntax `term) : TermElabM Expr := do
if let `($id:ident) := t then
Expand Down
7 changes: 2 additions & 5 deletions src/Lean/Elab/Tactic/Ext.lean
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,6 @@ structure ExtTheorems where
erased : PHashSet Name := {}
deriving Inhabited

/-- Discrimation tree settings for the `ext` extension. -/
def extExt.config : WhnfCoreConfig := {}

/-- The environment extension to track `@[ext]` theorems. -/
builtin_initialize extExtension :
SimpleScopedEnvExtension ExtTheorem ExtTheorems ←
Expand All @@ -211,7 +208,7 @@ builtin_initialize extExtension :
ordered from high priority to low. -/
@[inline] def getExtTheorems (ty : Expr) : MetaM (Array ExtTheorem) := do
let extTheorems := extExtension.getState (← getEnv)
let arr ← extTheorems.tree.getMatch ty extExt.config
let arr ← extTheorems.tree.getMatch ty
let erasedArr := arr.filter fun thm => !extTheorems.erased.contains thm.declName
-- Using insertion sort because it is stable and the list of matches should be mostly sorted.
-- Most ext theorems have default priority.
Expand Down Expand Up @@ -258,7 +255,7 @@ builtin_initialize registerBuiltinAttribute {
but this theorem proves{indentD declTy}"
let some (ty, lhs, rhs) := declTy.eq? | failNotEq
unless lhs.isMVar && rhs.isMVar do failNotEq
let keys ← withReducible <| DiscrTree.mkPath ty extExt.config
let keys ← withReducible <| DiscrTree.mkPath ty
let priority ← liftCommandElabM <| Elab.liftMacroM do evalPrio (prio.getD (← `(prio| default)))
extExtension.add {declName, keys, priority} kind
-- Realize iff theorem
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/Tactic/Simproc.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def elabSimprocPattern (stx : Syntax) : MetaM Expr := do

def elabSimprocKeys (stx : Syntax) : MetaM (Array Meta.SimpTheoremKey) := do
let pattern ← elabSimprocPattern stx
DiscrTree.mkPath pattern simpDtConfig
withSimpGlobalConfig <| DiscrTree.mkPath pattern

def checkSimprocType (declName : Name) : CoreM Bool := do
let decl ← getConstInfo declName
Expand Down
2 changes: 2 additions & 0 deletions src/Lean/Meta/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ structure Config where
Zeta-delta reduction: given a local context containing entry `x : t := e`, free variable `x` reduces to `e`.
-/
zetaDelta : Bool := true
deriving Inhabited

/-- Convert `isDefEq` and `WHNF` relevant parts into a key for caching results -/
private def Config.toKey (c : Config) : UInt64 :=
Expand All @@ -213,6 +214,7 @@ structure ConfigWithKey where
private mk ::
config : Config
key : UInt64
deriving Inhabited

def Config.toConfigWithKey (c : Config) : ConfigWithKey :=
{ config := c, key := c.toKey }
Expand Down
6 changes: 3 additions & 3 deletions src/Lean/Meta/Tactic/Simp/Rewrite.lean
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def rewrite? (e : Expr) (s : SimpTheoremTree) (erased : PHashSet Origin) (tag :
where
/-- For `(← getConfig).index := true`, use discrimination tree structure when collecting `simp` theorem candidates. -/
rewriteUsingIndex? : SimpM (Option Result) := do
let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig))
let candidates ← withSimpConfig <| s.getMatchWithExtra e
if candidates.isEmpty then
trace[Debug.Meta.Tactic.simp] "no theorems found for {tag}-rewriting {e}"
return none
Expand All @@ -221,7 +221,7 @@ where
Only the root symbol is taken into account. Most of the structure of the discrimination tree is ignored.
-/
rewriteNoIndex? : SimpM (Option Result) := do
let (candidates, numArgs) ← s.getMatchLiberal e (getDtConfig (← getConfig))
let (candidates, numArgs) ← withSimpConfig <| s.getMatchLiberal e
if candidates.isEmpty then
trace[Debug.Meta.Tactic.simp] "no theorems found for {tag}-rewriting {e}"
return none
Expand All @@ -245,7 +245,7 @@ where

diagnoseWhenNoIndex (thm : SimpTheorem) : SimpM Unit := do
if (← isDiagnosticsEnabled) then
let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig))
let candidates ← withSimpConfig <| s.getMatchWithExtra e
for (candidate, _) in candidates do
if unsafe ptrEq thm candidate then
return ()
Expand Down
49 changes: 34 additions & 15 deletions src/Lean/Meta/Tactic/Simp/SimpTheorems.lean
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,18 @@ structure SimpTheorems where
toUnfoldThms : PHashMap Name (Array Name) := {}
deriving Inhabited

/-- Configuration for the discrimination tree. -/
def simpDtConfig : WhnfCoreConfig := { iota := false, proj := .no, zetaDelta := false }
/--
Configuration for `MetaM` used to process global simp theorems
-/
def simpGlobalConfig : ConfigWithKey :=
{ iota := false
proj := .no
zetaDelta := false
transparency := .reducible
: Config }.toConfigWithKey

@[inline] def withSimpGlobalConfig : MetaM α → MetaM α :=
withConfigWithKey simpGlobalConfig

partial def SimpTheorems.eraseCore (d : SimpTheorems) (thmId : Origin) : SimpTheorems :=
let d := { d with erased := d.erased.insert thmId, lemmaNames := d.lemmaNames.erase thmId }
Expand Down Expand Up @@ -298,7 +308,7 @@ private partial def isPerm : Expr → Expr → MetaM Bool
| s, t => return s == t

private def checkBadRewrite (lhs rhs : Expr) : MetaM Unit := do
let lhs ← DiscrTree.reduceDT lhs (root := true) simpDtConfig
let lhs ← withSimpGlobalConfig <| DiscrTree.reduceDT lhs (root := true)
if lhs == rhs && lhs.isFVar then
throwError "invalid `simp` theorem, equation is equivalent to{indentExpr (← mkEq lhs rhs)}"

Expand Down Expand Up @@ -381,11 +391,11 @@ private def mkSimpTheoremCore (origin : Origin) (e : Expr) (levelParams : Array
assert! origin != .fvar ⟨.anonymous⟩
let type ← instantiateMVars (← inferType e)
withNewMCtxDepth do
let (_, _, type) ← withReducible <| forallMetaTelescopeReducing type
let (_, _, type) ← forallMetaTelescopeReducing type
let type ← whnfR type
let (keys, perm) ←
match type.eq? with
| some (_, lhs, rhs) => pure (← DiscrTree.mkPath lhs simpDtConfig noIndexAtArgs, ← isPerm lhs rhs)
| some (_, lhs, rhs) => withSimpGlobalConfig <| pure (← DiscrTree.mkPath lhs noIndexAtArgs, ← isPerm lhs rhs)
| none => throwError "unexpected kind of 'simp' theorem{indentExpr type}"
return { origin, keys, perm, post, levelParams, proof, priority := prio, rfl := (← isRflProof proof) }

Expand All @@ -394,7 +404,7 @@ private def mkSimpTheoremsFromConst (declName : Name) (post : Bool) (inv : Bool)
let us := cinfo.levelParams.map mkLevelParam
let origin := .decl declName post inv
let val := mkConst declName us
withReducible do
withSimpGlobalConfig do
let type ← inferType val
checkTypeIsProp type
if inv || (← shouldPreprocess type) then
Expand Down Expand Up @@ -464,18 +474,10 @@ private def preprocessProof (val : Expr) (inv : Bool) : MetaM (Array Expr) := do
return ps.toArray.map fun (val, _) => val

/-- Auxiliary method for creating simp theorems from a proof term `val`. -/
def mkSimpTheorems (id : Origin) (levelParams : Array Name) (proof : Expr) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM (Array SimpTheorem) :=
private def mkSimpTheorems (id : Origin) (levelParams : Array Name) (proof : Expr) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM (Array SimpTheorem) :=
withReducible do
(← preprocessProof proof inv).mapM fun val => mkSimpTheoremCore id val levelParams val post prio (noIndexAtArgs := true)

/-- Auxiliary method for adding a local simp theorem to a `SimpTheorems` datastructure. -/
def SimpTheorems.add (s : SimpTheorems) (id : Origin) (levelParams : Array Name) (proof : Expr) (inv := false) (post := true) (prio : Nat := eval_prio default) : MetaM SimpTheorems := do
if proof.isConst then
s.addConst proof.constName! post inv prio
else
let simpThms ← mkSimpTheorems id levelParams proof post inv prio
return simpThms.foldl addSimpTheoremEntry s

/--
Reducible functions and projection functions should always be put in `toUnfold`, instead
of trying to use equational theorems.
Expand Down Expand Up @@ -533,8 +535,25 @@ def SimpTheorems.addDeclToUnfold (d : SimpTheorems) (declName : Name) : MetaM Si
else
return d.addDeclToUnfoldCore declName

/-- Auxiliary method for adding a local simp theorem to a `SimpTheorems` datastructure. -/
-- TODO: It is used internally
def SimpTheorems.add (s : SimpTheorems) (id : Origin) (levelParams : Array Name) (proof : Expr) (inv := false) (post := true) (prio : Nat := eval_prio default) : MetaM SimpTheorems := do
if proof.isConst then
s.addConst proof.constName! post inv prio
else
let simpThms ← mkSimpTheorems id levelParams proof post inv prio
return simpThms.foldl addSimpTheoremEntry s

abbrev SimpTheoremsArray := Array SimpTheorems

/-
This API is used to
- Initialize bv_decide normalizer
- Initialize `*` at `simp` frontend
- Add contextual theorems
- `simpTargetStar`
- `simpAll`
-/
def SimpTheoremsArray.addTheorem (thmsArray : SimpTheoremsArray) (id : Origin) (h : Expr) : MetaM SimpTheoremsArray :=
if thmsArray.isEmpty then
let thms : SimpTheorems := {}
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Meta/Tactic/Simp/Simproc.lean
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def SimprocEntry.tryD (s : SimprocEntry) (numExtraArgs : Nat) (e : Expr) : SimpM
| .inr proc => return (← proc e).addExtraArgs extraArgs

def simprocCore (post : Bool) (s : SimprocTree) (erased : PHashSet Name) (e : Expr) : SimpM Step := do
let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig))
let candidates ← withSimpConfig <| s.getMatchWithExtra e
if candidates.isEmpty then
let tag := if post then "post" else "pre"
trace[Debug.Meta.Tactic.simp] "no {tag}-simprocs found for {e}"
Expand Down Expand Up @@ -250,7 +250,7 @@ def simprocCore (post : Bool) (s : SimprocTree) (erased : PHashSet Name) (e : Ex
return .continue

def dsimprocCore (post : Bool) (s : SimprocTree) (erased : PHashSet Name) (e : Expr) : SimpM DStep := do
let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig))
let candidates ← withSimpConfig <| s.getMatchWithExtra e
if candidates.isEmpty then
let tag := if post then "post" else "pre"
trace[Debug.Meta.Tactic.simp] "no {tag}-simprocs found for {e}"
Expand Down
32 changes: 20 additions & 12 deletions src/Lean/Meta/Tactic/Simp/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ abbrev CongrCache := ExprMap (Option CongrTheorem)

structure Context where
private mk ::
config : Config := {}
config : Config := {}
metaConfig : ConfigWithKey := default
/-- `maxDischargeDepth` from `config` as an `UInt32`. -/
maxDischargeDepth : UInt32 := UInt32.ofNatTruncate config.maxDischargeDepth
simpTheorems : SimpTheoremsArray := {}
Expand Down Expand Up @@ -117,9 +118,18 @@ private def updateArith (c : Config) : CoreM Config := do
else
return c

/--
Converts `Simp.Config` into `Meta.ConfigWithKey`.
-/
private def mkMetaConfig (c : Config) : ConfigWithKey :=
{ c with
proj := if c.proj then .yesWithDelta else .no
transparency := .reducible
: Meta.Config }.toConfigWithKey

def mkContext (config : Config := {}) (simpTheorems : SimpTheoremsArray := {}) (congrTheorems : SimpCongrTheorems := {}) : MetaM Context := do
let config ← updateArith config
return { config, simpTheorems, congrTheorems }
return { config, simpTheorems, congrTheorems, metaConfig := mkMetaConfig config }

def Context.setSimpTheorems (c : Context) (simpTheorems : SimpTheoremsArray) : Context :=
{ c with simpTheorems }
Expand Down Expand Up @@ -200,6 +210,14 @@ abbrev SimpM := ReaderT MethodsRef $ ReaderT Context $ StateRefT State MetaM
@[inline] def withInDSimp : SimpM α → SimpM α :=
withTheReader Context (fun ctx => { ctx with inDSimp := true })

/--
Executes `x` using a `MetaM` configuration inferred from `Simp.Config`.
For example, if the user has set `simp (config := { zeta := false })`,
`isDefEq` and `whnf` in `MetaM` should not perform `zeta` reduction.
-/
@[inline] def withSimpConfig (x : SimpM α) : SimpM α := do
withConfigWithKey (← readThe Simp.Context).metaConfig x

@[extern "lean_simp"]
opaque simp (e : Expr) : SimpM Result

Expand Down Expand Up @@ -676,16 +694,6 @@ def tryAutoCongrTheorem? (e : Expr) : SimpM (Option Result) := do
/- See comment above. This is reachable if `hasCast == true`. The `rhs` is not structurally equal to `mkAppN f argsNew` -/
return some { expr := rhs }

/--
Return a WHNF configuration for retrieving `[simp]` from the discrimination tree.
If user has disabled `zeta` and/or `beta` reduction in the simplifier, or enabled `zetaDelta`,
we must also disable/enable them when retrieving lemmas from discrimination tree. See issues: #2669 and #2281
-/
def getDtConfig (cfg : Config) : WhnfCoreConfig :=
match cfg.beta, cfg.zeta, cfg.zetaDelta with
| true, true, false => simpDtConfig
| _, _, _ => { simpDtConfig with zeta := cfg.zeta, beta := cfg.beta, zetaDelta := cfg.zetaDelta }

def Result.addExtraArgs (r : Result) (extraArgs : Array Expr) : MetaM Result := do
match r.proof? with
| none => return { expr := mkAppN r.expr extraArgs }
Expand Down

0 comments on commit 1a96ac5

Please sign in to comment.