Skip to content

Commit

Permalink
fixes for fun_prop for leanprover/lean4#6503
Browse files Browse the repository at this point in the history
  • Loading branch information
kim-em committed Nov 18, 2024
1 parent 10c8546 commit 7faa75a
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 51 deletions.
1 change: 0 additions & 1 deletion Mathlib/Tactic/FunProp/Attr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ initialize funPropAttr : Unit ←
add := fun declName _stx attrKind =>
discard <| MetaM.run do
let info ← getConstInfo declName

forallTelescope info.type fun _ b => do
if b.isProp then
addFunPropDecl declName
Expand Down
12 changes: 8 additions & 4 deletions Mathlib/Tactic/FunProp/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ def applyMorRules (funPropDecl : FunPropDecl) (e : Expr) (fData : FunctionData)
| .exact =>

let ext := morTheoremsExt.getState (← getEnv)
let candidates ← ext.theorems.getMatchWithScore e false { iota := false, zeta := false }
let candidates ← withConfig (fun cfg => { cfg with iota := false, zeta := false }) <|
ext.theorems.getMatchWithScore e false
let candidates := candidates.map (·.1) |>.flatten

trace[Meta.Tactic.fun_prop]
Expand All @@ -344,7 +345,8 @@ def applyTransitionRules (e : Expr) (funProp : Expr → FunPropM (Option Result)
withIncreasedTransitionDepth do

let ext := transitionTheoremsExt.getState (← getEnv)
let candidates ← ext.theorems.getMatchWithScore e false { iota := false, zeta := false }
let candidates ← withConfig (fun cfg => { cfg with iota := false, zeta := false }) <|
ext.theorems.getMatchWithScore e false
let candidates := candidates.map (·.1) |>.flatten

trace[Meta.Tactic.fun_prop]
Expand Down Expand Up @@ -433,7 +435,8 @@ def getLocalTheorems (funPropDecl : FunPropDecl) (funOrigin : Origin)
let .some (decl,f) ← getFunProp? b | return none
unless decl.funPropName = funPropDecl.funPropName do return none

let .data fData ← getFunctionData? f (← unfoldNamePred) {zeta := false, zetaDelta := false}
let .data fData ← withConfig (fun cfg => { cfg with zeta := false, zetaDelta := false }) <|
getFunctionData? f (← unfoldNamePred)
| return none
unless (fData.getFnOrigin == funOrigin) do return none

Expand Down Expand Up @@ -654,7 +657,8 @@ mutual
let e' := e.setArg funPropDecl.funArgId b
funProp (← mkLambdaFVars xs e')

match ← getFunctionData? f (← unfoldNamePred) {zeta := false, zetaDelta := false} with
match ← withConfig (fun cfg => { cfg with zeta := false, zetaDelta := false }) <|
getFunctionData? f (← unfoldNamePred) with
| .letE f =>
trace[Debug.Meta.Tactic.fun_prop] "let case on {← ppExpr f}"
let e := e.setArg funPropDecl.funArgId f -- update e with reduced f
Expand Down
2 changes: 1 addition & 1 deletion Mathlib/Tactic/FunProp/Decl.lean
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ the function it talks about. -/
def getFunProp? (e : Expr) : MetaM (Option (FunPropDecl × Expr)) := do
let ext := funPropDeclsExt.getState (← getEnv)

let decls ← ext.decls.getMatch e {}
let decls ← ext.decls.getMatch e (← read)

if decls.size = 0 then
return none
Expand Down
4 changes: 2 additions & 2 deletions Mathlib/Tactic/FunProp/FunctionData.lean
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def MaybeFunctionData.get (fData : MaybeFunctionData) : MetaM Expr :=

/-- Get `FunctionData` for `f`. -/
def getFunctionData? (f : Expr)
(unfoldPred : Name → Bool := fun _ => false) (cfg : WhnfCoreConfig := {}) :
(unfoldPred : Name → Bool := fun _ => false) :
MetaM MaybeFunctionData := do

let unfold := fun e : Expr => do
Expand All @@ -130,7 +130,7 @@ def getFunctionData? (f : Expr)
| throwError m!"fun_prop bug: function expected, got `{f} : {← inferType f}, \
type ctor {(← inferType f).ctorName}"
withLocalDeclD xName xType fun x => do
let fx' := (← Mor.whnfPred (f.beta #[x]).eta unfold cfg) |> headBetaThroughLet
let fx' := (← Mor.whnfPred (f.beta #[x]).eta unfold) |> headBetaThroughLet
let f' ← mkLambdaFVars #[x] fx'
match fx' with
| .letE .. => return .letE f'
Expand Down
14 changes: 7 additions & 7 deletions Mathlib/Tactic/FunProp/Mor.lean
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,22 @@ can specify which when to unfold definitions.
For example calling this on `coe (f a) b` will put `f` in weak normal head form instead of `coe`.
-/
partial def whnfPred (e : Expr) (pred : Expr → MetaM Bool) (cfg : WhnfCoreConfig := {}) :
partial def whnfPred (e : Expr) (pred : Expr → MetaM Bool) :
MetaM Expr := do
whnfEasyCases e fun e => do
let e ← whnfCore e cfg
let e ← whnfCore e

if let .some ⟨coe,f,x⟩ ← isMorApp? e then
let f ← whnfPred f pred cfg
if cfg.zeta then
let f ← whnfPred f pred
if (← getConfig).zeta then
return (coe.app f).app x
else
return ← letTelescope f fun xs f' =>
mkLambdaFVars xs ((coe.app f').app x)

if (← pred e) then
match (← unfoldDefinition? e) with
| some e => whnfPred e pred cfg
| some e => whnfPred e pred
| none => return e
else
return e
Expand All @@ -88,8 +88,8 @@ Weak normal head form of an expression involving morphism applications.
For example calling this on `coe (f a) b` will put `f` in weak normal head form instead of `coe`.
-/
def whnf (e : Expr) (cfg : WhnfCoreConfig := {}) : MetaM Expr :=
whnfPred e (fun _ => return false) cfg
def whnf (e : Expr) : MetaM Expr :=
whnfPred e (fun _ => return false)


/-- Argument of morphism application that stores corresponding coercion if necessary -/
Expand Down
60 changes: 29 additions & 31 deletions Mathlib/Tactic/FunProp/RefinedDiscrTree.lean
Original file line number Diff line number Diff line change
Expand Up @@ -438,23 +438,23 @@ where
| _ => failure

/-- Reduction procedure for the `RefinedDiscrTree` indexing. -/
partial def reduce (e : Expr) (config : WhnfCoreConfig) : MetaM Expr := do
let e ← whnfCore e config
partial def reduce (e : Expr) : MetaM Expr := do
let e ← whnfCore e
match (← unfoldDefinition? e) with
| some e => reduce e config
| some e => reduce e
| none => match e.etaExpandedStrict? with
| some e => reduce e config
| some e => reduce e
| none => return e

/-- Repeatedly apply reduce while stripping lambda binders and introducing their variables -/
@[specialize]
partial def lambdaTelescopeReduce {m} [Monad m] [MonadLiftT MetaM m] [MonadControlT MetaM m]
[Nonempty α] (e : Expr) (fvars : List FVarId) (config : WhnfCoreConfig)
[Nonempty α] (e : Expr) (fvars : List FVarId)
(k : Expr → List FVarId → m α) : m α := do
match ← reduce e config with
match ← reduce e with
| .lam n d b bi =>
withLocalDecl n bi d fun fvar =>
lambdaTelescopeReduce (b.instantiate1 fvar) (fvar.fvarId! :: fvars) config k
lambdaTelescopeReduce (b.instantiate1 fvar) (fvar.fvarId! :: fvars) k
| e => k e fvars


Expand Down Expand Up @@ -492,7 +492,6 @@ private structure Context where
bvars : List FVarId := []
/-- Variables that come from a lambda that has been removed via η-reduction. -/
forbiddenVars : List FVarId := []
config : WhnfCoreConfig
fvarInContext : FVarId → Bool

/-- Return for each argument whether it should be ignored. -/
Expand Down Expand Up @@ -633,7 +632,7 @@ private def withLams {m} [Monad m] [MonadWithReader Context m]
/-- Return the encoding of `e` as a `DTExpr`.
If `root = false`, then `e` is a strict sub expression of the original expression. -/
partial def mkDTExprAux (e : Expr) (root : Bool) : ReaderT Context MetaM DTExpr := do
lambdaTelescopeReduce e [] (← read).config fun e lambdas =>
lambdaTelescopeReduce e [] fun e lambdas =>
e.withApp fun fn args => do

let argDTExpr (arg : Expr) (ignore : Bool) : ReaderT Context MetaM DTExpr :=
Expand Down Expand Up @@ -755,7 +754,7 @@ def cacheEtaPossibilities (e original : Expr) (lambdas : List FVarId)
/-- Return all encodings of `e` as a `DTExpr`, taking possible η-reductions into account.
If `root = false`, then `e` is a strict sub expression of the original expression. -/
partial def mkDTExprsAux (original : Expr) (root : Bool) : M DTExpr := do
lambdaTelescopeReduce original [] (← read).config fun e lambdas => do
lambdaTelescopeReduce original [] fun e lambdas => do

if !root then
if let .const n _ := e.getAppFn then
Expand Down Expand Up @@ -849,16 +848,16 @@ Warning: to account for potential η-reductions of `e`, use `mkDTExprs` instead.
The argument `fvarInContext` allows you to specify which free variables in `e` will still be
in the context when the `RefinedDiscrTree` is being used for lookup.
It should return true only if the `RefinedDiscrTree` is built and used locally. -/
def mkDTExpr (e : Expr) (config : WhnfCoreConfig)
def mkDTExpr (e : Expr)
(fvarInContext : FVarId → Bool := fun _ => false) : MetaM DTExpr :=
withReducible do (MkDTExpr.mkDTExprAux e true |>.run {config, fvarInContext})
withReducible do (MkDTExpr.mkDTExprAux e true |>.run {fvarInContext})

/-- Similar to `mkDTExpr`.
Return all encodings of `e` as a `DTExpr`, taking potential further η-reductions into account. -/
def mkDTExprs (e : Expr) (config : WhnfCoreConfig) (onlySpecific : Bool)
def mkDTExprs (e : Expr) (onlySpecific : Bool)
(fvarInContext : FVarId → Bool := fun _ => false) : MetaM (List DTExpr) :=
withReducible do
let es ← (MkDTExpr.mkDTExprsAux e true).run' {} |>.run {config, fvarInContext}
let es ← (MkDTExpr.mkDTExprsAux e true).run' {} |>.run {fvarInContext}
return if onlySpecific then es.filter (·.isSpecific) else es


Expand Down Expand Up @@ -932,18 +931,18 @@ It should return true only if the `RefinedDiscrTree` is built and used locally.
if `onlySpecific := true`, then we filter out the patterns `*` and `Eq * * *`. -/
def insert [BEq α] (d : RefinedDiscrTree α) (e : Expr) (v : α)
(onlySpecific : Bool := true) (config : WhnfCoreConfig := {})
(fvarInContext : FVarId → Bool := fun _ => false) : MetaM (RefinedDiscrTree α) := do
let keys ← mkDTExprs e config onlySpecific fvarInContext
(onlySpecific : Bool := true) (fvarInContext : FVarId → Bool := fun _ => false) :
MetaM (RefinedDiscrTree α) := do
let keys ← mkDTExprs e onlySpecific fvarInContext
return keys.foldl (insertDTExpr · · v) d

/-- Insert the value `vLhs` at index `lhs`, and if `rhs` is indexed differently, then also
insert the value `vRhs` at index `rhs`. -/
def insertEqn [BEq α] (d : RefinedDiscrTree α) (lhs rhs : Expr) (vLhs vRhs : α)
(onlySpecific : Bool := true) (config : WhnfCoreConfig := {})
(fvarInContext : FVarId → Bool := fun _ => false) : MetaM (RefinedDiscrTree α) := do
let keysLhs ← mkDTExprs lhs config onlySpecific fvarInContext
let keysRhs ← mkDTExprs rhs config onlySpecific fvarInContext
(onlySpecific : Bool := true) (fvarInContext : FVarId → Bool := fun _ => false) :
MetaM (RefinedDiscrTree α) := do
let keysLhs ← mkDTExprs lhs onlySpecific fvarInContext
let keysRhs ← mkDTExprs rhs onlySpecific fvarInContext
let d := keysLhs.foldl (insertDTExpr · · vLhs) d
if @List.beq _ ⟨DTExpr.eqv⟩ keysLhs keysRhs then
return d
Expand All @@ -967,7 +966,6 @@ def findKey (children : Array (Key × Trie α)) (k : Key) : Option (Trie α) :=

private structure Context where
unify : Bool
config : WhnfCoreConfig

private structure State where
/-- Score representing how good the match is. -/
Expand All @@ -981,9 +979,9 @@ private structure State where
private abbrev M := ReaderT Context <| StateListM State

/-- Return all values from `x` in an array, together with their scores. -/
private def M.run (unify : Bool) (config : WhnfCoreConfig) (x : M (Trie α)) :
private def M.run (unify : Bool) (x : M (Trie α)) :
Array (Array α × Nat) :=
((x.run { unify, config }).run {}).toArray.map (fun (t, s) => (t.values!, s.score))
((x.run { unify }).run {}).toArray.map (fun (t, s) => (t.values!, s.score))

/-- Increment the score by `n`. -/
private def incrementScore (n : Nat) : M Unit :=
Expand Down Expand Up @@ -1076,7 +1074,7 @@ mutual
end

private partial def getMatchWithScoreAux (d : RefinedDiscrTree α) (e : DTExpr) (unify : Bool)
(config : WhnfCoreConfig) (allowRootStar : Bool := false) : Array (Array α × Nat) := (do
(allowRootStar : Bool := false) : Array (Array α × Nat) := (do
if e matches .star _ then
guard allowRootStar
d.root.foldl (init := failure) fun x k c => (do
Expand All @@ -1090,7 +1088,7 @@ private partial def getMatchWithScoreAux (d : RefinedDiscrTree α) (e : DTExpr)
guard allowRootStar
let some c := d.root.find? (.star 0) | failure
return c
).run unify config
).run unify

end GetUnify

Expand All @@ -1106,22 +1104,22 @@ This is for when you don't want to instantiate metavariables in `e`.
If `allowRootStar := false`, then we don't allow `e` or the matched key in `d`
to be a star pattern. -/
def getMatchWithScore (d : RefinedDiscrTree α) (e : Expr) (unify : Bool)
(config : WhnfCoreConfig) (allowRootStar : Bool := false) : MetaM (Array (Array α × Nat)) := do
let e ← mkDTExpr e config
let result := GetUnify.getMatchWithScoreAux d e unify config allowRootStar
(allowRootStar : Bool := false) : MetaM (Array (Array α × Nat)) := do
let e ← mkDTExpr e
let result := GetUnify.getMatchWithScoreAux d e unify allowRootStar
return result.qsort (·.2 > ·.2)

/-- Similar to `getMatchWithScore`, but also returns matches with prefixes of `e`.
We store the score, followed by the number of ignored arguments. -/
partial def getMatchWithScoreWithExtra (d : RefinedDiscrTree α) (e : Expr) (unify : Bool)
(config : WhnfCoreConfig) (allowRootStar : Bool := false) :
(allowRootStar : Bool := false) :
MetaM (Array (Array α × Nat × Nat)) := do
let result ← go e 0
return result.qsort (·.2.1 > ·.2.1)
where
/-- go -/
go (e : Expr) (numIgnored : Nat) : MetaM (Array (Array α × Nat × Nat)) := do
let result ← getMatchWithScore d e unify config allowRootStar
let result ← getMatchWithScore d e unify allowRootStar
let result := result.map fun (a, b) => (a, b, numIgnored)
match e with
| .app e _ => return (← go e (numIgnored + 1)) ++ result
Expand Down
8 changes: 3 additions & 5 deletions Mathlib/Tactic/FunProp/Theorems.lean
Original file line number Diff line number Diff line change
Expand Up @@ -301,13 +301,11 @@ type of theorem it is. -/
def getTheoremFromConst (declName : Name) (prio : Nat := eval_prio default) : MetaM Theorem := do
let info ← getConstInfo declName
forallTelescope info.type fun xs b => do

let .some (decl,f) ← getFunProp? b
| throwError "unrecognized function property `{← ppExpr b}`"
let funPropName := decl.funPropName

let fData? ← getFunctionData? f defaultUnfoldPred {zeta := false}

let fData? ←
withConfig (fun cfg => { cfg with zeta := false}) <| getFunctionData? f defaultUnfoldPred
if let .some thmArgs ← detectLambdaTheoremArgs (← fData?.get) xs then
return .lam {
funPropName := funPropName
Expand Down Expand Up @@ -338,7 +336,7 @@ def getTheoremFromConst (declName : Name) (prio : Nat := eval_prio default) : Me
}
| .fvar .. =>
let (_,_,b') ← forallMetaTelescope info.type
let keys := ← RefinedDiscrTree.mkDTExprs b' {} false
let keys := ← RefinedDiscrTree.mkDTExprs b' false
let thm : GeneralTheorem := {
funPropName := funPropName
thmName := declName
Expand Down

0 comments on commit 7faa75a

Please sign in to comment.