Skip to content

Commit

Permalink
refactor everything to be in one model, use thunks for sendmessage, r…
Browse files Browse the repository at this point in the history
…eference toolManager from parts
  • Loading branch information
dlants committed Dec 16, 2024
1 parent 3dcaf39 commit 5b3a8da
Show file tree
Hide file tree
Showing 11 changed files with 251 additions and 198 deletions.
12 changes: 11 additions & 1 deletion rplugin/node/magenta/src/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import Anthropic from "@anthropic-ai/sdk";
import { context } from "./context.ts";
import { TOOL_SPECS, ToolRequest } from "./tools/toolManager.ts";

export class AnthropicClient {
class AnthropicClient {
private client: Anthropic;

constructor() {
Expand Down Expand Up @@ -71,3 +71,13 @@ export class AnthropicClient {
return toolRequests;
}
}

let client: AnthropicClient | undefined;

// lazy load so we have a chance to init context before constructing the class
export function getClient() {
if (!client) {
client = new AnthropicClient();
}
return client;
}
182 changes: 130 additions & 52 deletions rplugin/node/magenta/src/chat/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,17 @@ import {
view as messageView,
} from "./message.ts";
import { ToolModel } from "../tools/toolManager.ts";
import { Dispatch, Update } from "../tea/tea.ts";
import {
Dispatch,
parallelThunks,
Thunk,
Update,
wrapThunk,
} from "../tea/tea.ts";
import { d, View, withBindings } from "../tea/view.ts";
import { context } from "../context.ts";
import * as ToolManager from "../tools/toolManager.ts";
import { getClient } from "../anthropic.ts";

export type Role = "user" | "assistant";

Expand All @@ -25,8 +33,16 @@ export type ChatState =
state: "awaiting-tool-use";
};

export function initModel(): Model {
return {
messages: [],
toolManager: ToolManager.initModel(),
};
}

export type Model = {
messages: Message[];
toolManager: ToolManager.Model;
};

export type Msg =
Expand All @@ -45,20 +61,18 @@ export type Msg =
text: string;
}
| {
type: "add-tool-use";
toolModel: ToolModel;
type: "init-tool-use";
request: ToolManager.ToolRequest;
}
| {
type: "add-tool-response";
toolModel: ToolModel;
response: ToolResultBlockParam;
type: "send-message";
}
| {
type: "tool-model-update";
toolModel: ToolModel;
type: "clear";
}
| {
type: "clear";
type: "tool-manager-msg";
msg: ToolManager.Msg;
};

