Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anthropic Tool Support #1594

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
0c60707
support anthropic PDF beta
evalstate Nov 18, 2024
a0e1b5f
Merge remote-tracking branch 'upstream/main' into feature/anthropic-p…
evalstate Nov 18, 2024
84aac49
upstream merge, remove commented out console log line
evalstate Nov 18, 2024
67d7295
Merge branch 'main' into feature/anthropic-pdf-beta
evalstate Nov 18, 2024
92bf923
Fixing type errors.
evalstate Nov 18, 2024
16637d2
Merge branch 'main' into feature/anthropic-pdf-beta
nsarrazin Nov 18, 2024
3e34b63
changed document processor to async (matching image processor)
evalstate Nov 18, 2024
4c67d1c
Merge remote-tracking branch 'upstream/main' into feature/anthropic-p…
evalstate Nov 22, 2024
36a1cc3
use the beta api types rather than custom extension
evalstate Nov 22, 2024
3cbc5de
Merge remote-tracking branch 'upstream/main' into feature/anthropic-p…
evalstate Nov 25, 2024
786b576
rudimentary tool testing
evalstate Nov 25, 2024
e66ba8d
Merge branch 'main' of https://github.com/huggingface/chat-ui into fe…
evalstate Nov 25, 2024
da22402
interim commit (tool re-passing, file handling)
evalstate Nov 26, 2024
b69d18e
Merge branch 'feature/anthropic-pdf-beta' into feature/anthropic-tool…
evalstate Nov 26, 2024
506ecff
remove merge error
evalstate Nov 26, 2024
3c3d282
Merge branch 'feature/anthropic-pdf-beta' of https://github.com/barre…
evalstate Nov 26, 2024
07764f5
Merge branch 'feature/anthropic-pdf-beta' into feature/anthropic-tool…
evalstate Nov 26, 2024
b233de7
tidy up, isolate beta classes to utils
evalstate Nov 26, 2024
ee6107a
anthropic tool calling support.
evalstate Nov 26, 2024
a8bac58
improve handling of directlyAnswer tool
evalstate Nov 26, 2024
87b57f3
fix streaming
evalstate Nov 26, 2024
0c9abdf
slight tidy up to tools flow handling
evalstate Nov 27, 2024
d3ceb35
Merge remote-tracking branch 'upstream/main' into feature/anthropic-t…
evalstate Nov 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
"zod": "^3.22.3"
},
"optionalDependencies": {
"@anthropic-ai/sdk": "^0.25.0",
"@anthropic-ai/sdk": "^0.32.1",
"@anthropic-ai/vertex-sdk": "^0.4.1",
"@aws-sdk/client-bedrock-runtime": "^3.631.0",
"@google-cloud/vertexai": "^1.1.0",
Expand Down
141 changes: 126 additions & 15 deletions src/lib/server/endpoints/anthropic/endpointAnthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@ import type { Endpoint } from "../endpoints";
import { env } from "$env/dynamic/private";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import { createImageProcessorOptionsValidator } from "../images";
import { endpointMessagesToAnthropicMessages } from "./utils";
import { endpointMessagesToAnthropicMessages, addToolResults } from "./utils";
import { createDocumentProcessorOptionsValidator } from "../document";
import type {
Tool,
ToolCall,
ToolInput,
ToolInputFile,
ToolInputFixed,
ToolInputOptional,
} from "$lib/types/Tool";
import type Anthropic from "@anthropic-ai/sdk";
import type { MessageParam } from "@anthropic-ai/sdk/resources/messages.mjs";
import directlyAnswer from "$lib/server/tools/directlyAnswer";

export const endpointAnthropicParametersSchema = z.object({
weight: z.number().int().positive().default(1),
Expand All @@ -23,6 +35,10 @@ export const endpointAnthropicParametersSchema = z.object({
maxWidth: 4096,
maxHeight: 4096,
}),
document: createDocumentProcessorOptionsValidator({
supportedMimeTypes: ["application/pdf"],
maxSizeInMB: 32,
}),
})
.default({}),
});
Expand All @@ -46,7 +62,14 @@ export async function endpointAnthropic(
defaultQuery,
});

