Skip to content

Commit

Permalink
Ollama: use the new Chat endpoint. Closes #270
Browse files Browse the repository at this point in the history
  • Loading branch information
enricoros committed Dec 11, 2023
1 parent d0ea96e commit 11055b1
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 29 deletions.
45 changes: 35 additions & 10 deletions src/modules/llms/transports/server/ollama/ollama.router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ import { capitalizeFirstLetter } from '~/common/util/textUtils';
import { fixupHost, openAIChatGenerateOutputSchema, OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router';
import { listModelsOutputSchema, ModelDescriptionSchema } from '../server.schemas';

import { OLLAMA_BASE_MODELS, OLLAMA_LAST_UPDATE } from './ollama.models';
import { wireOllamaGenerationSchema } from './ollama.wiretypes';
import { OLLAMA_BASE_MODELS, OLLAMA_PREV_UPDATE } from './ollama.models';
import { WireOllamaChatCompletionInput, wireOllamaChunkedOutputSchema } from './ollama.wiretypes';


// Default hosts
const DEFAULT_OLLAMA_HOST = 'http://127.0.0.1:11434';
export const OLLAMA_PATH_CHAT = '/api/chat';
const OLLAMA_PATH_TAGS = '/api/tags';
const OLLAMA_PATH_SHOW = '/api/show';


// Mappers
Expand All @@ -34,7 +37,23 @@ export function ollamaAccess(access: OllamaAccessSchema, apiPath: string): { hea

}

export function ollamaChatCompletionPayload(model: OpenAIModelSchema, history: OpenAIHistorySchema, stream: boolean) {

export const ollamaChatCompletionPayload = (model: OpenAIModelSchema, history: OpenAIHistorySchema, stream: boolean): WireOllamaChatCompletionInput => ({
model: model.id,
messages: history,
options: {
...(model.temperature && { temperature: model.temperature }),
},
// n: ...
// functions: ...
// function_call: ...
stream,
});


/* Unused: switched to the Chat endpoint (above). The implementation is left here for reference.
https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion
export function ollamaCompletionPayload(model: OpenAIModelSchema, history: OpenAIHistorySchema, stream: boolean) {
// if the first message is the system prompt, extract it
let systemPrompt: string | undefined = undefined;
Expand Down Expand Up @@ -62,7 +81,7 @@ export function ollamaChatCompletionPayload(model: OpenAIModelSchema, history: O
...(systemPrompt && { system: systemPrompt }),
stream,
};
}
}*/

async function ollamaGET<TOut extends object>(access: OllamaAccessSchema, apiPath: string /*, signal?: AbortSignal*/): Promise<TOut> {
const { headers, url } = ollamaAccess(access, apiPath);
Expand Down Expand Up @@ -104,6 +123,7 @@ const listPullableOutputSchema = z.object({
label: z.string(),
tag: z.string(),
description: z.string(),
pulls: z.number(),
isNew: z.boolean(),
})),
});
Expand All @@ -122,7 +142,8 @@ export const llmOllamaRouter = createTRPCRouter({
label: capitalizeFirstLetter(model_id),
tag: 'latest',
description: model.description,
isNew: !!model.added && model.added >= OLLAMA_LAST_UPDATE,
pulls: model.pulls,
isNew: !!model.added && model.added >= OLLAMA_PREV_UPDATE,
})),
};
}),
Expand Down Expand Up @@ -160,14 +181,15 @@ export const llmOllamaRouter = createTRPCRouter({
throw new Error('Ollama delete issue: ' + deleteOutput);
}),


