Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: ArgsPacker #3621

Merged
merged 6 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/Lean/Elab/PreDefinition/WF/Eqns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Lean.Meta.Tactic.Rewrite
import Lean.Meta.Tactic.Split
import Lean.Elab.PreDefinition.Basic
import Lean.Elab.PreDefinition.Eqns
import Lean.Meta.ArgsPacker.Basic

namespace Lean.Elab.WF
open Meta
Expand All @@ -17,6 +18,7 @@ structure EqnInfo extends EqnInfoCore where
declNames : Array Name
declNameNonRec : Name
fixedPrefixSize : Nat
argsPacker : ArgsPacker
deriving Inhabited

private partial def deltaLHSUntilFix (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
Expand Down Expand Up @@ -129,7 +131,8 @@ def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) :=

builtin_initialize eqnInfoExt : MapDeclarationExtension EqnInfo ← mkMapDeclarationExtension

def registerEqnsInfo (preDefs : Array PreDefinition) (declNameNonRec : Name) (fixedPrefixSize : Nat) : MetaM Unit := do
def registerEqnsInfo (preDefs : Array PreDefinition) (declNameNonRec : Name) (fixedPrefixSize : Nat)
(argsPacker : ArgsPacker) : MetaM Unit := do
/-
See issue #2327.
Remark: we could do better for mutual declarations that mix theorems and definitions. However, this is a rare
Expand All @@ -140,7 +143,8 @@ def registerEqnsInfo (preDefs : Array PreDefinition) (declNameNonRec : Name) (fi
let declNames := preDefs.map (·.declName)
modifyEnv fun env =>
preDefs.foldl (init := env) fun env preDef =>
eqnInfoExt.insert env preDef.declName { preDef with declNames, declNameNonRec, fixedPrefixSize }
eqnInfoExt.insert env preDef.declName { preDef with
declNames, declNameNonRec, fixedPrefixSize, argsPacker }

def getEqnsFor? (declName : Name) : MetaM (Option (Array Name)) := do
if let some info := eqnInfoExt.find? (← getEnv) declName then
Expand Down
16 changes: 8 additions & 8 deletions src/Lean/Elab/PreDefinition/WF/Fix.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ import Lean.Util.HasConstCache
import Lean.Meta.Match.Match
import Lean.Meta.Tactic.Simp.Main
import Lean.Meta.Tactic.Cleanup
import Lean.Meta.ArgsPacker
import Lean.Elab.Tactic.Basic
import Lean.Elab.RecAppSyntax
import Lean.Elab.PreDefinition.Basic
import Lean.Elab.PreDefinition.Structural.Basic
import Lean.Elab.PreDefinition.Structural.BRecOn
import Lean.Elab.PreDefinition.WF.PackMutual
import Lean.Data.Array

namespace Lean.Elab.WF
Expand Down Expand Up @@ -172,19 +172,19 @@ know which function is making the call.
The close coupling with how arguments are packed and termination goals look like is not great,
but it works for now.
-/
def groupGoalsByFunction (numFuncs : Nat) (goals : Array MVarId) : MetaM (Array (Array MVarId)) := do
def groupGoalsByFunction (argsPacker : ArgsPacker) (numFuncs : Nat) (goals : Array MVarId) : MetaM (Array (Array MVarId)) := do
let mut r := mkArray numFuncs #[]
for goal in goals do
let (.mdata _ (.app _ param)) ← goal.getType
| throwError "MVar does not look like like a recursive call"
let (funidx, _) ← unpackMutualArg numFuncs param
let (funidx, _) ← argsPacker.unpack param
r := r.modify funidx (·.push goal)
return r

def solveDecreasingGoals (decrTactics : Array (Option DecreasingBy)) (value : Expr) : MetaM Expr := do
def solveDecreasingGoals (argsPacker : ArgsPacker) (decrTactics : Array (Option DecreasingBy)) (value : Expr) : MetaM Expr := do
let goals ← getMVarsNoDelayed value
let goals ← assignSubsumed goals
let goalss ← groupGoalsByFunction decrTactics.size goals
let goalss ← groupGoalsByFunction argsPacker decrTactics.size goals
for goals in goalss, decrTactic? in decrTactics do
Lean.Elab.Term.TermElabM.run' do
match decrTactic? with
Expand All @@ -205,8 +205,8 @@ def solveDecreasingGoals (decrTactics : Array (Option DecreasingBy)) (value : Ex
Term.reportUnsolvedGoals remainingGoals
instantiateMVars value

def mkFix (preDef : PreDefinition) (prefixArgs : Array Expr) (wfRel : Expr)
(decrTactics : Array (Option DecreasingBy)) : TermElabM Expr := do
def mkFix (preDef : PreDefinition) (prefixArgs : Array Expr) (argsPacker : ArgsPacker)
(wfRel : Expr) (decrTactics : Array (Option DecreasingBy)) : TermElabM Expr := do
let type ← instantiateForall preDef.type prefixArgs
let (wfFix, varName) ← forallBoundedTelescope type (some 1) fun x type => do
let x := x[0]!
Expand All @@ -229,7 +229,7 @@ def mkFix (preDef : PreDefinition) (prefixArgs : Array Expr) (wfRel : Expr)
let val := preDef.value.beta (prefixArgs.push x)
let val ← processSumCasesOn x F val fun x F val => do
processPSigmaCasesOn x F val (replaceRecApps preDef.declName prefixArgs.size)
let val ← solveDecreasingGoals decrTactics val
let val ← solveDecreasingGoals argsPacker decrTactics val
mkLambdaFVars prefixArgs (mkApp wfFix (← mkLambdaFVars #[x, F] val))

end Lean.Elab.WF
44 changes: 26 additions & 18 deletions src/Lean/Elab/PreDefinition/WF/GuessLex.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ import Lean.Meta.Match.MatcherApp.Transform
import Lean.Meta.Tactic.Cleanup
import Lean.Meta.Tactic.Refl
import Lean.Meta.Tactic.TryThis
import Lean.Meta.ArgsPacker
import Lean.Elab.Quotation
import Lean.Elab.RecAppSyntax
import Lean.Elab.PreDefinition.Basic
import Lean.Elab.PreDefinition.Structural.Basic
import Lean.Elab.PreDefinition.WF.TerminationHint
import Lean.Elab.PreDefinition.WF.PackMutual
import Lean.Data.Array


Expand Down Expand Up @@ -84,11 +84,11 @@ def originalVarNames (preDef : PreDefinition) : MetaM (Array Name) := do
lambdaTelescope preDef.value fun xs _ => xs.mapM (·.fvarId!.getUserName)

/--
Given the original paramter names from `originalVarNames`, remove the fixed prefix and find
Given the original parameter names from `originalVarNames`, find
good variable names to be used when talking about termination arguments:
Use user-given parameter names if present; use x1...xn otherwise.

The names ought to accessible (no macro scopes) and new names fresh wrt to the current environment,
The names ought to accessible (no macro scopes) and fresh wrt to the current environment,
so that with `showInferredTerminationBy` we can print them to the user reliably.
We do that by appending `'` as needed.

Expand All @@ -97,8 +97,7 @@ shadow each other, and the guessed relation refers to the wrong one. In that
case, the user gets to keep both pieces (and may have to rename variables).
-/
partial
def naryVarNames (fixedPrefixSize : Nat) (xs : Array Name) : MetaM (Array Name) := do
let xs := xs.extract fixedPrefixSize xs.size
def naryVarNames (xs : Array Name) : MetaM (Array Name) := do
let mut ns : Array Name := #[]
for h : i in [:xs.size] do
let n := xs[i]
Expand Down Expand Up @@ -264,8 +263,8 @@ def filterSubsumed (rcs : Array RecCallWithContext ) : Array RecCallWithContext
/-- Traverse a unary PreDefinition, and returns a `WithRecCall` closure for each recursive
call site.
-/
def collectRecCalls (unaryPreDef : PreDefinition) (fixedPrefixSize : Nat) (arities : Array Nat)
: MetaM (Array RecCallWithContext) := withoutModifyingState do
def collectRecCalls (unaryPreDef : PreDefinition) (fixedPrefixSize : Nat)
(argsPacker : ArgsPacker) : MetaM (Array RecCallWithContext) := withoutModifyingState do
addAsAxiom unaryPreDef
lambdaTelescope unaryPreDef.value fun xs body => do
unless xs.size == fixedPrefixSize + 1 do
Expand All @@ -277,8 +276,8 @@ def collectRecCalls (unaryPreDef : PreDefinition) (fixedPrefixSize : Nat) (ariti
throwError "Insufficient arguments in recursive call"
let arg := args[fixedPrefixSize]!
trace[Elab.definition.wf] "collectRecCalls: {unaryPreDef.declName} ({param}) → {unaryPreDef.declName} ({arg})"
let (caller, params) ← unpackArg arities param
let (callee, args) ← unpackArg arities arg
let (caller, params) ← argsPacker.unpack param
let (callee, args) ← argsPacker.unpack arg
RecCallWithContext.create (← getRef) caller params callee args

/-- A `GuessLexRel` described how a recursive call affects a measure; whether it
Expand Down Expand Up @@ -738,12 +737,14 @@ Try to find a lexicographic ordering of the arguments for which the recursive de
terminates. See the module doc string for a high-level overview.
-/
def guessLex (preDefs : Array PreDefinition) (unaryPreDef : PreDefinition)
(fixedPrefixSize : Nat) :
(fixedPrefixSize : Nat) (argsPacker : ArgsPacker) :
MetaM TerminationWF := do
let extraParamss := preDefs.map (·.termination.extraParams)
let arities := argsPacker.varNamess.map (·.size)
let userVarNamess ← argsPacker.varNamess.mapM (naryVarNames ·)
-- with fixed prefix, used to qualify the measure in buildTermWf.
let originalVarNamess ← preDefs.mapM originalVarNames
let varNamess ← originalVarNamess.mapM (naryVarNames fixedPrefixSize ·)
let arities := varNamess.map (·.size)
trace[Elab.definition.wf] "varNames is: {varNamess}"
trace[Elab.definition.wf] "varNames is: {userVarNamess}"

let forbiddenArgs ← preDefs.mapM (getForbiddenByTrivialSizeOf fixedPrefixSize)
let needsNoSizeOf ←
Expand All @@ -758,23 +759,30 @@ def guessLex (preDefs : Array PreDefinition) (unaryPreDef : PreDefinition)

-- If there is only one plausible measure, use that
if let #[solution] := measures then
let wf ← buildTermWF originalVarNamess varNamess needsNoSizeOf #[solution]
let wf ← buildTermWF originalVarNamess userVarNamess needsNoSizeOf #[solution]
reportWF preDefs wf
return wf

-- Collect all recursive calls and extract their context
let recCalls ← collectRecCalls unaryPreDef fixedPrefixSize arities
let recCalls ← collectRecCalls unaryPreDef fixedPrefixSize argsPacker
let recCalls := filterSubsumed recCalls
let rcs ← recCalls.mapM (RecCallCache.mk (preDefs.map (·.termination.decreasingBy?)) ·)
let callMatrix := rcs.map (inspectCall ·)

match ← liftMetaM <| solve measures callMatrix with
| .some solution => do
let wf ← buildTermWF originalVarNamess varNamess needsNoSizeOf solution
reportWF preDefs wf
let wf ← buildTermWF originalVarNamess userVarNamess needsNoSizeOf solution

let wf' := trimTermWF extraParamss wf
for preDef in preDefs, term in wf' do
if showInferredTerminationBy.get (← getOptions) then
logInfoAt preDef.ref m!"Inferred termination argument:\n{← term.unexpand}"
if let some ref := preDef.termination.terminationBy?? then
Tactic.TryThis.addSuggestion ref (← term.unexpand)

return wf
| .none =>
let explanation ← explainFailure (preDefs.map (·.declName)) varNamess rcs
let explanation ← explainFailure (preDefs.map (·.declName)) userVarNamess rcs
Lean.throwError <| "Could not find a decreasing measure.\n" ++
explanation ++ "\n" ++
"Please use `termination_by` to specify a decreasing measure."
62 changes: 35 additions & 27 deletions src/Lean/Elab/PreDefinition/WF/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ Authors: Leonardo de Moura
prelude
import Lean.Elab.PreDefinition.Basic
import Lean.Elab.PreDefinition.WF.TerminationHint
import Lean.Elab.PreDefinition.WF.PackDomain
import Lean.Elab.PreDefinition.WF.PackMutual
import Lean.Elab.PreDefinition.WF.Preprocess
import Lean.Elab.PreDefinition.WF.Rel
Expand All @@ -19,29 +18,15 @@ namespace Lean.Elab
open WF
open Meta

private partial def addNonRecPreDefs (preDefs : Array PreDefinition) (preDefNonRec : PreDefinition) (fixedPrefixSize : Nat) : TermElabM Unit := do
private partial def addNonRecPreDefs (fixedPrefixSize : Nat) (argsPacker : ArgsPacker) (preDefs : Array PreDefinition) (preDefNonRec : PreDefinition) : TermElabM Unit := do
let us := preDefNonRec.levelParams.map mkLevelParam
let all := preDefs.toList.map (·.declName)
for fidx in [:preDefs.size] do
let preDef := preDefs[fidx]!
let value ← lambdaTelescope preDef.value fun xs _ => do
let packedArgs : Array Expr := xs[fixedPrefixSize:]
let mkProd (type : Expr) : MetaM Expr := do
mkUnaryArg type packedArgs
let rec mkSum (i : Nat) (type : Expr) : MetaM Expr := do
if i == preDefs.size - 1 then
mkProd type
else
(← whnfD type).withApp fun f args => do
assert! args.size == 2
if i == fidx then
return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0]! args[1]! (← mkProd args[0]!)
else
let r ← mkSum (i+1) args[1]!
return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0]! args[1]! r
let Expr.forallE _ domain _ _ := (← instantiateForall preDefNonRec.type xs[:fixedPrefixSize]) | unreachable!
let arg ← mkSum 0 domain
mkLambdaFVars xs (mkApp (mkAppN (mkConst preDefNonRec.declName us) xs[:fixedPrefixSize]) arg)
let value ← forallBoundedTelescope preDef.type (some fixedPrefixSize) fun xs _ => do
let value := mkAppN (mkConst preDefNonRec.declName us) xs
let value ← argsPacker.curryProj value fidx
mkLambdaFVars xs value
trace[Elab.definition.wf] "{preDef.declName} := {value}"
addNonRec { preDef with value } (applyAttrAfterCompilation := false) (all := all)

Expand Down Expand Up @@ -81,23 +66,44 @@ private def isOnlyOneUnaryDef (preDefs : Array PreDefinition) (fixedPrefixSize :
else
return false

/--
Collect the names of the varying variables (after the fixed prefix); this also determines the
arity for the well-founded translations, and is turned into an `ArgsPacker`.
We use the term to determine the arity, but take the name from the type, for better names in the
```
fun : (n : Nat) → Nat | 0 => 0 | n+1 => fun n
```
idiom.
-/
def varyingVarNames (fixedPrefixSize : Nat) (preDef : PreDefinition) : MetaM (Array Name) := do
-- We take the arity from the term, but the names from the types
let arity ← lambdaTelescope preDef.value fun xs _ => return xs.size
assert! fixedPrefixSize ≤ arity
if arity = fixedPrefixSize then
throwError "well-founded recursion cannot be used, '{preDef.declName}' does not take any (non-fixed) arguments"
forallBoundedTelescope preDef.type arity fun xs _ => do
assert! xs.size = arity
let xs : Array Expr := xs[fixedPrefixSize:]
xs.mapM (·.fvarId!.getUserName)

def wfRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let preDefs ← preDefs.mapM fun preDef =>
return { preDef with value := (← preprocess preDef.value) }
let (unaryPreDef, fixedPrefixSize) ← withoutModifyingEnv do
let (fixedPrefixSize, argsPacker, unaryPreDef) ← withoutModifyingEnv do
for preDef in preDefs do
addAsAxiom preDef
let fixedPrefixSize ← getFixedPrefix preDefs
trace[Elab.definition.wf] "fixed prefix: {fixedPrefixSize}"
let varNamess ← preDefs.mapM (varyingVarNames fixedPrefixSize ·)
let argsPacker := { varNamess }
let preDefsDIte ← preDefs.mapM fun preDef => return { preDef with value := (← iteToDIte preDef.value) }
let unaryPreDefs ← packDomain fixedPrefixSize preDefsDIte
return (← packMutual fixedPrefixSize preDefs unaryPreDefs, fixedPrefixSize)
return (fixedPrefixSize, argsPacker, ← packMutual fixedPrefixSize argsPacker preDefsDIte)

let wf ← do
let (preDefsWith, preDefsWithout) := preDefs.partition (·.termination.terminationBy?.isSome)
if preDefsWith.isEmpty then
-- No termination_by anywhere, so guess one
guessLex preDefs unaryPreDef fixedPrefixSize
guessLex preDefs unaryPreDef fixedPrefixSize argsPacker
else if preDefsWithout.isEmpty then
pure <| preDefsWith.map (·.termination.terminationBy?.get!)
else
Expand All @@ -109,12 +115,14 @@ def wfRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do

let preDefNonRec ← forallBoundedTelescope unaryPreDef.type fixedPrefixSize fun prefixArgs type => do
let type ← whnfForall type
unless type.isForall do
throwError "wfRecursion: expected unary function type: {type}"
let packedArgType := type.bindingDomain!
elabWFRel preDefs unaryPreDef.declName fixedPrefixSize packedArgType wf fun wfRel => do
trace[Elab.definition.wf] "wfRel: {wfRel}"
let (value, envNew) ← withoutModifyingEnv' do
addAsAxiom unaryPreDef
let value ← mkFix unaryPreDef prefixArgs wfRel (preDefs.map (·.termination.decreasingBy?))
let value ← mkFix unaryPreDef prefixArgs argsPacker wfRel (preDefs.map (·.termination.decreasingBy?))
eraseRecAppSyntaxExpr value
/- `mkFix` invokes `decreasing_tactic` which may add auxiliary theorems to the environment. -/
let value ← unfoldDeclsFrom envNew value
Expand All @@ -126,12 +134,12 @@ def wfRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
else
withEnableInfoTree false do
addNonRec preDefNonRec (applyAttrAfterCompilation := false)
addNonRecPreDefs preDefs preDefNonRec fixedPrefixSize
addNonRecPreDefs fixedPrefixSize argsPacker preDefs preDefNonRec
-- We create the `_unsafe_rec` before we abstract nested proofs.
-- Reason: the nested proofs may be referring to the _unsafe_rec.
addAndCompilePartialRec preDefs
let preDefs ← preDefs.mapM (abstractNestedProofs ·)
registerEqnsInfo preDefs preDefNonRec.declName fixedPrefixSize
registerEqnsInfo preDefs preDefNonRec.declName fixedPrefixSize argsPacker
for preDef in preDefs do
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation

Expand Down
Loading
Loading