Skip to content

Commit

Permalink
further conversion to tea architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
dlants committed Dec 14, 2024
1 parent 5a63e1f commit 4226b30
Show file tree
Hide file tree
Showing 13 changed files with 181 additions and 89 deletions.
16 changes: 7 additions & 9 deletions rplugin/node/magenta/src/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export class AnthropicClient {

async sendMessage(
messages: Array<Anthropic.MessageParam>,
onText: (text: string) => Promise<void>,
onText: (text: string) => void,
): Promise<ToolRequest[]> {
this.logger.trace(
`initializing stream with messages: ${JSON.stringify(messages, null, 2)}`,
Expand All @@ -34,14 +34,12 @@ export class AnthropicClient {

flushInProgress = true;

onText(text)
.catch((e: Error) => {
this.logger.error(e);
})
.finally(() => {
flushInProgress = false;
setInterval(flushBuffer, 1);
});
try {
onText(text);
} finally {
flushInProgress = false;
setInterval(flushBuffer, 1);
}
}
};

Expand Down
56 changes: 54 additions & 2 deletions rplugin/node/magenta/src/chat/chat.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import Anthropic from "@anthropic-ai/sdk";
import { ToolResultBlockParam } from "@anthropic-ai/sdk/resources/index.mjs";
import { toMessageParam } from "./part.js";
import { Model as Message, update as updateMessage } from "./message.js";
import {
Model as Message,
Msg as MessageMsg,
update as updateMessage,
view as messageView,
} from "./message.js";
import { ToolRequest } from "../tools/index.js";
import { Update } from "../tea/tea.js";
import { Dispatch, Update } from "../tea/tea.js";
import { ToolProcess } from "../tools/types.js";
import { d, View } from "../tea/view.js";

export type Role = "user" | "assistant";

Expand All @@ -23,11 +30,25 @@ export type Model = {
};

export type Msg =
| {
type: "message-msg";
msg: MessageMsg;
idx: number;
}
| {
type: "add-message";
role: Role;
content?: string;
}
| {
type: "stream-response";
text: string;
}
| {
type: "add-tool-use";
request: ToolRequest;
process: ToolProcess;
}
| {
type: "add-tool-response";
request: ToolRequest;
Expand Down Expand Up @@ -55,6 +76,22 @@ export const update: Update<Msg, Model> = (msg, model) => {
model.messages.push(message);
return [model];
}

case "message-msg": {
// TODO
return [model];
}

case "stream-response": {
// TODO
return [model];
}

case "add-tool-use": {
// TODO
return [model];
}

case "add-tool-response": {
let lastMessage = model.messages[model.messages.length - 1];
if (lastMessage.role != "user") {
Expand Down Expand Up @@ -82,6 +119,21 @@ export const update: Update<Msg, Model> = (msg, model) => {
}
};

export const view: View<{ model: Model; dispatch: Dispatch<Msg> }> = ({
model,
dispatch,
}) => {
return d`# Chat
${model.messages.map((m, idx) =>
messageView({
model: m,
dispatch: (msg) => {
dispatch({ type: "message-msg", msg, idx });
},
}),
)}`;
};

export function getMessages(model: Model): Anthropic.MessageParam[] {
return model.messages.map((msg) => ({
role: msg.role,
Expand Down
81 changes: 56 additions & 25 deletions rplugin/node/magenta/src/magenta.ts
Original file line number Diff line number Diff line change
@@ -1,33 +1,45 @@
import { AnthropicClient } from "./anthropic.js";
import { NvimPlugin } from "neovim";
import { Sidebar } from "./sidebar.js";
import { view as chatView } from "./chat/chat.js";
import {
Model as ChatModel,
Msg as ChatMsg,
getMessages,
view as chatView,
update as chatUpdate,
} from "./chat/chat.js";
import { Logger } from "./logger.js";
import { Context } from "./types.js";
import { TOOLS } from "./tools/index.js";
import { assertUnreachable } from "./utils/assertUnreachable.js";
import { ToolProcess } from "./tools/types.js";
import { Moderator } from "./moderator.js";
import { App, createApp } from "./tea/tea.js";

class Magenta {
private anthropicClient: AnthropicClient;
private sidebar: Sidebar;
private moderator: Moderator;
private chat: App<ChatMsg, ChatModel>;

constructor(
private context: Context,
private chat: Chat,
) {
constructor(private context: Context) {
this.context.logger.debug(`Initializing plugin`);
this.anthropicClient = new AnthropicClient(this.context.logger);
this.sidebar = new Sidebar(this.context.nvim, this.context.logger);
this.chat = createApp({
initialModel: { messages: [] },
update: chatUpdate,
View: chatView,
});
this.moderator = new Moderator(
this.context,
// on tool result
(req, res) => {
this.chat
.addToolResponse(req, res)
.catch((err) => this.context.logger.error(err as Error));
(request, response) => {
this.chat.dispatch({
type: "add-tool-response",
request,
response,
});
},
// autorespond
() => {
Expand All @@ -42,7 +54,15 @@ class Magenta {
this.context.logger.debug(`Received command ${args[0]}`);
switch (args[0]) {
case "toggle": {
await this.sidebar.toggle(this.chat.displayBuffer);
const buffers = await this.sidebar.toggle();
if (buffers) {
await this.chat.mount({
nvim: this.context.nvim,
buffer: buffers.displayBuffer,
startPos: { row: 0, col: 0 },
endPos: { row: 0, col: 0 },
});
}
break;
}

Expand All @@ -51,14 +71,18 @@ class Magenta {
this.context.logger.trace(`current message: ${message}`);
if (!message) return;

await this.chat.addMessage("user", message);
this.chat.dispatch({
type: "add-message",
role: "user",
content: message,
});

await this.sendMessage();
break;
}

case "clear":
this.chat.clear();
this.chat.dispatch({ type: "clear" });
break;

default:
Expand All @@ -67,14 +91,22 @@ class Magenta {
}

private async sendMessage() {
const messages = this.chat.getMessages();
const state = this.chat.getState();
if (state.status != "running") {
this.context.logger.error(`chat is not running.`);
return;
}

const messages = getMessages(state.model);

const currentMessage = await this.chat.addMessage("assistant", "");
const toolRequests = await this.anthropicClient.sendMessage(
messages,
async (text) => {
(text) => {
this.context.logger.trace(`stream received text ${text}`);
await currentMessage.appendText(text);
this.chat.dispatch({
type: "stream-response",
text,
});
},
);

Expand All @@ -95,18 +127,17 @@ class Magenta {
}

this.moderator.registerProcess(process);
await currentMessage.addToolUse(request, process);
this.chat.dispatch({
type: "add-tool-use",
request,
process,
});
}
}
}

static async init(plugin: NvimPlugin, logger: Logger) {
const chat = await Chat.init({ nvim: plugin.nvim, logger });
return new Magenta({ nvim: plugin.nvim, logger }, chat);
}
}

let init: { magenta: Promise<Magenta>; logger: Logger } | undefined = undefined;
let init: { magenta: Magenta; logger: Logger } | undefined = undefined;

module.exports = (plugin: NvimPlugin) => {
plugin.setOptions({});
Expand All @@ -119,7 +150,7 @@ module.exports = (plugin: NvimPlugin) => {
});

init = {
magenta: Magenta.init(plugin, logger),
magenta: new Magenta({ nvim: plugin.nvim, logger }),
logger,
};
}
Expand All @@ -128,7 +159,7 @@ module.exports = (plugin: NvimPlugin) => {
"Magenta",
async (args: string[]) => {
try {
const magenta = await init!.magenta;
const magenta = init!.magenta;
await magenta.command(args);
} catch (err) {
init!.logger.error(err as Error);
Expand Down
47 changes: 26 additions & 21 deletions rplugin/node/magenta/src/sidebar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ import { Logger } from "./logger.js";
export class Sidebar {
private state:
| {
state: "not-loaded";
state: "hidden";
}
| {
state: "loaded";
visible: boolean;
state: "visible";
displayBuffer: Buffer;
inputBuffer: Buffer;
displayWindow: Window;
Expand All @@ -21,24 +20,26 @@ export class Sidebar {
private nvim: Neovim,
private logger: Logger,
) {
this.state = { state: "not-loaded" };
this.state = { state: "hidden" };
// TODO: also probably need to set up some autocommands to detect if the user closes the scratch buffers
}

/** returns the input buffer when it was created
/** returns buffers when they are visible
*/
async toggle(displayBuffer: Buffer): Promise<void> {
if (this.state.state == "not-loaded") {
await this.create(displayBuffer);
async toggle(): Promise<
{ displayBuffer: Buffer; inputBuffer: Buffer } | undefined
> {
if (this.state.state == "hidden") {
return await this.create();
} else {
if (this.state.visible) {
await this.hide();
} else {
await this.show();
}
await this.destroy();
}
}

private async create(displayBuffer: Buffer): Promise<Buffer> {
private async create(): Promise<{
displayBuffer: Buffer;
inputBuffer: Buffer;
}> {
const { nvim, logger } = this;
logger.trace(`sidebar.create`);
const totalHeight = (await nvim.getOption("lines")) as number;
Expand All @@ -49,6 +50,7 @@ export class Sidebar {

await nvim.command("leftabove vsplit");
const displayWindow = await nvim.window;
const displayBuffer = (await this.nvim.createBuffer(false, true)) as Buffer;
displayWindow.width = width;
await nvim.lua(
`vim.api.nvim_win_set_buf(${displayWindow.id}, ${displayBuffer.id})`,
Expand Down Expand Up @@ -93,20 +95,23 @@ export class Sidebar {

logger.trace(`sidebar.create setting state`);
this.state = {
state: "loaded",
visible: true,
state: "visible",
displayBuffer,
inputBuffer,
displayWindow,
inputWindow,
};

return inputBuffer;
return { displayBuffer, inputBuffer };
}

async hide() {}
async destroy() {
this.state = {
state: "hidden",
};

async show() {}
// TODO: clean up buffers
}

async scrollTop() {
// const { displayWindow } = await this.getWindowIfVisible();
Expand All @@ -125,7 +130,7 @@ export class Sidebar {
displayWindow?: Window;
inputWindow?: Window;
}> {
if (this.state.state != "loaded") {
if (this.state.state != "visible") {
return {};
}

Expand All @@ -140,7 +145,7 @@ export class Sidebar {
}

async getMessage(): Promise<string> {
if (this.state.state != "loaded") {
if (this.state.state != "visible") {
this.logger.trace(`sidebar state is ${this.state.state} in getMessage`);
return "";
}
Expand Down
Loading

0 comments on commit 4226b30

Please sign in to comment.