From 786024ca86c2937cc1ab26dee7ca0a72af92252f Mon Sep 17 00:00:00 2001 From: Denis Lantsman Date: Wed, 18 Dec 2024 14:46:41 -0800 Subject: [PATCH] add replace tool --- rplugin/node/magenta/src/tools/diff.ts | 75 ++++++- rplugin/node/magenta/src/tools/insert.ts | 27 ++- rplugin/node/magenta/src/tools/replace.ts | 209 ++++++++++++++++++ rplugin/node/magenta/src/tools/toolManager.ts | 77 ++++++- 4 files changed, 365 insertions(+), 23 deletions(-) create mode 100644 rplugin/node/magenta/src/tools/replace.ts diff --git a/rplugin/node/magenta/src/tools/diff.ts b/rplugin/node/magenta/src/tools/diff.ts index 84ba4c3..c68f1c7 100644 --- a/rplugin/node/magenta/src/tools/diff.ts +++ b/rplugin/node/magenta/src/tools/diff.ts @@ -1,16 +1,34 @@ import { context } from "../context.ts"; import { Buffer } from "neovim"; import { WIDTH } from "../sidebar.ts"; +import { assertUnreachable } from "../utils/assertUnreachable.ts"; +import { Dispatch } from "../tea/tea.ts"; -type Edit = { - type: "insert-after"; - insertAfter: string; - content: string; +type Edit = + | { + type: "insert-after"; + insertAfter: string; + content: string; + } + | { + type: "replace"; + start: string; + end: string; + content: string; + }; + +type Msg = { + type: "error"; + error: string; }; /** Helper to bring up an editing interface for the given file path. */ -export async function displayDiffs(filePath: string, edits: Edit[]) { +export async function displayDiffs( + filePath: string, + edits: Edit[], + dispatch: Dispatch, +) { const { nvim } = context; // first, check to see if any windows *other than* the magenta plugin windows are open, and close them. @@ -48,12 +66,47 @@ export async function displayDiffs(filePath: string, edits: Edit[]) { let content: string = lines.join("\n"); for (const edit of edits) { - const insertLocation = - content.indexOf(edit.insertAfter) + edit.insertAfter.length; - content = - content.slice(0, insertLocation) + - edit.content + - content.slice(insertLocation); + switch (edit.type) { + case "insert-after": { + const insertLocation = + content.indexOf(edit.insertAfter) + edit.insertAfter.length; + content = + content.slice(0, insertLocation) + + edit.content + + content.slice(insertLocation); + break; + } + + case "replace": { + const insertStart = content.indexOf(edit.start); + const insertEnd = content.indexOf(edit.end); + + if (insertStart == -1) { + dispatch({ + type: "error", + error: `Unable to find start location of string ${edit.start} in file ${filePath}`, + }); + continue; + } + + if (insertEnd == -1) { + dispatch({ + type: "error", + error: `Unable to find end location of string ${edit.end} in file ${filePath}`, + }); + continue; + } + content = + content.slice(0, insertStart) + + edit.content + + content.slice(insertEnd + edit.end.length); + + break; + } + + default: + assertUnreachable(edit); + } } const scratchBuffer = (await nvim.createBuffer(false, true)) as Buffer; diff --git a/rplugin/node/magenta/src/tools/insert.ts b/rplugin/node/magenta/src/tools/insert.ts index fd79c0e..5f04ec3 100644 --- a/rplugin/node/magenta/src/tools/insert.ts +++ b/rplugin/node/magenta/src/tools/insert.ts @@ -76,13 +76,26 @@ export function insertThunk(model: Model) { const request = model.request; return async (dispatch: Dispatch) => { try { - await displayDiffs(request.input.filePath, [ - { - type: "insert-after", - insertAfter: request.input.insertAfter, - content: request.input.content, - }, - ]); + await displayDiffs( + request.input.filePath, + [ + { + type: "insert-after", + insertAfter: request.input.insertAfter, + content: request.input.content, + }, + ], + (msg) => + dispatch({ + type: "finish", + result: { + type: "tool_result", + tool_use_id: model.request.id, + content: msg.error, + is_error: true, + }, + }), + ); } catch (error) { dispatch({ type: "finish", diff --git a/rplugin/node/magenta/src/tools/replace.ts b/rplugin/node/magenta/src/tools/replace.ts new file mode 100644 index 0000000..e19a104 --- /dev/null +++ b/rplugin/node/magenta/src/tools/replace.ts @@ -0,0 +1,209 @@ +import * as Anthropic from "@anthropic-ai/sdk"; +import { assertUnreachable } from "../utils/assertUnreachable.ts"; +import { ToolResultBlockParam } from "@anthropic-ai/sdk/resources/index.mjs"; +import { Dispatch, Update } from "../tea/tea.ts"; +import { d, VDOMNode, withBindings } from "../tea/view.ts"; +import { ToolRequestId } from "./toolManager.ts"; +import { displayDiffs } from "./diff.ts"; + +export type Model = { + type: "replace"; + autoRespond: boolean; + request: ReplaceToolRequest; + state: + | { + state: "pending-user-action"; + } + | { + state: "editing-diff"; + } + | { + state: "done"; + result: ToolResultBlockParam; + }; +}; + +export type Msg = + | { + type: "finish"; + result: ToolResultBlockParam; + } + | { + type: "display-diff"; + }; + +export const update: Update = (msg, model) => { + switch (msg.type) { + case "finish": + return [ + { + ...model, + state: { + state: "done", + result: msg.result, + }, + }, + ]; + case "display-diff": + return [ + { + ...model, + state: { + state: "pending-user-action", + }, + }, + insertThunk(model), + ]; + default: + assertUnreachable(msg); + } +}; + +export function initModel(request: ReplaceToolRequest): [Model] { + const model: Model = { + type: "replace", + autoRespond: false, + request, + state: { + state: "pending-user-action", + }, + }; + + return [model]; +} + +export function insertThunk(model: Model) { + const request = model.request; + return async (dispatch: Dispatch) => { + try { + await displayDiffs( + request.input.filePath, + [ + { + type: "replace", + start: request.input.start, + end: request.input.end, + content: request.input.content, + }, + ], + (msg) => + dispatch({ + type: "finish", + result: { + type: "tool_result", + tool_use_id: model.request.id, + content: msg.error, + is_error: true, + }, + }), + ); + } catch (error) { + dispatch({ + type: "finish", + result: { + type: "tool_result", + tool_use_id: request.id, + content: `Error: ${(error as Error).message}`, + is_error: true, + }, + }); + } + }; +} + +export function view({ + model, + dispatch, +}: { + model: Model; + dispatch: Dispatch; +}): VDOMNode { + return d`Insert ${( + model.request.input.content.match(/\n/g) || [] + ).length.toString()} into file ${model.request.input.filePath} +${toolStatusView({ model, dispatch })}`; +} + +function toolStatusView({ + model, + dispatch, +}: { + model: Model; + dispatch: Dispatch; +}): VDOMNode { + switch (model.state.state) { + case "pending-user-action": + return withBindings(d`[review diff]`, { + Enter: () => + dispatch({ + type: "display-diff", + }), + }); + case "editing-diff": + return d`Editing diff`; + case "done": + return d`Done`; + } +} + +export function getToolResult(model: Model): ToolResultBlockParam { + switch (model.state.state) { + case "editing-diff": + return { + type: "tool_result", + tool_use_id: model.request.id, + content: `The user is reviewing the change. Please proceed with your answer or address other parts of the question.`, + }; + case "pending-user-action": + return { + type: "tool_result", + tool_use_id: model.request.id, + content: `Waiting for a user action to finish processing this tool use. Please proceed with your answer or address other parts of the question.`, + }; + case "done": + return model.state.result; + default: + assertUnreachable(model.state); + } +} + +export const spec: Anthropic.Anthropic.Tool = { + name: "insert", + description: "Replace text between two strings in a file.", + input_schema: { + type: "object", + properties: { + filePath: { + type: "string", + description: "Path of the file to modify.", + }, + start: { + type: "string", + description: + "We will replace text starting with this string. This string is included in the text that is replaced. Please provide a minimal string that uniquely identifies a location in the file.", + }, + end: { + type: "string", + description: + "We will replace text until we encounter this string. This string is included in the text that is replaced. Please provide a minimal string that uniquely identifies a location in the file.", + }, + content: { + type: "string", + description: "Content to insert", + }, + }, + required: ["filePath", "start", "end", "content"], + }, +}; + +export type ReplaceToolRequest = { + type: "tool_use"; + id: ToolRequestId; + name: "replace"; + input: { + filePath: string; + start: string; + end: string; + content: string; + }; +}; diff --git a/rplugin/node/magenta/src/tools/toolManager.ts b/rplugin/node/magenta/src/tools/toolManager.ts index 2b69d59..df81958 100644 --- a/rplugin/node/magenta/src/tools/toolManager.ts +++ b/rplugin/node/magenta/src/tools/toolManager.ts @@ -1,22 +1,24 @@ import * as GetFile from "./getFile.ts"; import * as Insert from "./insert.ts"; +import * as Replace from "./replace.ts"; import { Dispatch, Update } from "../tea/tea.ts"; import { assertUnreachable } from "../utils/assertUnreachable.ts"; import { ToolResultBlockParam } from "@anthropic-ai/sdk/resources/messages.mjs"; export type ToolRequest = | GetFile.GetFileToolUseRequest - | Insert.InsertToolUseRequest; + | Insert.InsertToolUseRequest + | Replace.ReplaceToolRequest; -export type ToolModel = GetFile.Model | Insert.Model; +export type ToolModel = GetFile.Model | Insert.Model | Replace.Model; export type ToolRequestId = string & { __toolRequestId: true }; -export const TOOL_SPECS = [GetFile.spec, Insert.spec]; +export const TOOL_SPECS = [GetFile.spec, Insert.spec, Replace.spec]; export type Model = { toolModels: { - [id: ToolRequestId]: GetFile.Model | Insert.Model; + [id: ToolRequestId]: ToolModel; }; }; @@ -26,6 +28,9 @@ export function getToolResult(model: ToolModel): ToolResultBlockParam { return GetFile.getToolResult(model); case "insert": return Insert.getToolResult(model); + case "replace": + return Replace.getToolResult(model); + default: return assertUnreachable(model); } @@ -45,6 +50,17 @@ export function renderTool(model: ToolModel, dispatch: Dispatch) { msg: { type: "insert", msg }, }), }); + case "replace": + return Replace.view({ + model, + dispatch: (msg) => + dispatch({ + type: "tool-msg", + id: model.request.id, + msg: { type: "insert", msg }, + }), + }); + default: assertUnreachable(model); } @@ -53,7 +69,10 @@ export function renderTool(model: ToolModel, dispatch: Dispatch) { export type Msg = | { type: "init-tool-use"; - request: GetFile.GetFileToolUseRequest | Insert.InsertToolUseRequest; + request: + | GetFile.GetFileToolUseRequest + | Insert.InsertToolUseRequest + | Replace.ReplaceToolRequest; } | { type: "tool-msg"; @@ -66,6 +85,10 @@ export type Msg = | { type: "insert"; msg: Insert.Msg; + } + | { + type: "replace"; + msg: Replace.Msg; }; }; @@ -116,6 +139,20 @@ export const update: Update = (msg, model) => { }, ]; } + + case "replace": { + const [insertModel] = Replace.initModel(request); + return [ + { + ...model, + toolModels: { + ...model.toolModels, + [request.id]: insertModel, + }, + }, + ]; + } + default: return assertUnreachable(request); } @@ -188,6 +225,36 @@ export const update: Update = (msg, model) => { ]; } + case "replace": { + const [nextToolModel, thunk] = Replace.update( + msg.msg.msg, + toolModel as Replace.Model, + ); + + return [ + { + ...model, + toolModels: { + ...model.toolModels, + [msg.id]: nextToolModel, + }, + }, + thunk + ? (dispatch) => + thunk((innerMsg) => + dispatch({ + type: "tool-msg", + id: msg.id, + msg: { + type: "replace", + msg: innerMsg, + }, + }), + ) + : undefined, + ]; + } + default: return assertUnreachable(msg.msg); }