Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: periodically refresh semantic tokens #3691

Merged
merged 3 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/Lean/Data/Lsp/Ipc.lean
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ def readRequestAs (expectedMethod : String) (α) [FromJson α] : IpcM (Request
(←stdout).readLspRequestAs expectedMethod α

/--
Reads response, discarding notifications in between. This function is meant
purely for testing where we use `collectDiagnostics` explicitly if we do care
about such notifications. -/
Reads response, discarding notifications and server-to-client requests in between.
This function is meant purely for testing where we use `collectDiagnostics` explicitly
if we do care about such notifications.
-/
partial def readResponseAs (expectedID : RequestID) (α) [FromJson α] :
IpcM (Response α) := do
let m ← (←stdout).readLspMessage
Expand All @@ -79,7 +80,8 @@ partial def readResponseAs (expectedID : RequestID) (α) [FromJson α] :
else
throw $ userError s!"Expected id {expectedID}, got id {id}"
| .notification .. => readResponseAs expectedID α
| _ => throw $ userError s!"Expected JSON-RPC response, got: '{(toJson m).compress}'"
| .request .. => readResponseAs expectedID α
| .responseError .. => throw $ userError s!"Expected JSON-RPC response, got: '{(toJson m).compress}'"

def waitForExit : IpcM UInt32 := do
(←read).wait
Expand Down
48 changes: 36 additions & 12 deletions src/Lean/Server/FileWorker.lean
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,22 @@ open JsonRpc
structure WorkerContext where
/-- Synchronized output channel for LSP messages. Notifications for outdated versions are
discarded on read. -/
chanOut : IO.Channel JsonRpc.Message
chanOut : IO.Channel JsonRpc.Message
/--
Latest document version received by the client, used for filtering out notifications from
previous versions.
-/
maxDocVersionRef : IO.Ref Int
hLog : FS.Stream
initParams : InitializeParams
processor : Parser.InputContext → BaseIO Lean.Language.Lean.InitialSnapshot
clientHasWidgets : Bool
maxDocVersionRef : IO.Ref Int
freshRequestIdRef : IO.Ref Int
hLog : FS.Stream
initParams : InitializeParams
processor : Parser.InputContext → BaseIO Lean.Language.Lean.InitialSnapshot
clientHasWidgets : Bool
/--
Options defined on the worker cmdline (i.e. not including options from `setup-file`), used for
context-free tasks such as editing delay.
-/
cmdlineOpts : Options
cmdlineOpts : Options

/-! # Asynchronous snapshot elaboration -/

Expand Down Expand Up @@ -222,7 +223,6 @@ structure WorkerState where
/-- A map of RPC session IDs. We allow asynchronous elab tasks and request handlers
to modify sessions. A single `Ref` ensures atomic transactions. -/
rpcSessions : RBMap UInt64 (IO.Ref RpcSession) compare
freshRequestID : Nat

abbrev WorkerM := ReaderT WorkerContext <| StateRefT WorkerState IO

Expand Down Expand Up @@ -288,6 +288,7 @@ section Initialization
mainModuleName ← moduleNameOfFileName path none
catch _ => pure ()
let maxDocVersionRef ← IO.mkRef 0
let freshRequestIdRef ← IO.mkRef 0
let chanOut ← mkLspOutputChannel maxDocVersionRef
let srcSearchPathPromise ← IO.Promise.new

Expand All @@ -301,6 +302,7 @@ section Initialization
processor
clientHasWidgets
maxDocVersionRef
freshRequestIdRef
cmdlineOpts := opts
}
let doc : EditableDocumentCore := {
Expand All @@ -316,7 +318,6 @@ section Initialization
pendingRequests := RBMap.empty
rpcSessions := RBMap.empty
importCachingTask? := none
freshRequestID := 0
})
where
/-- Creates an LSP message output channel along with a reader that sends out read messages on
Expand Down Expand Up @@ -345,6 +346,18 @@ section Initialization
return chanOut
end Initialization

section ServerRequests
def sendServerRequest [ToJson α]
(ctx : WorkerContext)
(method : String)
(param : α)
: IO Unit := do
let freshRequestId ← ctx.freshRequestIdRef.modifyGet fun freshRequestId =>
(freshRequestId, freshRequestId + 1)
let r : JsonRpc.Request α := ⟨freshRequestId, method, param⟩
ctx.chanOut.send r
end ServerRequests

section Updates
def updatePendingRequests (map : PendingRequestMap → PendingRequestMap) : WorkerM Unit := do
modify fun st => { st with pendingRequests := map st.pendingRequests }
Expand Down Expand Up @@ -541,8 +554,8 @@ section MessageHandling
ctx.chanOut.send <| e.toLspResponseError id
queueRequest id t

def handleResponse (_ : RequestID) (result : Json) : WorkerM Unit :=
throwServerError s!"Unknown response kind: {result}"
def handleResponse (_ : RequestID) (_ : Json) : WorkerM Unit :=
return -- The only response that we currently expect here is always empty

end MessageHandling

Expand Down Expand Up @@ -589,6 +602,13 @@ section MainLoop
| _ => throwServerError "Got invalid JSON-RPC message"
end MainLoop

