Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: lemma selection improvements to to rw? and lazy discriminator tree #3769

Merged
merged 6 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 95 additions & 56 deletions src/Lean/Meta/LazyDiscrTree.lean
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,35 @@ private def newTrie [Monad m] [MonadState (Array (Trie α)) m] (e : LazyEntry α
private def addLazyEntryToTrie (i:TrieIndex) (e : LazyEntry α) : MatchM α Unit :=
modify (·.modify i (·.pushPending e))

private def evalLazyEntry (config : WhnfCoreConfig)
(p : Array α × TrieIndex × HashMap Key TrieIndex)
(entry : LazyEntry α)
: MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
let (values, starIdx, children) := p
let (todo, lctx, v) := entry
if todo.isEmpty then
let values := values.push v
pure (values, starIdx, children)
else
let e := todo.back
let todo := todo.pop
let (k, todo) ← withLCtx lctx.1 lctx.2 $ pushArgs false todo e config
if k == .star then
if starIdx = 0 then
let starIdx ← newTrie (todo, lctx, v)
pure (values, starIdx, children)
else
addLazyEntryToTrie starIdx (todo, lctx, v)
pure (values, starIdx, children)
else
match children.find? k with
| none =>
let children := children.insert k (← newTrie (todo, lctx, v))
pure (values, starIdx, children)
| some idx =>
addLazyEntryToTrie idx (todo, lctx, v)
pure (values, starIdx, children)

/--
This evaluates all lazy entries in a trie and updates `values`, `starIdx`, and `children`
accordingly.
Expand All @@ -453,34 +482,10 @@ private partial def evalLazyEntries (config : WhnfCoreConfig)
(values : Array α) (starIdx : TrieIndex) (children : HashMap Key TrieIndex)
(entries : Array (LazyEntry α)) :
MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
let rec iter values starIdx children (i : Nat) : MatchM α _ := do
if p : i < entries.size then
let (todo, lctx, v) := entries[i]
if todo.isEmpty then
let values := values.push v
iter values starIdx children (i+1)
else
let e := todo.back
let todo := todo.pop
let (k, todo) ← withLCtx lctx.1 lctx.2 $ pushArgs false todo e config
if k == .star then
if starIdx = 0 then
let starIdx ← newTrie (todo, lctx, v)
iter values starIdx children (i+1)
else
addLazyEntryToTrie starIdx (todo, lctx, v)
iter values starIdx children (i+1)
else
match children.find? k with
| none =>
let children := children.insert k (← newTrie (todo, lctx, v))
iter values starIdx children (i+1)
| some idx =>
addLazyEntryToTrie idx (todo, lctx, v)
iter values starIdx children (i+1)
else
pure (values, starIdx, children)
iter values starIdx children 0
let mut values := values
let mut starIdx := starIdx
let mut children := children
entries.foldlM (init := (values, starIdx, children)) (evalLazyEntry config)

private def evalNode (c : TrieIndex) :
MatchM α (Array α × TrieIndex × HashMap Key TrieIndex) := do
Expand Down Expand Up @@ -590,7 +595,7 @@ private partial def getMatchLoop (todo : Array Expr) (score : Nat) (c : TrieInde
and there is an edge for `k` and `k != Key.star`. -/
let visitStar (result : MatchResult α) : MatchM α (MatchResult α) :=
if star != 0 then
getMatchLoop todo score star result
getMatchLoop todo (score + 1) star result
else
return result
let visitNonStar (k : Key) (args : Array Expr) (result : MatchResult α) :=
Expand All @@ -617,26 +622,26 @@ private def getStarResult (root : Lean.HashMap Key TrieIndex) : MatchM α (Match
pure <| {}
| some idx => do
let (vs, _) ← evalNode idx
pure <| ({} : MatchResult α).push 0 vs
pure <| ({} : MatchResult α).push (score := 1) vs

private def getMatchRoot (r : Lean.HashMap Key TrieIndex) (k : Key) (args : Array Expr)
(result : MatchResult α) : MatchM α (MatchResult α) :=
match r.find? k with
| none => pure result
| some c => getMatchLoop args 1 c result
| some c => getMatchLoop args (score := 1) c result

/--
Find values that match `e` in `root`.
-/
private def getMatchCore (root : Lean.HashMap Key TrieIndex) (e : Expr) :
MatchM α (MatchResult α) := do
let result ← getStarResult root
let (k, args) ← MatchClone.getMatchKeyArgs e (root := true) (←read)
let (k, args) ← MatchClone.getMatchKeyArgs e (root := true) (← read)
match k with
| .star => return result
/- See note about "dep-arrow vs arrow" at `getMatchLoop` -/
| .arrow =>
getMatchRoot root k args (←getMatchRoot root .other #[] result)
getMatchRoot root k args (← getMatchRoot root .other #[] result)
| _ =>
getMatchRoot root k args result

Expand Down Expand Up @@ -756,7 +761,28 @@ structure Cache where

def Cache.empty (ngen : NameGenerator) : Cache := { ngen := ngen, core := {}, meta := {} }

def matchPrefix (s : String) (pre : String) :=
s.startsWith pre && (s |>.drop pre.length |>.all Char.isDigit)

def isInternalDetail : Name → Bool
| .str p s =>
s.startsWith "_"
|| matchPrefix s "eq_"
|| matchPrefix s "match_"
|| matchPrefix s "proof_"
|| p.isInternalOrNum
| .num _ _ => true
| p => p.isInternalOrNum

def blacklistInsertion (env : Environment) (declName : Name) : Bool :=
!allowCompletion env declName
|| declName == ``sorryAx
|| isInternalDetail declName
|| (declName matches .str _ "inj")
|| (declName matches .str _ "noConfusionType")

private def addConstImportData
(cctx : Core.Context)
(env : Environment)
(modName : Name)
(d : ImportData)
Expand All @@ -765,16 +791,12 @@ private def addConstImportData
(act : Name → ConstantInfo → MetaM (Array (InitEntry α)))
(name : Name) (constInfo : ConstantInfo) : BaseIO (PreDiscrTree α) := do
if constInfo.isUnsafe then return tree
if !allowCompletion env name then return tree
if blacklistInsertion env name then return tree
let { ngen, core := core_cache, meta := meta_cache } ← cacheRef.get
let mstate : Meta.State := { cache := meta_cache }
cacheRef.set (Cache.empty ngen)
let ctx : Meta.Context := { config := { transparency := .reducible } }
let cm := (act name constInfo).run ctx mstate
let cctx : Core.Context := {
fileName := default,
fileMap := default
}
let cstate : Core.State := {env, cache := core_cache, ngen}
match ←(cm.run cctx cstate).toBaseIO with
| .ok ((a, ms), cs) =>
Expand Down Expand Up @@ -818,7 +840,9 @@ private def toFlat (d : ImportData) (tree : PreDiscrTree α) :
let de ← d.errors.swap #[]
pure ⟨tree, de⟩

private partial def loadImportedModule (env : Environment)
private partial def loadImportedModule
(cctx : Core.Context)
(env : Environment)
(act : Name → ConstantInfo → MetaM (Array (InitEntry α)))
(d : ImportData)
(cacheRef : IO.Ref Cache)
Expand All @@ -829,12 +853,12 @@ private partial def loadImportedModule (env : Environment)
if h : i < mdata.constNames.size then
let name := mdata.constNames[i]
let constInfo := mdata.constants[i]!
let tree ← addConstImportData env mname d cacheRef tree act name constInfo
loadImportedModule env act d cacheRef tree mname mdata (i+1)
let tree ← addConstImportData cctx env mname d cacheRef tree act name constInfo
loadImportedModule cctx env act d cacheRef tree mname mdata (i+1)
else
pure tree

private def createImportedEnvironmentSeq (ngen : NameGenerator) (env : Environment)
private def createImportedEnvironmentSeq (cctx : Core.Context) (ngen : NameGenerator) (env : Environment)
(act : Name → ConstantInfo → MetaM (Array (InitEntry α)))
(start stop : Nat) : BaseIO (InitResults α) := do
let cacheRef ← IO.mkRef (Cache.empty ngen)
Expand All @@ -843,7 +867,7 @@ private def createImportedEnvironmentSeq (ngen : NameGenerator) (env : Environme
if start < stop then
let mname := env.header.moduleNames[start]!
let mdata := env.header.moduleData[start]!
let tree ← loadImportedModule env act d cacheRef tree mname mdata
let tree ← loadImportedModule cctx env act d cacheRef tree mname mdata
go d cacheRef tree (start+1) stop
else
toFlat d tree
Expand All @@ -860,6 +884,7 @@ def getChildNgen [Monad M] [MonadNameGenerator M] : M NameGenerator := do
pure cngen

def createLocalPreDiscrTree
(cctx : Core.Context)
(ngen : NameGenerator)
(env : Environment)
(d : ImportData)
Expand All @@ -868,18 +893,22 @@ def createLocalPreDiscrTree
let modName := env.header.mainModule
let cacheRef ← IO.mkRef (Cache.empty ngen)
let act (t : PreDiscrTree α) (n : Name) (c : ConstantInfo) : BaseIO (PreDiscrTree α) :=
addConstImportData env modName d cacheRef t act n c
addConstImportData cctx env modName d cacheRef t act n c
let r ← (env.constants.map₂.foldlM (init := {}) act : BaseIO (PreDiscrTree α))
pure r

def dropKeys (t : LazyDiscrTree α) (keys : List (List LazyDiscrTree.Key)) : MetaM (LazyDiscrTree α) := do
keys.foldlM (init := t) (·.dropKey ·)

def logImportFailure [Monad m] [MonadLog m] [AddMessageContext m] [MonadOptions m] (f : ImportFailure) : m Unit :=
logError m!"Processing failure with {f.const} in {f.module}:\n {f.exception.toMessageData}"

/-- Create a discriminator tree for imported environment. -/
def createImportedDiscrTree (ngen : NameGenerator) (env : Environment)
def createImportedDiscrTree [Monad m] [MonadLog m] [AddMessageContext m] [MonadOptions m] [MonadLiftT BaseIO m]
(cctx : Core.Context) (ngen : NameGenerator) (env : Environment)
(act : Name → ConstantInfo → MetaM (Array (InitEntry α)))
(constantsPerTask : Nat := 1000) :
EIO Exception (LazyDiscrTree α) := do
m (LazyDiscrTree α) := do
let n := env.header.moduleData.size
let rec
/-- Allocate constants to tasks according to `constantsPerTask`. -/
Expand All @@ -889,37 +918,48 @@ def createImportedDiscrTree (ngen : NameGenerator) (env : Environment)
let cnt := cnt + mdata.constants.size
if cnt > constantsPerTask then
let (childNGen, ngen) := ngen.mkChild
let t ← createImportedEnvironmentSeq childNGen env act start (idx+1) |>.asTask
let t ← liftM <| createImportedEnvironmentSeq cctx childNGen env act start (idx+1) |>.asTask
go ngen (tasks.push t) (idx+1) 0 (idx+1)
else
go ngen tasks start cnt (idx+1)
else
if start < n then
let (childNGen, _) := ngen.mkChild
tasks.push <$> (createImportedEnvironmentSeq childNGen env act start n).asTask
let t ← (createImportedEnvironmentSeq cctx childNGen env act start n).asTask
pure (tasks.push t)
else
pure tasks
termination_by env.header.moduleData.size - idx
let tasks ← go ngen #[] 0 0 0
let r := combineGet default tasks
if p : r.errors.size > 0 then
throw r.errors[0].exception
r.errors.forM logImportFailure
pure <| r.tree.toLazy

/-- Creates the core context used for initializing a tree using the current context. -/
private def createTreeCtx (ctx : Core.Context) : Core.Context := {
fileName := ctx.fileName,
fileMap := ctx.fileMap,
options := ctx.options,
maxRecDepth := ctx.maxRecDepth,
maxHeartbeats := 0,
ref := ctx.ref,
}

def findImportMatches
(ext : EnvExtension (IO.Ref (Option (LazyDiscrTree α))))
(addEntry : Name → ConstantInfo → MetaM (Array (InitEntry α)))
(droppedKeys : List (List LazyDiscrTree.Key) := [])
(constantsPerTask : Nat := 1000)
(ty : Expr) : MetaM (MatchResult α) := do
let cctx ← (read : CoreM Core.Context)
let ngen ← getNGen
let (cNGen, ngen) := ngen.mkChild
setNGen ngen
let dummy : IO.Ref (Option (LazyDiscrTree α)) ← IO.mkRef none
let ref := @EnvExtension.getState _ ⟨dummy⟩ ext (←getEnv)
let importTree ← (←ref.get).getDM $ do
profileitM Exception "lazy discriminator import initialization" (←getOptions) $ do
let t ← createImportedDiscrTree cNGen (←getEnv) addEntry
let t ← createImportedDiscrTree (createTreeCtx cctx) cNGen (←getEnv) addEntry
(constantsPerTask := constantsPerTask)
dropKeys t droppedKeys
let (importCandidates, importTree) ← importTree.getMatch ty
Expand All @@ -943,10 +983,9 @@ def createModuleDiscrTree
let env ← getEnv
let ngen ← getChildNgen
let d ← ImportData.new
let t ← createLocalPreDiscrTree ngen env d entriesForConst
let errors ← d.errors.get
if p : errors.size > 0 then
throw errors[0].exception
let ctx ← read
let t ← createLocalPreDiscrTree ctx ngen env d entriesForConst
(← d.errors.get).forM logImportFailure
pure <| t.toLazy

/--
Expand All @@ -966,7 +1005,7 @@ Returns candidates from this module in this module that match the expression.
this module's definitions.
-/
def findModuleMatches (moduleRef : ModuleDiscrTreeRef α) (ty : Expr) : MetaM (MatchResult α) := do
profileitM Exception "lazy discriminator local search" (←getOptions) $ do
profileitM Exception "lazy discriminator local search" (← getOptions) $ do
let discrTree ← moduleRef.ref.get
let (localCandidates, localTree) ← discrTree.getMatch ty
moduleRef.ref.set localTree
Expand Down
Loading
Loading