From 4226b30ab72210fab85d4959462c99312803da6c Mon Sep 17 00:00:00 2001 From: Denis Lantsman Date: Sat, 14 Dec 2024 09:08:59 -0800 Subject: [PATCH] further conversion to tea architecture --- rplugin/node/magenta/src/anthropic.ts | 16 ++--- rplugin/node/magenta/src/chat/chat.ts | 56 ++++++++++++++- rplugin/node/magenta/src/magenta.ts | 81 +++++++++++++++------- rplugin/node/magenta/src/sidebar.ts | 47 +++++++------ rplugin/node/magenta/src/tea/tea.ts | 20 +++++- rplugin/node/magenta/src/tea/util.ts | 2 +- rplugin/node/magenta/src/tea/view.spec.ts | 21 +----- rplugin/node/magenta/src/tea/view.ts | 7 +- rplugin/node/magenta/src/tools/getFile.ts | 2 +- rplugin/node/magenta/src/tools/insert.ts | 2 +- rplugin/node/magenta/src/types.ts | 10 +-- rplugin/node/magenta/src/utils/extmarks.ts | 2 +- rplugin/node/magenta/test/preamble.ts | 4 ++ 13 files changed, 181 insertions(+), 89 deletions(-) diff --git a/rplugin/node/magenta/src/anthropic.ts b/rplugin/node/magenta/src/anthropic.ts index 894c018..99da9f6 100644 --- a/rplugin/node/magenta/src/anthropic.ts +++ b/rplugin/node/magenta/src/anthropic.ts @@ -19,7 +19,7 @@ export class AnthropicClient { async sendMessage( messages: Array, - onText: (text: string) => Promise, + onText: (text: string) => void, ): Promise { this.logger.trace( `initializing stream with messages: ${JSON.stringify(messages, null, 2)}`, @@ -34,14 +34,12 @@ export class AnthropicClient { flushInProgress = true; - onText(text) - .catch((e: Error) => { - this.logger.error(e); - }) - .finally(() => { - flushInProgress = false; - setInterval(flushBuffer, 1); - }); + try { + onText(text); + } finally { + flushInProgress = false; + setInterval(flushBuffer, 1); + } } }; diff --git a/rplugin/node/magenta/src/chat/chat.ts b/rplugin/node/magenta/src/chat/chat.ts index b6c0988..97451ad 100644 --- a/rplugin/node/magenta/src/chat/chat.ts +++ b/rplugin/node/magenta/src/chat/chat.ts @@ -1,9 +1,16 @@ import Anthropic from "@anthropic-ai/sdk"; import { ToolResultBlockParam } from "@anthropic-ai/sdk/resources/index.mjs"; import { toMessageParam } from "./part.js"; -import { Model as Message, update as updateMessage } from "./message.js"; +import { + Model as Message, + Msg as MessageMsg, + update as updateMessage, + view as messageView, +} from "./message.js"; import { ToolRequest } from "../tools/index.js"; -import { Update } from "../tea/tea.js"; +import { Dispatch, Update } from "../tea/tea.js"; +import { ToolProcess } from "../tools/types.js"; +import { d, View } from "../tea/view.js"; export type Role = "user" | "assistant"; @@ -23,11 +30,25 @@ export type Model = { }; export type Msg = + | { + type: "message-msg"; + msg: MessageMsg; + idx: number; + } | { type: "add-message"; role: Role; content?: string; } + | { + type: "stream-response"; + text: string; + } + | { + type: "add-tool-use"; + request: ToolRequest; + process: ToolProcess; + } | { type: "add-tool-response"; request: ToolRequest; @@ -55,6 +76,22 @@ export const update: Update = (msg, model) => { model.messages.push(message); return [model]; } + + case "message-msg": { + // TODO + return [model]; + } + + case "stream-response": { + // TODO + return [model]; + } + + case "add-tool-use": { + // TODO + return [model]; + } + case "add-tool-response": { let lastMessage = model.messages[model.messages.length - 1]; if (lastMessage.role != "user") { @@ -82,6 +119,21 @@ export const update: Update = (msg, model) => { } }; +export const view: View<{ model: Model; dispatch: Dispatch }> = ({ + model, + dispatch, +}) => { + return d`# Chat +${model.messages.map((m, idx) => + messageView({ + model: m, + dispatch: (msg) => { + dispatch({ type: "message-msg", msg, idx }); + }, + }), +)}`; +}; + export function getMessages(model: Model): Anthropic.MessageParam[] { return model.messages.map((msg) => ({ role: msg.role, diff --git a/rplugin/node/magenta/src/magenta.ts b/rplugin/node/magenta/src/magenta.ts index 944aa05..c8b2cdf 100644 --- a/rplugin/node/magenta/src/magenta.ts +++ b/rplugin/node/magenta/src/magenta.ts @@ -1,33 +1,45 @@ import { AnthropicClient } from "./anthropic.js"; import { NvimPlugin } from "neovim"; import { Sidebar } from "./sidebar.js"; -import { view as chatView } from "./chat/chat.js"; +import { + Model as ChatModel, + Msg as ChatMsg, + getMessages, + view as chatView, + update as chatUpdate, +} from "./chat/chat.js"; import { Logger } from "./logger.js"; import { Context } from "./types.js"; import { TOOLS } from "./tools/index.js"; import { assertUnreachable } from "./utils/assertUnreachable.js"; import { ToolProcess } from "./tools/types.js"; import { Moderator } from "./moderator.js"; +import { App, createApp } from "./tea/tea.js"; class Magenta { private anthropicClient: AnthropicClient; private sidebar: Sidebar; private moderator: Moderator; + private chat: App; - constructor( - private context: Context, - private chat: Chat, - ) { + constructor(private context: Context) { this.context.logger.debug(`Initializing plugin`); this.anthropicClient = new AnthropicClient(this.context.logger); this.sidebar = new Sidebar(this.context.nvim, this.context.logger); + this.chat = createApp({ + initialModel: { messages: [] }, + update: chatUpdate, + View: chatView, + }); this.moderator = new Moderator( this.context, // on tool result - (req, res) => { - this.chat - .addToolResponse(req, res) - .catch((err) => this.context.logger.error(err as Error)); + (request, response) => { + this.chat.dispatch({ + type: "add-tool-response", + request, + response, + }); }, // autorespond () => { @@ -42,7 +54,15 @@ class Magenta { this.context.logger.debug(`Received command ${args[0]}`); switch (args[0]) { case "toggle": { - await this.sidebar.toggle(this.chat.displayBuffer); + const buffers = await this.sidebar.toggle(); + if (buffers) { + await this.chat.mount({ + nvim: this.context.nvim, + buffer: buffers.displayBuffer, + startPos: { row: 0, col: 0 }, + endPos: { row: 0, col: 0 }, + }); + } break; } @@ -51,14 +71,18 @@ class Magenta { this.context.logger.trace(`current message: ${message}`); if (!message) return; - await this.chat.addMessage("user", message); + this.chat.dispatch({ + type: "add-message", + role: "user", + content: message, + }); await this.sendMessage(); break; } case "clear": - this.chat.clear(); + this.chat.dispatch({ type: "clear" }); break; default: @@ -67,14 +91,22 @@ class Magenta { } private async sendMessage() { - const messages = this.chat.getMessages(); + const state = this.chat.getState(); + if (state.status != "running") { + this.context.logger.error(`chat is not running.`); + return; + } + + const messages = getMessages(state.model); - const currentMessage = await this.chat.addMessage("assistant", ""); const toolRequests = await this.anthropicClient.sendMessage( messages, - async (text) => { + (text) => { this.context.logger.trace(`stream received text ${text}`); - await currentMessage.appendText(text); + this.chat.dispatch({ + type: "stream-response", + text, + }); }, ); @@ -95,18 +127,17 @@ class Magenta { } this.moderator.registerProcess(process); - await currentMessage.addToolUse(request, process); + this.chat.dispatch({ + type: "add-tool-use", + request, + process, + }); } } } - - static async init(plugin: NvimPlugin, logger: Logger) { - const chat = await Chat.init({ nvim: plugin.nvim, logger }); - return new Magenta({ nvim: plugin.nvim, logger }, chat); - } } -let init: { magenta: Promise; logger: Logger } | undefined = undefined; +let init: { magenta: Magenta; logger: Logger } | undefined = undefined; module.exports = (plugin: NvimPlugin) => { plugin.setOptions({}); @@ -119,7 +150,7 @@ module.exports = (plugin: NvimPlugin) => { }); init = { - magenta: Magenta.init(plugin, logger), + magenta: new Magenta({ nvim: plugin.nvim, logger }), logger, }; } @@ -128,7 +159,7 @@ module.exports = (plugin: NvimPlugin) => { "Magenta", async (args: string[]) => { try { - const magenta = await init!.magenta; + const magenta = init!.magenta; await magenta.command(args); } catch (err) { init!.logger.error(err as Error); diff --git a/rplugin/node/magenta/src/sidebar.ts b/rplugin/node/magenta/src/sidebar.ts index e4bb5a3..209d8d7 100644 --- a/rplugin/node/magenta/src/sidebar.ts +++ b/rplugin/node/magenta/src/sidebar.ts @@ -6,11 +6,10 @@ import { Logger } from "./logger.js"; export class Sidebar { private state: | { - state: "not-loaded"; + state: "hidden"; } | { - state: "loaded"; - visible: boolean; + state: "visible"; displayBuffer: Buffer; inputBuffer: Buffer; displayWindow: Window; @@ -21,24 +20,26 @@ export class Sidebar { private nvim: Neovim, private logger: Logger, ) { - this.state = { state: "not-loaded" }; + this.state = { state: "hidden" }; + // TODO: also probably need to set up some autocommands to detect if the user closes the scratch buffers } - /** returns the input buffer when it was created + /** returns buffers when they are visible */ - async toggle(displayBuffer: Buffer): Promise { - if (this.state.state == "not-loaded") { - await this.create(displayBuffer); + async toggle(): Promise< + { displayBuffer: Buffer; inputBuffer: Buffer } | undefined + > { + if (this.state.state == "hidden") { + return await this.create(); } else { - if (this.state.visible) { - await this.hide(); - } else { - await this.show(); - } + await this.destroy(); } } - private async create(displayBuffer: Buffer): Promise { + private async create(): Promise<{ + displayBuffer: Buffer; + inputBuffer: Buffer; + }> { const { nvim, logger } = this; logger.trace(`sidebar.create`); const totalHeight = (await nvim.getOption("lines")) as number; @@ -49,6 +50,7 @@ export class Sidebar { await nvim.command("leftabove vsplit"); const displayWindow = await nvim.window; + const displayBuffer = (await this.nvim.createBuffer(false, true)) as Buffer; displayWindow.width = width; await nvim.lua( `vim.api.nvim_win_set_buf(${displayWindow.id}, ${displayBuffer.id})`, @@ -93,20 +95,23 @@ export class Sidebar { logger.trace(`sidebar.create setting state`); this.state = { - state: "loaded", - visible: true, + state: "visible", displayBuffer, inputBuffer, displayWindow, inputWindow, }; - return inputBuffer; + return { displayBuffer, inputBuffer }; } - async hide() {} + async destroy() { + this.state = { + state: "hidden", + }; - async show() {} + // TODO: clean up buffers + } async scrollTop() { // const { displayWindow } = await this.getWindowIfVisible(); @@ -125,7 +130,7 @@ export class Sidebar { displayWindow?: Window; inputWindow?: Window; }> { - if (this.state.state != "loaded") { + if (this.state.state != "visible") { return {}; } @@ -140,7 +145,7 @@ export class Sidebar { } async getMessage(): Promise { - if (this.state.state != "loaded") { + if (this.state.state != "visible") { this.logger.trace(`sidebar state is ${this.state.state} in getMessage`); return ""; } diff --git a/rplugin/node/magenta/src/tea/tea.ts b/rplugin/node/magenta/src/tea/tea.ts index b6d7d2d..6886aa4 100644 --- a/rplugin/node/magenta/src/tea/tea.ts +++ b/rplugin/node/magenta/src/tea/tea.ts @@ -38,6 +38,13 @@ type AppState = error: string; }; +export type App = { + mount(mount: MountPoint): Promise; + unmount(): void; + dispatch: Dispatch; + getState(): AppState; +}; + export function createApp({ initialModel, update, @@ -51,7 +58,7 @@ export function createApp({ subscriptions: (model: Model) => Subscription[]; subscriptionManager: SubscriptionManager; }; -}) { +}): App { let currentState: AppState = { status: "running", model: initialModel, @@ -146,7 +153,16 @@ export function createApp({ mount, props: { currentState, dispatch }, }); - return { root, dispatch }; + }, + unmount() { + if (root) { + root.unmount(); + root = undefined; + } + }, + dispatch, + getState() { + return currentState; }, }; } diff --git a/rplugin/node/magenta/src/tea/util.ts b/rplugin/node/magenta/src/tea/util.ts index c3e6289..fcbd36e 100644 --- a/rplugin/node/magenta/src/tea/util.ts +++ b/rplugin/node/magenta/src/tea/util.ts @@ -1,6 +1,6 @@ import { Neovim, Buffer } from "neovim"; import { Position } from "./view.js"; -import { Line } from "../part.js"; +import { Line } from "../chat/part.js"; export async function replaceBetweenPositions({ nvim, diff --git a/rplugin/node/magenta/src/tea/view.spec.ts b/rplugin/node/magenta/src/tea/view.spec.ts index c07590e..2d3b039 100644 --- a/rplugin/node/magenta/src/tea/view.spec.ts +++ b/rplugin/node/magenta/src/tea/view.spec.ts @@ -1,9 +1,8 @@ import type { NeovimClient, Buffer } from "neovim"; -import { NeovimTestHelper } from "../../test/preamble.js"; -import { d, MountedVDOM, mountView } from "./view.js"; +import { extractMountTree, NeovimTestHelper } from "../../test/preamble.js"; +import { d, mountView } from "./view.js"; import * as assert from "assert"; import { test } from "node:test"; -import { assertUnreachable } from "../utils/assertUnreachable.js"; await test.describe("Neovim Plugin Tests", async () => { let helper: NeovimTestHelper; @@ -165,19 +164,3 @@ await test.describe("Neovim Plugin Tests", async () => { ); }); }); - -function extractMountTree(mounted: MountedVDOM): unknown { - switch (mounted.type) { - case "string": - return mounted; - case "node": - return { - type: "node", - children: mounted.children.map(extractMountTree), - startPos: mounted.startPos, - endPos: mounted.endPos, - }; - default: - assertUnreachable(mounted); - } -} diff --git a/rplugin/node/magenta/src/tea/view.ts b/rplugin/node/magenta/src/tea/view.ts index 9d942a7..9a0ac7f 100644 --- a/rplugin/node/magenta/src/tea/view.ts +++ b/rplugin/node/magenta/src/tea/view.ts @@ -57,8 +57,8 @@ export type MountedVDOM = export type MountedView

= { render(props: P): Promise; - /** for testing - */ + unmount(): void; + /** for testing */ _getMountedNode(): MountedVDOM; }; @@ -81,6 +81,9 @@ export async function mountView

({ mount, }); }, + unmount() { + // TODO + }, _getMountedNode: () => mountedNode, }; } diff --git a/rplugin/node/magenta/src/tools/getFile.ts b/rplugin/node/magenta/src/tools/getFile.ts index e04b1a2..73a81c5 100644 --- a/rplugin/node/magenta/src/tools/getFile.ts +++ b/rplugin/node/magenta/src/tools/getFile.ts @@ -3,7 +3,7 @@ import { Context } from "../types.js"; import { getBufferIfOpen } from "../utils/buffers.js"; import fs from "fs"; import path from "path"; -import { Line } from "../part.js"; +import { Line } from "../chat/part.js"; import { assertUnreachable } from "../utils/assertUnreachable.js"; import { ToolResultBlockParam } from "@anthropic-ai/sdk/resources/index.mjs"; diff --git a/rplugin/node/magenta/src/tools/insert.ts b/rplugin/node/magenta/src/tools/insert.ts index 654a093..c789fe5 100644 --- a/rplugin/node/magenta/src/tools/insert.ts +++ b/rplugin/node/magenta/src/tools/insert.ts @@ -2,7 +2,7 @@ import * as Anthropic from "@anthropic-ai/sdk"; import { Context } from "../types.js"; import {} from "@anthropic-ai/sdk"; import { Buffer } from "neovim"; -import { Line } from "../part.js"; +import { Line } from "../chat/part.js"; import { assertUnreachable } from "../utils/assertUnreachable.js"; import { ToolResultBlockParam } from "@anthropic-ai/sdk/resources/index.mjs"; diff --git a/rplugin/node/magenta/src/types.ts b/rplugin/node/magenta/src/types.ts index c727746..b937a52 100644 --- a/rplugin/node/magenta/src/types.ts +++ b/rplugin/node/magenta/src/types.ts @@ -1,7 +1,7 @@ -import { Neovim } from "neovim" -import { Logger } from "./logger.js" +import { Neovim } from "neovim"; +import { Logger } from "./logger.js"; export type Context = { - nvim: Neovim, - logger: Logger -} + nvim: Neovim; + logger: Logger; +}; diff --git a/rplugin/node/magenta/src/utils/extmarks.ts b/rplugin/node/magenta/src/utils/extmarks.ts index fd793ed..710f2fb 100644 --- a/rplugin/node/magenta/src/utils/extmarks.ts +++ b/rplugin/node/magenta/src/utils/extmarks.ts @@ -1,5 +1,5 @@ import { Buffer, Neovim } from "neovim"; -import { Line } from "../part.js"; +import { Line } from "../chat/part.js"; export type Mark = number & { __mark: true }; export type MarkOpts = { details: { is_start: boolean } }; diff --git a/rplugin/node/magenta/test/preamble.ts b/rplugin/node/magenta/test/preamble.ts index 7dd9015..53ab8ce 100644 --- a/rplugin/node/magenta/test/preamble.ts +++ b/rplugin/node/magenta/test/preamble.ts @@ -62,6 +62,10 @@ export function extractMountTree(mounted: MountedVDOM): unknown { startPos: mounted.startPos, endPos: mounted.endPos, }; + + case "array": + return mounted; + default: assertUnreachable(mounted); }