def runRefreshTask : WorkerM (Task (Except IO.Error Unit)) := do
let ctx ← read
IO.asTask do
while ! (←IO.checkCanceled) do
sendServerRequest ctx "workspace/semanticTokens/refresh" (none : Option Nat)
IO.sleep 2000

def initAndRunWorker (i o e : FS.Stream) (opts : Options) : IO UInt32 := do
let i ← maybeTee "fwIn.txt" false i
let o ← maybeTee "fwOut.txt" true o
Expand All @@ -605,7 +625,11 @@ def initAndRunWorker (i o e : FS.Stream) (opts : Options) : IO UInt32 := do
let _ ← IO.setStderr e
try
let (ctx, st) ← initializeWorker meta o e initParams.param opts
let _ ← StateRefT'.run (s := st) <| ReaderT.run (r := ctx) (mainLoop i)
let _ ← StateRefT'.run (s := st) <| ReaderT.run (r := ctx) do
let refreshTask ← runRefreshTask
let exitCode ← mainLoop i
IO.cancel refreshTask
pure exitCode
return (0 : UInt32)
catch err =>
IO.eprintln err
Expand Down
35 changes: 22 additions & 13 deletions src/Lean/Server/FileWorker/RequestHandling.lean
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,10 @@ def noHighlightKinds : Array SyntaxNodeKind := #[
``Lean.Parser.Command.moduleDoc]

structure SemanticTokensContext where
beginPos : String.Pos
endPos : String.Pos
text : FileMap
snap : Snapshot
beginPos : String.Pos
endPos? : Option String.Pos
text : FileMap
snap : Snapshot

structure SemanticTokensState where
data : Array Nat
Expand All @@ -447,20 +447,29 @@ def keywordSemanticTokenMap : RBMap String SemanticTokenType compare :=
|>.insert "stop" .leanSorryLike
|>.insert "#exit" .leanSorryLike

partial def handleSemanticTokens (beginPos endPos : String.Pos)
partial def handleSemanticTokens (beginPos : String.Pos) (endPos? : Option String.Pos)
: RequestM (RequestTask SemanticTokens) := do
let doc ← readDoc
let text := doc.meta.text
let t := doc.cmdSnaps.waitUntil (·.endPos >= endPos)
mapTask t fun (snaps, _) =>
match endPos? with
| none =>
-- Only grabs the finished prefix so that we do not need to wait for elaboration to complete
-- for the full file before sending a response. This means that the response will be incomplete,
-- which we mitigate by regularly sending `workspace/semanticTokens/refresh` requests in the
-- `FileWorker` to tell the client to re-compute the semantic tokens.
let (snaps, _) ← doc.cmdSnaps.getFinishedPrefix
asTask <| run doc snaps
| some endPos =>
let t := doc.cmdSnaps.waitUntil (·.endPos >= endPos)
mapTask t fun (snaps, _) => run doc snaps
where
run doc snaps : RequestM SemanticTokens :=
StateT.run' (s := { data := #[], lastLspPos := ⟨0, 0⟩ : SemanticTokensState }) do
for s in snaps do
if s.endPos <= beginPos then
continue
ReaderT.run (r := SemanticTokensContext.mk beginPos endPos text s) <|
ReaderT.run (r := SemanticTokensContext.mk beginPos endPos? doc.meta.text s) <|
go s.stx
return { data := (← get).data }
where
go (stx : Syntax) := do
match stx with
| `($e.$id:ident) => go e; addToken id SemanticTokenType.property
Expand Down Expand Up @@ -506,9 +515,9 @@ where
(val.length > 1 && val.front == '#' && (val.get ⟨1⟩).isAlpha) then
addToken stx (keywordSemanticTokenMap.findD val .keyword)
addToken stx type := do
let ⟨beginPos, endPos, text, _⟩ ← read
let ⟨beginPos, endPos?, text, _⟩ ← read
if let (some pos, some tailPos) := (stx.getPos?, stx.getTailPos?) then
if beginPos <= pos && pos < endPos then
if beginPos <= pos && endPos?.all (pos < ·) then
let lspPos := (← get).lastLspPos
let lspPos' := text.utf8PosToLspPos pos
let deltaLine := lspPos'.line - lspPos.line
Expand All @@ -523,7 +532,7 @@ where

def handleSemanticTokensFull (_ : SemanticTokensParams)
: RequestM (RequestTask SemanticTokens) := do
handleSemanticTokens 0 ⟨1 <<< 31⟩
handleSemanticTokens 0 none

def handleSemanticTokensRange (p : SemanticTokensRangeParams)
: RequestM (RequestTask SemanticTokens) := do
Expand Down
1 change: 1 addition & 0 deletions tests/lean/interactive/run.lean
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ partial def main (args : List String) : IO Unit := do
assert! id == requestNo
return r
| Message.notification .. => readFirstResponse
| Message.request .. => readFirstResponse
| msg => throw <| IO.userError s!"unexpected message {toJson msg}"
let resp ← readFirstResponse
IO.eprintln resp
Expand Down
Loading