return async ({ messages, preprompt, generateSettings }) => {
return async ({
messages,
preprompt,
generateSettings,
conversationId,
tools = [],
toolResults = [],
}) => {
let system = preprompt;
if (messages?.[0]?.from === "system") {
system = messages[0].content;
Expand All @@ -59,7 +82,13 @@ export async function endpointAnthropic(
return (async function* () {
const stream = anthropic.messages.stream({
model: model.id ?? model.name,
messages: await endpointMessagesToAnthropicMessages(messages, multimodal),
tools: createAnthropicTools(tools),
tool_choice:
tools.length > 0 ? { type: "auto", disable_parallel_tool_use: false } : undefined,
messages: addToolResults(
await endpointMessagesToAnthropicMessages(messages, multimodal, conversationId),
toolResults
) as MessageParam[],
max_tokens: parameters?.max_new_tokens,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
Expand All @@ -70,21 +99,40 @@ export async function endpointAnthropic(
while (true) {
const result = await Promise.race([stream.emitted("text"), stream.emitted("end")]);

// Stream end
if (result === undefined) {
yield {
token: {
id: tokenId++,
text: "",
logprob: 0,
special: true,
},
generated_text: await stream.finalText(),
details: null,
} satisfies TextGenerationStreamOutput;
if ("tool_use" === stream.receivedMessages[0].stop_reason) {
// this should really create a new "Assistant" message with the tool id in it.
const toolCalls: ToolCall[] = stream.receivedMessages[0].content
.filter(
(block): block is Anthropic.Messages.ContentBlock & { type: "tool_use" } =>
block.type === "tool_use"
)
.map((block) => ({
name: block.name,
parameters: block.input as Record<string, string | number | boolean>,
id: block.id,
}));

yield {
token: { id: tokenId, text: "", logprob: 0, special: false, toolCalls },
generated_text: null,
details: null,
};
} else {
yield {
token: {
id: tokenId++,
text: "",
logprob: 0,
special: true,
},
generated_text: await stream.finalText(),
details: null,
} satisfies TextGenerationStreamOutput;
}

return;
}

// Text delta
yield {
token: {
Expand All @@ -100,3 +148,66 @@ export async function endpointAnthropic(
})();
};
}

function createAnthropicTools(tools: Tool[]): Anthropic.Messages.Tool[] {
return tools
.filter((tool) => tool.name !== directlyAnswer.name)
.map((tool) => {
const properties = tool.inputs.reduce((acc, input) => {
acc[input.name] = convertToolInputToJSONSchema(input);
return acc;
}, {} as Record<string, unknown>);

const required = tool.inputs
.filter((input) => input.paramType === "required")
.map((input) => input.name);

return {
name: tool.name,
description: tool.description,
input_schema: {
type: "object",
properties,
required: required.length > 0 ? required : undefined,
},
};
});
}

function convertToolInputToJSONSchema(input: ToolInput): Record<string, unknown> {
const baseSchema: Record<string, unknown> = {};
if ("description" in input) {
baseSchema["description"] = input.description || "";
}
switch (input.paramType) {
case "optional":
baseSchema["default"] = (input as ToolInputOptional).default;
break;
case "fixed":
baseSchema["const"] = (input as ToolInputFixed).value;
break;
}

if (input.type === "file") {
baseSchema["type"] = "string";
baseSchema["format"] = "binary";
baseSchema["mimeTypes"] = (input as ToolInputFile).mimeTypes;
} else {
switch (input.type) {
case "str":
baseSchema["type"] = "string";
break;
case "int":
baseSchema["type"] = "integer";
break;
case "float":
baseSchema["type"] = "number";
break;
case "bool":
baseSchema["type"] = "boolean";
break;
}
}

return baseSchema;
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { Endpoint } from "../endpoints";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import { createImageProcessorOptionsValidator } from "../images";
import { endpointMessagesToAnthropicMessages } from "./utils";
import type { MessageParam } from "@anthropic-ai/sdk/resources/messages.mjs";

export const endpointAnthropicVertexParametersSchema = z.object({
weight: z.number().int().positive().default(1),
Expand Down Expand Up @@ -56,7 +57,10 @@ export async function endpointAnthropicVertex(
return (async function* () {
const stream = anthropic.messages.stream({
model: model.id ?? model.name,
messages: await endpointMessagesToAnthropicMessages(messages, multimodal),
messages: (await endpointMessagesToAnthropicMessages(
messages,
multimodal
)) as MessageParam[],
max_tokens: model.parameters?.max_new_tokens,
temperature: model.parameters?.temperature,
top_p: model.parameters?.top_p,
Expand Down
94 changes: 85 additions & 9 deletions src/lib/server/endpoints/anthropic/utils.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
import { makeImageProcessor, type ImageProcessorOptions } from "../images";
import { makeDocumentProcessor, type FileProcessorOptions } from "../document";
import type { EndpointMessage } from "../endpoints";
import type { MessageFile } from "$lib/types/Message";
import type { ImageBlockParam, MessageParam } from "@anthropic-ai/sdk/resources/messages.mjs";
import type {
BetaImageBlockParam,
BetaMessageParam,
BetaBase64PDFBlock,
} from "@anthropic-ai/sdk/resources/beta/messages/messages.mjs";
import type { ToolResult } from "$lib/types/Tool";
import { downloadFile } from "$lib/server/files/downloadFile";
import type { ObjectId } from "mongodb";

export async function fileToImageBlock(
file: MessageFile,
opts: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">
): Promise<ImageBlockParam> {
): Promise<BetaImageBlockParam> {
const processor = makeImageProcessor(opts);

const { image, mime } = await processor(file);

return {
Expand All @@ -20,25 +29,92 @@ export async function fileToImageBlock(
};
}

type NonSystemMessage = EndpointMessage & { from: "user" | "assistant" };
export async function fileToDocumentBlock(
file: MessageFile,
opts: FileProcessorOptions<"application/pdf">
): Promise<BetaBase64PDFBlock> {
const processor = makeDocumentProcessor(opts);
const { file: document, mime } = await processor(file);

return {
type: "document",
source: {
type: "base64",
media_type: mime,
data: document.toString("base64"),
},
};
}

type NonSystemMessage = EndpointMessage & { from: "user" | "assistant" };
export async function endpointMessagesToAnthropicMessages(
messages: EndpointMessage[],
multimodal: { image: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp"> }
): Promise<MessageParam[]> {
multimodal: {
image: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">;
document?: FileProcessorOptions<"application/pdf">;
},
conversationId?: ObjectId | undefined
): Promise<BetaMessageParam[]> {
return await Promise.all(
messages
.filter((message): message is NonSystemMessage => message.from !== "system")
.map<Promise<MessageParam>>(async (message) => {
.map<Promise<BetaMessageParam>>(async (message) => {
return {
role: message.from,
content: [
...(await Promise.all(
(message.files ?? []).map((file) => fileToImageBlock(file, multimodal.image))
)),
...(message.from === "user"
? await Promise.all(
(message.files ?? []).map(async (file) => {
if (file.type === "hash" && conversationId) {
file = await downloadFile(file.value, conversationId);
}

if (file.mime.startsWith("image/")) {
return fileToImageBlock(file, multimodal.image);
} else if (file.mime === "application/pdf" && multimodal.document) {
return fileToDocumentBlock(file, multimodal.document);
} else {
throw new Error(`Unsupported file type: ${file.mime}`);
}
})
)
: []),
{ type: "text", text: message.content },
],
};
})
);
}

export function addToolResults(
messages: BetaMessageParam[],
toolResults: ToolResult[]
): BetaMessageParam[] {
const id = crypto.randomUUID();
if (toolResults.length === 0) {
return messages;
}
return [
...messages,
{
role: "assistant",
content: toolResults.map((result, index) => ({
type: "tool_use",
id: `tool_${index}_${id}`,
name: result.call.name,
input: result.call.parameters,
})),
},
{
role: "user",
content: toolResults.map((result, index) => ({
type: "tool_result",
tool_use_id: `tool_${index}_${id}`,
is_error: result.status === "error",
content: JSON.stringify(
result.status === "error" ? result.message : "outputs" in result ? result.outputs : ""
),
})),
},
];
}
13 changes: 9 additions & 4 deletions src/lib/server/endpoints/document.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,28 @@ export type DocumentProcessor<TMimeType extends string = string> = (file: Messag
mime: TMimeType;
};

export type AsyncDocumentProcessor<TMimeType extends string = string> = (
file: MessageFile
) => Promise<{
file: Buffer;
mime: TMimeType;
}>;

export function makeDocumentProcessor<TMimeType extends string = string>(
options: FileProcessorOptions<TMimeType>
): DocumentProcessor<TMimeType> {
return (file) => {
): AsyncDocumentProcessor<TMimeType> {
return async (file) => {
const { supportedMimeTypes, maxSizeInMB } = options;
const { mime, value } = file;

const buffer = Buffer.from(value, "base64");

const tooLargeInBytes = buffer.byteLength > maxSizeInMB * 1000 * 1000;

if (tooLargeInBytes) {
throw Error("Document is too large");
}

const outputMime = validateMimeType(supportedMimeTypes, mime);

return { file: buffer, mime: outputMime };
};
}
Expand Down
Loading