/* Ollama: List the Models available */
listModels: publicProcedure
.input(accessOnlySchema)
.output(listModelsOutputSchema)
.query(async ({ input }) => {

// get the models
const wireModels = await ollamaGET(input.access, '/api/tags');
const wireModels = await ollamaGET(input.access, OLLAMA_PATH_TAGS);
const wireOllamaListModelsSchema = z.object({
models: z.array(z.object({
name: z.string(),
Expand All @@ -180,7 +202,7 @@ export const llmOllamaRouter = createTRPCRouter({

// retrieve info for each of the models (/api/show, post call, in parallel)
const detailedModels = await Promise.all(models.map(async model => {
const wireModelInfo = await ollamaPOST(input.access, { 'name': model.name }, '/api/show');
const wireModelInfo = await ollamaPOST(input.access, { 'name': model.name }, OLLAMA_PATH_SHOW);
const wireOllamaModelInfoSchema = z.object({
license: z.string().optional(),
modelfile: z.string(),
Expand Down Expand Up @@ -221,12 +243,15 @@ export const llmOllamaRouter = createTRPCRouter({
.output(openAIChatGenerateOutputSchema)
.mutation(async ({ input: { access, history, model } }) => {

const wireGeneration = await ollamaPOST(access, ollamaChatCompletionPayload(model, history, false), '/api/generate');
const generation = wireOllamaGenerationSchema.parse(wireGeneration);
const wireGeneration = await ollamaPOST(access, ollamaChatCompletionPayload(model, history, false), OLLAMA_PATH_CHAT);
const generation = wireOllamaChunkedOutputSchema.parse(wireGeneration);

if (!generation.message?.content)
throw new Error('Ollama chat generation (non-stream) issue: ' + JSON.stringify(wireGeneration));

return {
role: 'assistant',
content: generation.response,
content: generation.message.content,
finish_reason: generation.done ? 'stop' : null,
};
}),
Expand Down
65 changes: 59 additions & 6 deletions src/modules/llms/transports/server/ollama/ollama.wiretypes.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,69 @@
import { z } from 'zod';

export const wireOllamaGenerationSchema = z.object({

/**
* Chat Completion API - Request
* https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion
*/
const wireOllamaChatCompletionInputSchema = z.object({

// required
model: z.string(),
messages: z.array(z.object({
role: z.enum(['assistant', 'system', 'user']),
content: z.string(),
})),

// optional
format: z.enum(['json']).optional(),
options: z.object({
// https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md
// Maximum number of tokens to predict when generating text.
num_predict: z.number().int().optional(),
// Sets the random number seed to use for generation
seed: z.number().int().optional(),
// The temperature of the model
temperature: z.number().positive().optional(),
// Reduces the probability of generating nonsense (Default: 40)
top_k: z.number().positive().optional(),
// Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text. (Default 0.9)
top_p: z.number().positive().optional(),
}).optional(),
template: z.string().optional(), // overrides what is defined in the Modelfile
stream: z.boolean().optional(), // default: true

// Future Improvements?
// n: z.number().int().optional(), // number of completions to generate
// functions: ...
// function_call: ...
});
export type WireOllamaChatCompletionInput = z.infer<typeof wireOllamaChatCompletionInputSchema>;


/**
* Chat Completion or Generation APIs - Streaming Response
*/
export const wireOllamaChunkedOutputSchema = z.object({
model: z.string(),
// created_at: z.string(), // commented because unused
response: z.string(),

// [Chat Completion] (exclusive with 'response')
message: z.object({
role: z.enum(['assistant' /*, 'system', 'user' Disabled on purpose, to validate the response */]),
content: z.string(),
}).optional(), // optional on the last message

// [Generation] (non-chat, exclusive with 'message')
//response: z.string().optional(),

done: z.boolean(),

// only on the last message
// context: z.array(z.number()),
// context: z.array(z.number()), // non-chat endpoint
// total_duration: z.number(),
// load_duration: z.number(),
// eval_duration: z.number(),
// prompt_eval_count: z.number(),
// prompt_eval_duration: z.number(),
// eval_count: z.number(),
});
// eval_duration: z.number(),

});
32 changes: 19 additions & 13 deletions src/modules/llms/transports/server/openai/openai.streaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import { createEmptyReadableStream, debugGenerateCurlCommand, safeErrorString, S

import type { AnthropicWire } from '../anthropic/anthropic.wiretypes';
import type { OpenAIWire } from './openai.wiretypes';
import { OLLAMA_PATH_CHAT, ollamaAccess, ollamaAccessSchema, ollamaChatCompletionPayload } from '../ollama/ollama.router';
import { anthropicAccess, anthropicAccessSchema, anthropicChatCompletionPayload } from '../anthropic/anthropic.router';
import { ollamaAccess, ollamaAccessSchema, ollamaChatCompletionPayload } from '../ollama/ollama.router';
import { openAIAccess, openAIAccessSchema, openAIChatCompletionPayload, openAIHistorySchema, openAIModelSchema } from './openai.router';
import { wireOllamaGenerationSchema } from '../ollama/ollama.wiretypes';
import { wireOllamaChunkedOutputSchema } from '../ollama/ollama.wiretypes';


/**
Expand Down Expand Up @@ -59,10 +59,10 @@ export async function openaiStreamingRelayHandler(req: NextRequest): Promise<Res
break;

case 'ollama':
headersUrl = ollamaAccess(access, '/api/generate');
headersUrl = ollamaAccess(access, OLLAMA_PATH_CHAT);
body = ollamaChatCompletionPayload(model, history, true);
eventStreamFormat = 'json-nl';
vendorStreamParser = createOllamaStreamParser();
vendorStreamParser = createOllamaChatCompletionStreamParser();
break;

case 'azure':
Expand Down Expand Up @@ -135,30 +135,35 @@ function createAnthropicStreamParser(): AIStreamParser {
};
}

function createOllamaStreamParser(): AIStreamParser {
function createOllamaChatCompletionStreamParser(): AIStreamParser {
let hasBegun = false;

return (data: string) => {

let wireGeneration: any;
// parse the JSON chunk
let wireJsonChunk: any;
try {
wireGeneration = JSON.parse(data);
wireJsonChunk = JSON.parse(data);
} catch (error: any) {
// log the malformed data to the console, and rethrow to transmit as 'error'
console.log(`/api/llms/stream: Ollama parsing issue: ${error?.message || error}`, data);
throw error;
}
const generation = wireOllamaGenerationSchema.parse(wireGeneration);
let text = generation.response;

// validate chunk
const chunk = wireOllamaChunkedOutputSchema.parse(wireJsonChunk);

// process output
let text = chunk.message?.content || /*chunk.response ||*/ '';

// hack: prepend the model name to the first packet
if (!hasBegun) {
if (!hasBegun && chunk.model) {
hasBegun = true;
const firstPacket: ChatStreamFirstPacketSchema = { model: generation.model };
const firstPacket: ChatStreamFirstPacketSchema = { model: chunk.model };
text = JSON.stringify(firstPacket) + text;
}

return { text, close: generation.done };
return { text, close: chunk.done };
};
}

Expand Down Expand Up @@ -248,7 +253,8 @@ function createEventStreamTransformer(vendorTextParser: AIStreamParser, inputFor
if (close)
controller.terminate();
} catch (error: any) {
// console.log(`/api/llms/stream: parse issue: ${error?.message || error}`);
if (SERVER_DEBUG_WIRE)
console.log(' - E: parse issue:', event.data, error?.message || error);
controller.enqueue(textEncoder.encode(`[Stream Issue] ${dialectLabel}: ${safeErrorString(error) || 'Unknown stream parsing error'}`));
controller.terminate();
}
Expand Down

0 comments on commit 11055b1

Please sign in to comment.