From 5b3a8daae210dfc663c2c21ce2e4e9ed3300867c Mon Sep 17 00:00:00 2001 From: Denis Lantsman Date: Mon, 16 Dec 2024 07:26:08 -0800 Subject: [PATCH] refactor everything to be in one model, use thunks for sendmessage, reference toolManager from parts --- rplugin/node/magenta/src/anthropic.ts | 12 +- rplugin/node/magenta/src/chat/chat.ts | 182 +++++++++++++----- rplugin/node/magenta/src/chat/message.ts | 36 ++-- rplugin/node/magenta/src/chat/part.ts | 26 ++- .../node/magenta/src/debug/call-anthropic.ts | 7 +- rplugin/node/magenta/src/magenta.ts | 92 +-------- rplugin/node/magenta/src/tea/tea.ts | 8 + rplugin/node/magenta/src/tools/getFile.ts | 3 +- rplugin/node/magenta/src/tools/insert.ts | 47 ++++- rplugin/node/magenta/src/tools/toolManager.ts | 33 ++-- rplugin/node/magenta/test/preamble.ts | 3 +- 11 files changed, 251 insertions(+), 198 deletions(-) diff --git a/rplugin/node/magenta/src/anthropic.ts b/rplugin/node/magenta/src/anthropic.ts index 51edb66..e888b5c 100644 --- a/rplugin/node/magenta/src/anthropic.ts +++ b/rplugin/node/magenta/src/anthropic.ts @@ -2,7 +2,7 @@ import Anthropic from "@anthropic-ai/sdk"; import { context } from "./context.ts"; import { TOOL_SPECS, ToolRequest } from "./tools/toolManager.ts"; -export class AnthropicClient { +class AnthropicClient { private client: Anthropic; constructor() { @@ -71,3 +71,13 @@ export class AnthropicClient { return toolRequests; } } + +let client: AnthropicClient | undefined; + +// lazy load so we have a chance to init context before constructing the class +export function getClient() { + if (!client) { + client = new AnthropicClient(); + } + return client; +} diff --git a/rplugin/node/magenta/src/chat/chat.ts b/rplugin/node/magenta/src/chat/chat.ts index 5904976..e6984ab 100644 --- a/rplugin/node/magenta/src/chat/chat.ts +++ b/rplugin/node/magenta/src/chat/chat.ts @@ -8,9 +8,17 @@ import { view as messageView, } from "./message.ts"; import { ToolModel } from "../tools/toolManager.ts"; -import { Dispatch, Update } from "../tea/tea.ts"; +import { + Dispatch, + parallelThunks, + Thunk, + Update, + wrapThunk, +} from "../tea/tea.ts"; import { d, View, withBindings } from "../tea/view.ts"; import { context } from "../context.ts"; +import * as ToolManager from "../tools/toolManager.ts"; +import { getClient } from "../anthropic.ts"; export type Role = "user" | "assistant"; @@ -25,8 +33,16 @@ export type ChatState = state: "awaiting-tool-use"; }; +export function initModel(): Model { + return { + messages: [], + toolManager: ToolManager.initModel(), + }; +} + export type Model = { messages: Message[]; + toolManager: ToolManager.Model; }; export type Msg = @@ -45,20 +61,18 @@ export type Msg = text: string; } | { - type: "add-tool-use"; - toolModel: ToolModel; + type: "init-tool-use"; + request: ToolManager.ToolRequest; } | { - type: "add-tool-response"; - toolModel: ToolModel; - response: ToolResultBlockParam; + type: "send-message"; } | { - type: "tool-model-update"; - toolModel: ToolModel; + type: "clear"; } | { - type: "clear"; + type: "tool-manager-msg"; + msg: ToolManager.Msg; }; export const update: Update = (msg, model) => { @@ -80,9 +94,27 @@ export const update: Update = (msg, model) => { return [model]; } + case "send-message": { + const lastMessage = model.messages[model.messages.length - 1]; + if (lastMessage && lastMessage.role == "user") { + return [model, sendMessage(model)]; + } + return [model]; + } + case "message-msg": { const [nextMessage] = updateMessage(msg.msg, model.messages[msg.idx]); model.messages[msg.idx] = nextMessage; + + if (msg.msg.type == "tool-manager-msg") { + const [nextToolManager, toolManagerThunk] = ToolManager.update( + msg.msg.msg, + model.toolManager, + ); + model.toolManager = nextToolManager; + return [model, wrapThunk("tool-manager-msg", toolManagerThunk)]; + } + return [model]; } @@ -104,7 +136,13 @@ export const update: Update = (msg, model) => { return [model]; } - case "add-tool-use": { + case "init-tool-use": { + const [nextToolManager, toolManagerThunk] = ToolManager.update( + { type: "init-tool-use", request: msg.request }, + model.toolManager, + ); + model.toolManager = nextToolManager; + const lastMessage = model.messages[model.messages.length - 1]; if (lastMessage?.role !== "assistant") { model.messages.push({ @@ -112,64 +150,103 @@ export const update: Update = (msg, model) => { parts: [], }); } + const [nextMessage] = updateMessage( - { type: "add-tool-use", toolModel: msg.toolModel }, + { type: "add-tool-use", requestId: msg.request.id }, model.messages[model.messages.length - 1], ); model.messages[model.messages.length - 1] = nextMessage; - return [model]; + return [model, wrapThunk("tool-manager-msg", toolManagerThunk)]; } - case "add-tool-response": { - let lastMessage = model.messages[model.messages.length - 1]; - if (lastMessage?.role !== "user") { - lastMessage = { - role: "user", - parts: [], - }; - model.messages.push(lastMessage); - } - - const [next] = updateMessage( - { - type: "add-tool-response", - toolModel: msg.toolModel, - response: msg.response, - }, - lastMessage, + case "tool-manager-msg": { + const [nextToolManager, toolManagerThunk] = ToolManager.update( + msg.msg, + model.toolManager, ); - model.messages.splice(model.messages.length - 1, 1, next); - return [model]; - } + model.toolManager = nextToolManager; + let nextModel = model; + let thunk: Thunk | undefined = wrapThunk( + "tool-manager-msg", + toolManagerThunk, + ); + if (msg.msg.type == "tool-msg" && msg.msg.msg.msg.type == "finish") { + const toolModel = nextToolManager.toolModels[msg.msg.id]; - case "tool-model-update": { - const nextMessages: Model["messages"] = []; - for (const message of model.messages) { - const [next] = updateMessage( - { - type: "tool-model-update", - toolModel: msg.toolModel, - }, - message, - ); + const response = msg.msg.msg.msg.result; + [nextModel] = addToolResponse(model, toolModel, response); + if (toolModel.autoRespond) { + let shouldRespond = true; + for (const tool of Object.values(model.toolManager.toolModels)) { + if (tool.state.state != "done") { + shouldRespond = false; + break; + } + } - nextMessages.push(next); + if (shouldRespond) { + thunk = parallelThunks(thunk, sendMessage(model)); + } + } } - - return [ - { - ...model, - messages: nextMessages, - }, - ]; + return [nextModel, thunk]; } case "clear": { - return [{ messages: [] }]; + return [initModel()]; } } }; +function sendMessage(model: Model): Thunk { + return async function (dispatch: Dispatch) { + const messages = getMessages(model); + + const toolRequests = await getClient().sendMessage(messages, (text) => { + context.logger.trace(`stream received text ${text}`); + dispatch({ + type: "stream-response", + text, + }); + }); + + if (toolRequests.length) { + for (const request of toolRequests) { + dispatch({ + type: "init-tool-use", + request, + }); + } + } + }; +} + +function addToolResponse( + model: Model, + toolModel: ToolModel, + response: ToolResultBlockParam, +): [Model] { + let lastMessage = model.messages[model.messages.length - 1]; + if (lastMessage?.role !== "user") { + lastMessage = { + role: "user", + parts: [], + }; + model.messages.push(lastMessage); + } + + const [next] = updateMessage( + { + type: "add-tool-response", + requestId: toolModel.request.id, + response, + }, + lastMessage, + ); + model.messages.splice(model.messages.length - 1, 1, next); + return [model]; +} + export const view: View<{ model: Model; dispatch: Dispatch }> = ({ model, dispatch, @@ -179,6 +256,7 @@ export const view: View<{ model: Model; dispatch: Dispatch }> = ({ (m, idx) => d`${messageView({ model: m, + toolManager: model.toolManager, dispatch: (msg) => { dispatch({ type: "message-msg", msg, idx }); }, @@ -191,6 +269,6 @@ export const view: View<{ model: Model; dispatch: Dispatch }> = ({ export function getMessages(model: Model): Anthropic.MessageParam[] { return model.messages.map((msg) => ({ role: msg.role, - content: msg.parts.map(toMessageParam), + content: msg.parts.map((p) => toMessageParam(p, model.toolManager)), })); } diff --git a/rplugin/node/magenta/src/chat/message.ts b/rplugin/node/magenta/src/chat/message.ts index 1787a51..b6c2121 100644 --- a/rplugin/node/magenta/src/chat/message.ts +++ b/rplugin/node/magenta/src/chat/message.ts @@ -1,6 +1,6 @@ import { ToolResultBlockParam } from "@anthropic-ai/sdk/resources/index.mjs"; import { Model as Part, view as partView } from "./part.ts"; -import { ToolModel } from "../tools/toolManager.ts"; +import * as ToolManager from "../tools/toolManager.ts"; import { Role } from "./chat.ts"; import { Dispatch, Update } from "../tea/tea.ts"; import { assertUnreachable } from "../utils/assertUnreachable.ts"; @@ -18,16 +18,16 @@ export type Msg = } | { type: "add-tool-use"; - toolModel: ToolModel; + requestId: ToolManager.ToolRequestId; } | { type: "add-tool-response"; - toolModel: ToolModel; + requestId: ToolManager.ToolRequestId; response: ToolResultBlockParam; } | { - type: "tool-model-update"; - toolModel: ToolModel; + type: "tool-manager-msg"; + msg: ToolManager.Msg; }; export const update: Update = (msg, model) => { @@ -48,27 +48,20 @@ export const update: Update = (msg, model) => { case "add-tool-use": model.parts.push({ type: "tool-request", - toolModel: msg.toolModel, + requestId: msg.requestId, }); break; case "add-tool-response": model.parts.push({ type: "tool-response", - toolModel: msg.toolModel, + requestId: msg.requestId, response: msg.response, }); break; - case "tool-model-update": { - for (const part of model.parts) { - if ( - (part.type == "tool-request" || part.type == "tool-response") && - part.toolModel.request.id == msg.toolModel.request.id - ) { - part.toolModel = msg.toolModel; - } - } + case "tool-manager-msg": { + // do nothing. This will be handled by the tool manager return [model]; } @@ -78,9 +71,12 @@ export const update: Update = (msg, model) => { return [model]; }; -export const view: View<{ model: Model; dispatch: Dispatch }> = ({ - model, -}) => +export const view: View<{ + model: Model; + toolManager: ToolManager.Model; + dispatch: Dispatch; +}> = ({ model, toolManager, dispatch }) => d`### ${model.role}:\n${model.parts.map( - (part) => d`${partView({ model: part })}\n`, + (part) => + d`${partView({ model: part, toolManager, dispatch: (msg) => dispatch({ type: "tool-manager-msg", msg }) })}\n`, )}`; diff --git a/rplugin/node/magenta/src/chat/part.ts b/rplugin/node/magenta/src/chat/part.ts index 342dbe1..90f559e 100644 --- a/rplugin/node/magenta/src/chat/part.ts +++ b/rplugin/node/magenta/src/chat/part.ts @@ -1,7 +1,8 @@ import Anthropic from "@anthropic-ai/sdk"; -import { renderTool, ToolModel } from "../tools/toolManager.ts"; +import * as ToolManager from "../tools/toolManager.ts"; import { assertUnreachable } from "../utils/assertUnreachable.ts"; import { d, View } from "../tea/view.ts"; +import { Dispatch } from "../tea/tea.ts"; /** A line that's meant to be sent to neovim. Should not contain newlines */ @@ -14,21 +15,27 @@ export type Model = } | { type: "tool-request"; - toolModel: ToolModel; + requestId: ToolManager.ToolRequestId; } | { type: "tool-response"; - toolModel: ToolModel; + requestId: ToolManager.ToolRequestId; response: Anthropic.ToolResultBlockParam; }; -export const view: View<{ model: Model }> = ({ model }) => { +export const view: View<{ + model: Model; + toolManager: ToolManager.Model; + dispatch: Dispatch; +}> = ({ model, dispatch, toolManager }) => { switch (model.type) { case "text": return d`${model.text}`; case "tool-request": - case "tool-response": - return renderTool(model.toolModel); + case "tool-response": { + const toolModel = toolManager.toolModels[model.requestId]; + return ToolManager.renderTool(toolModel, dispatch); + } default: assertUnreachable(model); } @@ -36,6 +43,7 @@ export const view: View<{ model: Model }> = ({ model }) => { export function toMessageParam( part: Model, + toolManager: ToolManager.Model, ): | Anthropic.TextBlockParam | Anthropic.ToolUseBlockParam @@ -43,8 +51,10 @@ export function toMessageParam( switch (part.type) { case "text": return part; - case "tool-request": - return part.toolModel.request; + case "tool-request": { + const toolModel = toolManager.toolModels[part.requestId]; + return toolModel.request; + } case "tool-response": return part.response; default: diff --git a/rplugin/node/magenta/src/debug/call-anthropic.ts b/rplugin/node/magenta/src/debug/call-anthropic.ts index 6109d15..d3afc9a 100644 --- a/rplugin/node/magenta/src/debug/call-anthropic.ts +++ b/rplugin/node/magenta/src/debug/call-anthropic.ts @@ -1,7 +1,7 @@ -import { AnthropicClient } from "../anthropic.ts"; +import { getClient } from "../anthropic.ts"; import { Logger } from "../logger.ts"; import { setContext } from "../context.ts"; -import { Neovim } from "neovim"; +import { Neovim, NvimPlugin } from "neovim"; const logger = new Logger( { @@ -15,12 +15,13 @@ const logger = new Logger( ); setContext({ + plugin: undefined as unknown as NvimPlugin, nvim: undefined as unknown as Neovim, logger, }); async function run() { - const client = new AnthropicClient(); + const client = getClient(); await client.sendMessage( [ diff --git a/rplugin/node/magenta/src/magenta.ts b/rplugin/node/magenta/src/magenta.ts index 9f0ced7..7780a7a 100644 --- a/rplugin/node/magenta/src/magenta.ts +++ b/rplugin/node/magenta/src/magenta.ts @@ -1,74 +1,25 @@ -import { AnthropicClient } from "./anthropic.ts"; import { NvimPlugin } from "neovim"; import { Sidebar } from "./sidebar.ts"; import * as Chat from "./chat/chat.ts"; import { Logger } from "./logger.ts"; import { App, createApp } from "./tea/tea.ts"; -import * as ToolManager from "./tools/toolManager.ts"; -import { d } from "./tea/view.ts"; import { setContext, context } from "./context.ts"; import { BindingKey } from "./tea/mappings.ts"; class Magenta { - private anthropicClient: AnthropicClient; private sidebar: Sidebar; private chat: App; private chatRoot: { onKey(key: BindingKey): void } | undefined; - private toolManager: App; constructor() { context.logger.debug(`Initializing plugin`); - this.anthropicClient = new AnthropicClient(); this.sidebar = new Sidebar(); this.chat = createApp({ - initialModel: { messages: [] }, + initialModel: Chat.initModel(), update: Chat.update, View: Chat.view, }); - - this.toolManager = createApp({ - initialModel: ToolManager.initModel(), - update: ToolManager.update, - View: () => d``, - onUpdate: (msg, model) => { - if (msg.type == "tool-msg") { - const toolModel = model.toolModels[msg.id]; - - // sync toolModel state w/ all the messages where it appears - this.chat.dispatch({ - type: "tool-model-update", - toolModel, - }); - - if (msg.msg.msg.type == "finish") { - const toolModel = model.toolModels[msg.id]; - const response = msg.msg.msg.result; - this.chat.dispatch({ - type: "add-tool-response", - toolModel, - response, - }); - - if (toolModel.autoRespond) { - let shouldRespond = true; - for (const tool of Object.values(model.toolModels)) { - if (tool.state.state != "done") { - shouldRespond = false; - break; - } - } - - if (shouldRespond) { - this.sendMessage().catch((err) => - context.logger.error(err as Error), - ); - } - } - } - } - }, - }); } async command(args: string[]): Promise { @@ -101,7 +52,9 @@ class Magenta { content: message, }); - await this.sendMessage(); + this.chat.dispatch({ + type: "send-message", + }); break; } @@ -114,43 +67,6 @@ class Magenta { } } - private async sendMessage() { - const state = this.chat.getState(); - if (state.status != "running") { - context.logger.error(`chat is not running.`); - return; - } - - const messages = Chat.getMessages(state.model); - - const toolRequests = await this.anthropicClient.sendMessage( - messages, - (text) => { - context.logger.trace(`stream received text ${text}`); - this.chat.dispatch({ - type: "stream-response", - text, - }); - }, - ); - - if (toolRequests.length) { - for (const request of toolRequests) { - this.toolManager.dispatch({ - type: "init-tool-use", - request, - }); - const toolManagerModel = this.toolManager.getState(); - if (toolManagerModel.status == "running") { - this.chat.dispatch({ - type: "add-tool-use", - toolModel: toolManagerModel.model.toolModels[request.id], - }); - } - } - } - } - onKey(key: BindingKey) { if (this.chatRoot) { this.chatRoot.onKey(key); diff --git a/rplugin/node/magenta/src/tea/tea.ts b/rplugin/node/magenta/src/tea/tea.ts index 7cf1f8b..d4f7170 100644 --- a/rplugin/node/magenta/src/tea/tea.ts +++ b/rplugin/node/magenta/src/tea/tea.ts @@ -250,3 +250,11 @@ export function chainThunks( } }; } + +export function parallelThunks( + ...thunks: (Thunk | undefined)[] +): Thunk { + return async (dispatch) => { + await Promise.all(thunks.map((t) => t && t(dispatch))); + }; +} diff --git a/rplugin/node/magenta/src/tools/getFile.ts b/rplugin/node/magenta/src/tools/getFile.ts index f4b7481..9037c3d 100644 --- a/rplugin/node/magenta/src/tools/getFile.ts +++ b/rplugin/node/magenta/src/tools/getFile.ts @@ -7,6 +7,7 @@ import { ToolResultBlockParam } from "@anthropic-ai/sdk/resources/index.mjs"; import { Thunk, Update } from "../tea/tea.ts"; import { d, VDOMNode } from "../tea/view.ts"; import { context } from "../context.ts"; +import { ToolRequestId } from "./toolManager.ts"; export type Model = { type: "get-file"; @@ -165,7 +166,7 @@ export const spec: Anthropic.Anthropic.Tool = { export type GetFileToolUseRequest = { type: "tool_use"; - id: string; //"toolu_01UJtsBsBED9bwkonjqdxji4" + id: ToolRequestId; //"toolu_01UJtsBsBED9bwkonjqdxji4" name: "get_file"; input: { filePath: string; //"./src/index.ts" diff --git a/rplugin/node/magenta/src/tools/insert.ts b/rplugin/node/magenta/src/tools/insert.ts index 3a5a1e7..d6e2778 100644 --- a/rplugin/node/magenta/src/tools/insert.ts +++ b/rplugin/node/magenta/src/tools/insert.ts @@ -3,8 +3,9 @@ import { Buffer } from "neovim"; import { assertUnreachable } from "../utils/assertUnreachable.ts"; import { ToolResultBlockParam } from "@anthropic-ai/sdk/resources/index.mjs"; import { Dispatch, Update } from "../tea/tea.ts"; -import { d, VDOMNode } from "../tea/view.ts"; +import { d, VDOMNode, withBindings } from "../tea/view.ts"; import { context } from "../context.ts"; +import { ToolRequestId } from "./toolManager.ts"; export type Model = { type: "insert"; @@ -14,6 +15,9 @@ export type Model = { | { state: "pending-user-action"; } + | { + state: "editing-diff"; + } | { state: "done"; result: ToolResultBlockParam; @@ -56,7 +60,7 @@ export const update: Update = (msg, model) => { } }; -export function initModel(request: InsertToolUseRequest) { +export function initModel(request: InsertToolUseRequest): [Model] { const model: Model = { type: "insert", autoRespond: false, @@ -133,10 +137,39 @@ export function insertThunk(model: Model) { }; } -export function view({ model }: { model: Model }): VDOMNode { - return d`Insert operation ${ - model.state.state === "done" ? "completed" : "in progress" - }`; +export function view({ + model, + dispatch, +}: { + model: Model; + dispatch: Dispatch; +}): VDOMNode { + return d`Insert ${( + model.request.input.content.match(/\n/g) || [] + ).length.toString()} into file ${model.request.input.filePath} +${toolStatusView({ model, dispatch })}`; +} + +function toolStatusView({ + model, + dispatch, +}: { + model: Model; + dispatch: Dispatch; +}): VDOMNode { + switch (model.state.state) { + case "pending-user-action": + return withBindings(d`[review diff]`, { + Enter: () => + dispatch({ + type: "display-diff", + }), + }); + case "editing-diff": + return d`Editing diff`; + case "done": + return d`Done`; + } } export const spec: Anthropic.Anthropic.Tool = { @@ -164,7 +197,7 @@ export const spec: Anthropic.Anthropic.Tool = { export type InsertToolUseRequest = { type: "tool_use"; - id: string; + id: ToolRequestId; name: "insert"; input: { filePath: string; diff --git a/rplugin/node/magenta/src/tools/toolManager.ts b/rplugin/node/magenta/src/tools/toolManager.ts index 4c74c72..c23f145 100644 --- a/rplugin/node/magenta/src/tools/toolManager.ts +++ b/rplugin/node/magenta/src/tools/toolManager.ts @@ -1,6 +1,6 @@ import * as GetFile from "./getFile.ts"; import * as Insert from "./insert.ts"; -import { Update } from "../tea/tea.ts"; +import { Dispatch, Update } from "../tea/tea.ts"; import { assertUnreachable } from "../utils/assertUnreachable.ts"; export type ToolRequest = @@ -9,20 +9,30 @@ export type ToolRequest = export type ToolModel = GetFile.Model | Insert.Model; +export type ToolRequestId = string & { __toolRequestId: true }; + export const TOOL_SPECS = [GetFile.spec, Insert.spec]; export type Model = { toolModels: { - [id: string]: GetFile.Model | Insert.Model; + [id: ToolRequestId]: GetFile.Model | Insert.Model; }; }; -export function renderTool(model: ToolModel) { +export function renderTool(model: ToolModel, dispatch: Dispatch) { switch (model.type) { case "get-file": return GetFile.view({ model }); case "insert": - return Insert.view({ model }); + return Insert.view({ + model, + dispatch: (msg) => + dispatch({ + type: "tool-msg", + id: model.request.id, + msg: { type: "insert", msg }, + }), + }); default: assertUnreachable(model); } @@ -35,7 +45,7 @@ export type Msg = } | { type: "tool-msg"; - id: string; + id: ToolRequestId; msg: | { type: "get-file"; @@ -83,7 +93,7 @@ export const update: Update = (msg, model) => { } case "insert": { - const [insertModel, thunk] = Insert.initModel(request); + const [insertModel] = Insert.initModel(request); return [ { ...model, @@ -92,17 +102,6 @@ export const update: Update = (msg, model) => { [request.id]: insertModel, }, }, - (dispatch) => - thunk((msg) => - dispatch({ - type: "tool-msg", - id: request.id, - msg: { - type: "insert", - msg, - }, - }), - ), ]; } default: diff --git a/rplugin/node/magenta/test/preamble.ts b/rplugin/node/magenta/test/preamble.ts index 42602a2..f0998c5 100644 --- a/rplugin/node/magenta/test/preamble.ts +++ b/rplugin/node/magenta/test/preamble.ts @@ -1,4 +1,4 @@ -import { attach, NeovimClient } from "neovim"; +import { attach, NeovimClient, NvimPlugin } from "neovim"; import { spawn } from "child_process"; import { MountedVDOM } from "../src/tea/view.ts"; import { assertUnreachable } from "../src/utils/assertUnreachable.ts"; @@ -33,6 +33,7 @@ export class NeovimTestHelper { try { this.nvimClient = attach({ proc: this.nvimProcess }); setContext({ + plugin: undefined as unknown as NvimPlugin, nvim: this.nvimClient, logger: { log: console.log,