Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factors out some standalone functions to simplify generate() #1173

Merged
merged 4 commits into from
Nov 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 78 additions & 65 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import {
ModelMiddleware,
ModelReference,
Part,
ToolDefinition,
} from './model.js';
import { ExecutablePrompt } from './prompt.js';
import { resolveTools, ToolArgument, toToolDefinition } from './tool.js';
Expand Down Expand Up @@ -135,20 +134,25 @@ async function resolveModel(
options: GenerateOptions
): Promise<ResolvedModel> {
let model = options.model;
let out: ResolvedModel;
let modelId: string;

if (!model) {
throw new Error('Model is required.');
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: 'Must supply a `model` to `generate()` calls.',
});
}
if (typeof model === 'string') {
return {
modelAction: (await registry.lookupAction(
`/model/${model}`
)) as ModelAction,
};
modelId = model;
out = { modelAction: await registry.lookupAction(`/model/${model}`) };
} else if (model.hasOwnProperty('__action')) {
return { modelAction: model as ModelAction };
modelId = (model as ModelAction).__action.name;
out = { modelAction: model as ModelAction };
} else {
const ref = model as ModelReference<any>;
return {
modelId = ref.name;
out = {
modelAction: (await registry.lookupAction(
`/model/${ref.name}`
)) as ModelAction,
Expand All @@ -158,6 +162,15 @@ async function resolveModel(
version: ref.version,
};
}

if (!out.modelAction) {
throw new GenkitError({
status: 'NOT_FOUND',
message: `Model ${modelId} not found`,
});
}

return out;
}

export class GenerationResponseError extends GenkitError {
Expand All @@ -180,6 +193,59 @@ export class GenerationResponseError extends GenkitError {
}
}

async function toolsToActionRefs(
registry: Registry,
toolOpt?: ToolArgument[]
): Promise<string[] | undefined> {
if (!toolOpt) return;

let tools: string[] = [];

for (const t of toolOpt) {
if (typeof t === 'string') {
tools.push(await resolveFullToolName(registry, t));
} else if ((t as Action).__action) {
tools.push(
`/${(t as Action).__action.metadata?.type}/${(t as Action).__action.name}`
);
} else if (typeof (t as ExecutablePrompt).asTool === 'function') {
const promptToolAction = (t as ExecutablePrompt).asTool();
tools.push(`/prompt/${promptToolAction.__action.name}`);
} else if (t.name) {
tools.push(await resolveFullToolName(registry, t.name));
} else {
throw new Error(`Unable to determine type of tool: ${JSON.stringify(t)}`);
}
}
return tools;
}

function messagesFromOptions(options: GenerateOptions): MessageData[] {
const messages: MessageData[] = [];
if (options.system) {
messages.push({
role: 'system',
content: Message.parseContent(options.system),
});
}
if (options.messages) {
messages.push(...options.messages);
}
if (options.prompt) {
messages.push({
role: 'user',
content: Message.parseContent(options.prompt),
});
}
if (messages.length === 0) {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: 'at least one message is required in generate request',
});
}
return messages;
}

/** A GenerationBlockedError is thrown when a generation is blocked. */
export class GenerationBlockedError extends GenerationResponseError {}

Expand All @@ -206,66 +272,13 @@ export async function generate<
const resolvedOptions: GenerateOptions<O, CustomOptions> =
await Promise.resolve(options);
const resolvedModel = await resolveModel(registry, resolvedOptions);
const model = resolvedModel.modelAction;
if (!model) {
let modelId: string;
if (typeof resolvedOptions.model === 'string') {
modelId = resolvedOptions.model;
} else if ((resolvedOptions.model as ModelAction)?.__action?.name) {
modelId = (resolvedOptions.model as ModelAction).__action.name;
} else {
modelId = (resolvedOptions.model as ModelReference<any>).name;
}
throw new Error(`Model ${modelId} not found`);
}

// convert tools to action refs (strings).
let tools: (string | ToolDefinition)[] | undefined;
if (resolvedOptions.tools) {
tools = [];
for (const t of resolvedOptions.tools) {
if (typeof t === 'string') {
tools.push(await resolveFullToolName(registry, t));
} else if ((t as Action).__action) {
tools.push(
`/${(t as Action).__action.metadata?.type}/${(t as Action).__action.name}`
);
} else if (typeof (t as ExecutablePrompt).asTool === 'function') {
const promptToolAction = (t as ExecutablePrompt).asTool();
tools.push(`/prompt/${promptToolAction.__action.name}`);
} else if (t.name) {
tools.push(await resolveFullToolName(registry, t.name));
} else {
throw new Error(
`Unable to determine type of of tool: ${JSON.stringify(t)}`
);
}
}
}

const messages: MessageData[] = [];
if (resolvedOptions.system) {
messages.push({
role: 'system',
content: Message.parseContent(resolvedOptions.system),
});
}
if (resolvedOptions.messages) {
messages.push(...resolvedOptions.messages);
}
if (resolvedOptions.prompt) {
messages.push({
role: 'user',
content: Message.parseContent(resolvedOptions.prompt),
});
}
const tools = await toolsToActionRefs(registry, resolvedOptions.tools);

if (messages.length === 0) {
throw new Error('at least one message is required in generate request');
}
const messages: MessageData[] = messagesFromOptions(resolvedOptions);

const params: z.infer<typeof GenerateUtilParamSchema> = {
model: model.__action.name,
model: resolvedModel.modelAction.__action.name,
docs: resolvedOptions.docs,
messages,
tools,
Expand Down
Loading