Skip to content

Commit

Permalink
Factors out some standalone functions to simplify generate() (#1173)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbleigh authored Nov 5, 2024
1 parent 5c03e9f commit 1c77685
Showing 1 changed file with 78 additions and 65 deletions.
143 changes: 78 additions & 65 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import {
ModelMiddleware,
ModelReference,
Part,
ToolDefinition,
} from './model.js';
import { ExecutablePrompt } from './prompt.js';
import { resolveTools, ToolArgument, toToolDefinition } from './tool.js';
Expand Down Expand Up @@ -135,20 +134,25 @@ async function resolveModel(
options: GenerateOptions
): Promise<ResolvedModel> {
let model = options.model;
let out: ResolvedModel;
let modelId: string;

if (!model) {
throw new Error('Model is required.');
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: 'Must supply a `model` to `generate()` calls.',
});
}
if (typeof model === 'string') {
return {
modelAction: (await registry.lookupAction(
`/model/${model}`
)) as ModelAction,
};
modelId = model;
out = { modelAction: await registry.lookupAction(`/model/${model}`) };
} else if (model.hasOwnProperty('__action')) {
return { modelAction: model as ModelAction };
modelId = (model as ModelAction).__action.name;
out = { modelAction: model as ModelAction };
} else {
const ref = model as ModelReference<any>;
return {
modelId = ref.name;
out = {
modelAction: (await registry.lookupAction(
`/model/${ref.name}`
)) as ModelAction,
Expand All @@ -158,6 +162,15 @@ async function resolveModel(
version: ref.version,
};
}

if (!out.modelAction) {
throw new GenkitError({
status: 'NOT_FOUND',
message: `Model ${modelId} not found`,
});
}

return out;
}

export class GenerationResponseError extends GenkitError {
Expand All @@ -180,6 +193,59 @@ export class GenerationResponseError extends GenkitError {
}
}

async function toolsToActionRefs(
registry: Registry,
toolOpt?: ToolArgument[]
): Promise<string[] | undefined> {
if (!toolOpt) return;

let tools: string[] = [];

for (const t of toolOpt) {
if (typeof t === 'string') {
tools.push(await resolveFullToolName(registry, t));
} else if ((t as Action).__action) {
tools.push(
`/${(t as Action).__action.metadata?.type}/${(t as Action).__action.name}`
);
} else if (typeof (t as ExecutablePrompt).asTool === 'function') {
const promptToolAction = (t as ExecutablePrompt).asTool();
tools.push(`/prompt/${promptToolAction.__action.name}`);
} else if (t.name) {
tools.push(await resolveFullToolName(registry, t.name));
} else {
throw new Error(`Unable to determine type of tool: ${JSON.stringify(t)}`);
}
}
return tools;
}

function messagesFromOptions(options: GenerateOptions): MessageData[] {
const messages: MessageData[] = [];
if (options.system) {
messages.push({
role: 'system',
content: Message.parseContent(options.system),
});
}
if (options.messages) {
messages.push(...options.messages);
}
if (options.prompt) {
messages.push({
role: 'user',
content: Message.parseContent(options.prompt),
});
}
if (messages.length === 0) {
throw new GenkitError({
status: 'INVALID_ARGUMENT',
message: 'at least one message is required in generate request',
});
}
return messages;
}

/** A GenerationBlockedError is thrown when a generation is blocked. */
export class GenerationBlockedError extends GenerationResponseError {}

Expand All @@ -206,66 +272,13 @@ export async function generate<
const resolvedOptions: GenerateOptions<O, CustomOptions> =
await Promise.resolve(options);
const resolvedModel = await resolveModel(registry, resolvedOptions);
const model = resolvedModel.modelAction;
if (!model) {
let modelId: string;
if (typeof resolvedOptions.model === 'string') {
modelId = resolvedOptions.model;
} else if ((resolvedOptions.model as ModelAction)?.__action?.name) {
modelId = (resolvedOptions.model as ModelAction).__action.name;
} else {
modelId = (resolvedOptions.model as ModelReference<any>).name;
}
throw new Error(`Model ${modelId} not found`);
}

// convert tools to action refs (strings).
let tools: (string | ToolDefinition)[] | undefined;
if (resolvedOptions.tools) {
tools = [];
for (const t of resolvedOptions.tools) {
if (typeof t === 'string') {
tools.push(await resolveFullToolName(registry, t));
} else if ((t as Action).__action) {
tools.push(
`/${(t as Action).__action.metadata?.type}/${(t as Action).__action.name}`
);
} else if (typeof (t as ExecutablePrompt).asTool === 'function') {
const promptToolAction = (t as ExecutablePrompt).asTool();
tools.push(`/prompt/${promptToolAction.__action.name}`);
} else if (t.name) {
tools.push(await resolveFullToolName(registry, t.name));
} else {
throw new Error(
`Unable to determine type of of tool: ${JSON.stringify(t)}`
);
}
}
}

const messages: MessageData[] = [];
if (resolvedOptions.system) {
messages.push({
role: 'system',
content: Message.parseContent(resolvedOptions.system),
});
}
if (resolvedOptions.messages) {
messages.push(...resolvedOptions.messages);
}
if (resolvedOptions.prompt) {
messages.push({
role: 'user',
content: Message.parseContent(resolvedOptions.prompt),
});
}
const tools = await toolsToActionRefs(registry, resolvedOptions.tools);

if (messages.length === 0) {
throw new Error('at least one message is required in generate request');
}
const messages: MessageData[] = messagesFromOptions(resolvedOptions);

const params: z.infer<typeof GenerateUtilParamSchema> = {
model: model.__action.name,
model: resolvedModel.modelAction.__action.name,
docs: resolvedOptions.docs,
messages,
tools,
Expand Down

0 comments on commit 1c77685

Please sign in to comment.