diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index b7b6289d4..f656ea222 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -37,7 +37,6 @@ import { ModelMiddleware, ModelReference, Part, - ToolDefinition, } from './model.js'; import { ExecutablePrompt } from './prompt.js'; import { resolveTools, ToolArgument, toToolDefinition } from './tool.js'; @@ -135,20 +134,25 @@ async function resolveModel( options: GenerateOptions ): Promise { let model = options.model; + let out: ResolvedModel; + let modelId: string; + if (!model) { - throw new Error('Model is required.'); + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: 'Must supply a `model` to `generate()` calls.', + }); } if (typeof model === 'string') { - return { - modelAction: (await registry.lookupAction( - `/model/${model}` - )) as ModelAction, - }; + modelId = model; + out = { modelAction: await registry.lookupAction(`/model/${model}`) }; } else if (model.hasOwnProperty('__action')) { - return { modelAction: model as ModelAction }; + modelId = (model as ModelAction).__action.name; + out = { modelAction: model as ModelAction }; } else { const ref = model as ModelReference; - return { + modelId = ref.name; + out = { modelAction: (await registry.lookupAction( `/model/${ref.name}` )) as ModelAction, @@ -158,6 +162,15 @@ async function resolveModel( version: ref.version, }; } + + if (!out.modelAction) { + throw new GenkitError({ + status: 'NOT_FOUND', + message: `Model ${modelId} not found`, + }); + } + + return out; } export class GenerationResponseError extends GenkitError { @@ -180,6 +193,59 @@ export class GenerationResponseError extends GenkitError { } } +async function toolsToActionRefs( + registry: Registry, + toolOpt?: ToolArgument[] +): Promise { + if (!toolOpt) return; + + let tools: string[] = []; + + for (const t of toolOpt) { + if (typeof t === 'string') { + tools.push(await resolveFullToolName(registry, t)); + } else if ((t as Action).__action) { + tools.push( + `/${(t as Action).__action.metadata?.type}/${(t as Action).__action.name}` + ); + } else if (typeof (t as ExecutablePrompt).asTool === 'function') { + const promptToolAction = (t as ExecutablePrompt).asTool(); + tools.push(`/prompt/${promptToolAction.__action.name}`); + } else if (t.name) { + tools.push(await resolveFullToolName(registry, t.name)); + } else { + throw new Error(`Unable to determine type of tool: ${JSON.stringify(t)}`); + } + } + return tools; +} + +function messagesFromOptions(options: GenerateOptions): MessageData[] { + const messages: MessageData[] = []; + if (options.system) { + messages.push({ + role: 'system', + content: Message.parseContent(options.system), + }); + } + if (options.messages) { + messages.push(...options.messages); + } + if (options.prompt) { + messages.push({ + role: 'user', + content: Message.parseContent(options.prompt), + }); + } + if (messages.length === 0) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: 'at least one message is required in generate request', + }); + } + return messages; +} + /** A GenerationBlockedError is thrown when a generation is blocked. */ export class GenerationBlockedError extends GenerationResponseError {} @@ -206,66 +272,13 @@ export async function generate< const resolvedOptions: GenerateOptions = await Promise.resolve(options); const resolvedModel = await resolveModel(registry, resolvedOptions); - const model = resolvedModel.modelAction; - if (!model) { - let modelId: string; - if (typeof resolvedOptions.model === 'string') { - modelId = resolvedOptions.model; - } else if ((resolvedOptions.model as ModelAction)?.__action?.name) { - modelId = (resolvedOptions.model as ModelAction).__action.name; - } else { - modelId = (resolvedOptions.model as ModelReference).name; - } - throw new Error(`Model ${modelId} not found`); - } - - // convert tools to action refs (strings). - let tools: (string | ToolDefinition)[] | undefined; - if (resolvedOptions.tools) { - tools = []; - for (const t of resolvedOptions.tools) { - if (typeof t === 'string') { - tools.push(await resolveFullToolName(registry, t)); - } else if ((t as Action).__action) { - tools.push( - `/${(t as Action).__action.metadata?.type}/${(t as Action).__action.name}` - ); - } else if (typeof (t as ExecutablePrompt).asTool === 'function') { - const promptToolAction = (t as ExecutablePrompt).asTool(); - tools.push(`/prompt/${promptToolAction.__action.name}`); - } else if (t.name) { - tools.push(await resolveFullToolName(registry, t.name)); - } else { - throw new Error( - `Unable to determine type of of tool: ${JSON.stringify(t)}` - ); - } - } - } - const messages: MessageData[] = []; - if (resolvedOptions.system) { - messages.push({ - role: 'system', - content: Message.parseContent(resolvedOptions.system), - }); - } - if (resolvedOptions.messages) { - messages.push(...resolvedOptions.messages); - } - if (resolvedOptions.prompt) { - messages.push({ - role: 'user', - content: Message.parseContent(resolvedOptions.prompt), - }); - } + const tools = await toolsToActionRefs(registry, resolvedOptions.tools); - if (messages.length === 0) { - throw new Error('at least one message is required in generate request'); - } + const messages: MessageData[] = messagesFromOptions(resolvedOptions); const params: z.infer = { - model: model.__action.name, + model: resolvedModel.modelAction.__action.name, docs: resolvedOptions.docs, messages, tools,