export const update: Update<Msg, Model> = (msg, model) => {
Expand All @@ -80,9 +94,27 @@ export const update: Update<Msg, Model> = (msg, model) => {
return [model];
}

case "send-message": {
const lastMessage = model.messages[model.messages.length - 1];
if (lastMessage && lastMessage.role == "user") {
return [model, sendMessage(model)];
}
return [model];
}

case "message-msg": {
const [nextMessage] = updateMessage(msg.msg, model.messages[msg.idx]);
model.messages[msg.idx] = nextMessage;

if (msg.msg.type == "tool-manager-msg") {
const [nextToolManager, toolManagerThunk] = ToolManager.update(
msg.msg.msg,
model.toolManager,
);
model.toolManager = nextToolManager;
return [model, wrapThunk("tool-manager-msg", toolManagerThunk)];
}

return [model];
}

Expand All @@ -104,72 +136,117 @@ export const update: Update<Msg, Model> = (msg, model) => {
return [model];
}

case "add-tool-use": {
case "init-tool-use": {
const [nextToolManager, toolManagerThunk] = ToolManager.update(
{ type: "init-tool-use", request: msg.request },
model.toolManager,
);
model.toolManager = nextToolManager;

const lastMessage = model.messages[model.messages.length - 1];
if (lastMessage?.role !== "assistant") {
model.messages.push({
role: "assistant",
parts: [],
});
}

const [nextMessage] = updateMessage(
{ type: "add-tool-use", toolModel: msg.toolModel },
{ type: "add-tool-use", requestId: msg.request.id },
model.messages[model.messages.length - 1],
);
model.messages[model.messages.length - 1] = nextMessage;
return [model];
return [model, wrapThunk("tool-manager-msg", toolManagerThunk)];
}

case "add-tool-response": {
let lastMessage = model.messages[model.messages.length - 1];
if (lastMessage?.role !== "user") {
lastMessage = {
role: "user",
parts: [],
};
model.messages.push(lastMessage);
}

const [next] = updateMessage(
{
type: "add-tool-response",
toolModel: msg.toolModel,
response: msg.response,
},
lastMessage,
case "tool-manager-msg": {
const [nextToolManager, toolManagerThunk] = ToolManager.update(
msg.msg,
model.toolManager,
);
model.messages.splice(model.messages.length - 1, 1, next);
return [model];
}
model.toolManager = nextToolManager;
let nextModel = model;
let thunk: Thunk<Msg> | undefined = wrapThunk(
"tool-manager-msg",
toolManagerThunk,
);
if (msg.msg.type == "tool-msg" && msg.msg.msg.msg.type == "finish") {
const toolModel = nextToolManager.toolModels[msg.msg.id];

case "tool-model-update": {
const nextMessages: Model["messages"] = [];
for (const message of model.messages) {
const [next] = updateMessage(
{
type: "tool-model-update",
toolModel: msg.toolModel,
},
message,
);
const response = msg.msg.msg.msg.result;
[nextModel] = addToolResponse(model, toolModel, response);
if (toolModel.autoRespond) {
let shouldRespond = true;
for (const tool of Object.values(model.toolManager.toolModels)) {
if (tool.state.state != "done") {
shouldRespond = false;
break;
}
}

nextMessages.push(next);
if (shouldRespond) {
thunk = parallelThunks<Msg>(thunk, sendMessage(model));
}
}
}

return [
{
...model,
messages: nextMessages,
},
];
return [nextModel, thunk];
}

case "clear": {
return [{ messages: [] }];
return [initModel()];
}
}
};

function sendMessage(model: Model): Thunk<Msg> {
return async function (dispatch: Dispatch<Msg>) {
const messages = getMessages(model);

const toolRequests = await getClient().sendMessage(messages, (text) => {
context.logger.trace(`stream received text ${text}`);
dispatch({
type: "stream-response",
text,
});
});

if (toolRequests.length) {
for (const request of toolRequests) {
dispatch({
type: "init-tool-use",
request,
});
}
}
};
}

function addToolResponse(
model: Model,
toolModel: ToolModel,
response: ToolResultBlockParam,
): [Model] {
let lastMessage = model.messages[model.messages.length - 1];
if (lastMessage?.role !== "user") {
lastMessage = {
role: "user",
parts: [],
};
model.messages.push(lastMessage);
}

const [next] = updateMessage(
{
type: "add-tool-response",
requestId: toolModel.request.id,
response,
},
lastMessage,
);
model.messages.splice(model.messages.length - 1, 1, next);
return [model];
}

export const view: View<{ model: Model; dispatch: Dispatch<Msg> }> = ({
model,
dispatch,
Expand All @@ -179,6 +256,7 @@ export const view: View<{ model: Model; dispatch: Dispatch<Msg> }> = ({
(m, idx) =>
d`${messageView({
model: m,
toolManager: model.toolManager,
dispatch: (msg) => {
dispatch({ type: "message-msg", msg, idx });
},
Expand All @@ -191,6 +269,6 @@ export const view: View<{ model: Model; dispatch: Dispatch<Msg> }> = ({
export function getMessages(model: Model): Anthropic.MessageParam[] {
return model.messages.map((msg) => ({
role: msg.role,
content: msg.parts.map(toMessageParam),
content: msg.parts.map((p) => toMessageParam(p, model.toolManager)),
}));
}
36 changes: 16 additions & 20 deletions rplugin/node/magenta/src/chat/message.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ToolResultBlockParam } from "@anthropic-ai/sdk/resources/index.mjs";
import { Model as Part, view as partView } from "./part.ts";
import { ToolModel } from "../tools/toolManager.ts";
import * as ToolManager from "../tools/toolManager.ts";
import { Role } from "./chat.ts";
import { Dispatch, Update } from "../tea/tea.ts";
import { assertUnreachable } from "../utils/assertUnreachable.ts";
Expand All @@ -18,16 +18,16 @@ export type Msg =
}
| {
type: "add-tool-use";
toolModel: ToolModel;
requestId: ToolManager.ToolRequestId;
}
| {
type: "add-tool-response";
toolModel: ToolModel;
requestId: ToolManager.ToolRequestId;
response: ToolResultBlockParam;
}
| {
type: "tool-model-update";
toolModel: ToolModel;
type: "tool-manager-msg";
msg: ToolManager.Msg;
};

export const update: Update<Msg, Model> = (msg, model) => {
Expand All @@ -48,27 +48,20 @@ export const update: Update<Msg, Model> = (msg, model) => {
case "add-tool-use":
model.parts.push({
type: "tool-request",
toolModel: msg.toolModel,
requestId: msg.requestId,
});
break;

case "add-tool-response":
model.parts.push({
type: "tool-response",
toolModel: msg.toolModel,
requestId: msg.requestId,
response: msg.response,
});
break;

case "tool-model-update": {
for (const part of model.parts) {
if (
(part.type == "tool-request" || part.type == "tool-response") &&
part.toolModel.request.id == msg.toolModel.request.id
) {
part.toolModel = msg.toolModel;
}
}
case "tool-manager-msg": {
// do nothing. This will be handled by the tool manager
return [model];
}

Expand All @@ -78,9 +71,12 @@ export const update: Update<Msg, Model> = (msg, model) => {
return [model];
};

export const view: View<{ model: Model; dispatch: Dispatch<Msg> }> = ({
model,
}) =>
export const view: View<{
model: Model;
toolManager: ToolManager.Model;
dispatch: Dispatch<Msg>;
}> = ({ model, toolManager, dispatch }) =>
d`### ${model.role}:\n${model.parts.map(
(part) => d`${partView({ model: part })}\n`,
(part) =>
d`${partView({ model: part, toolManager, dispatch: (msg) => dispatch({ type: "tool-manager-msg", msg }) })}\n`,
)}`;
Loading

0 comments on commit 5b3a8da

Please sign in to comment.