Skip to content

Commit

Permalink
Mistral Platform: full support
Browse files Browse the repository at this point in the history
Closes #273.
  • Loading branch information
enricoros committed Dec 13, 2023
1 parent a265112 commit c0c724a
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 7 deletions.
10 changes: 10 additions & 0 deletions src/common/components/icons/MistralIcon.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import * as React from 'react';

import { SvgIcon } from '@mui/joy';
import { SxProps } from '@mui/joy/styles/types';

export function MistralIcon(props: { sx?: SxProps }) {
return <SvgIcon viewBox='0 0 24 24' width='24' height='24' strokeWidth={0} stroke='none' fill='currentColor' strokeLinecap='butt' strokeLinejoin='miter' {...props}>
<path d='m 2,2 v 4 4 V 14 v 4 4 h 4 v -4 -4 h 4 v 4 h 4 v -4 h 4 v 4 4 h 4 v -4 -4 -4 -4 V 2 h -4 v 4 h -4 v 4 h -4 v -4 H 6 V 2 Z' />
</SvgIcon>;
}
33 changes: 33 additions & 0 deletions src/modules/llms/transports/server/openai/mistral.wiretypes.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import { z } from 'zod';


// [Mistral] Models List API - Response

export const wireMistralModelsListOutputSchema = z.object({
id: z.string(),
object: z.literal('model'),
created: z.number(),
owned_by: z.string(),
root: z.null().optional(),
parent: z.null().optional(),
// permission: z.array(wireMistralModelsListPermissionsSchema)
});

// export type WireMistralModelsListOutput = z.infer<typeof wireMistralModelsListOutputSchema>;

/*
const wireMistralModelsListPermissionsSchema = z.object({
id: z.string(),
object: z.literal('model_permission'),
created: z.number(),
allow_create_engine: z.boolean(),
allow_sampling: z.boolean(),
allow_logprobs: z.boolean(),
allow_search_indices: z.boolean(),
allow_view: z.boolean(),
allow_fine_tuning: z.boolean(),
organization: z.string(),
group: z.null().optional(),
is_blocking: z.boolean()
});
*/
66 changes: 63 additions & 3 deletions src/modules/llms/transports/server/openai/models.data.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import type { ModelDescriptionSchema } from '../server.schemas';
import { LLM_IF_OAI_Chat, LLM_IF_OAI_Complete, LLM_IF_OAI_Fn, LLM_IF_OAI_Vision } from '../../../store-llms';
import { SERVER_DEBUG_WIRE } from '~/server/wire';

import { LLM_IF_OAI_Chat, LLM_IF_OAI_Complete, LLM_IF_OAI_Fn, LLM_IF_OAI_Vision } from '../../../store-llms';

import type { ModelDescriptionSchema } from '../server.schemas';
import { wireMistralModelsListOutputSchema } from './mistral.wiretypes';


// [Azure] / [OpenAI]
const _knownOpenAIChatModels: ManualMappings = [
Expand Down Expand Up @@ -204,6 +207,63 @@ export function localAIModelToModelDescription(modelId: string): ModelDescriptio
}


// [Mistral]

const _knownMistralChatModels: ManualMappings = [
{
idPrefix: 'mistral-medium',
label: 'Mistral Medium',
description: 'Mistral internal prototype model.',
contextWindow: 32768,
interfaces: [LLM_IF_OAI_Chat],
},
{
idPrefix: 'mistral-small',
label: 'Mistral Small',
description: 'Higher reasoning capabilities and more capabilities (English, French, German, Italian, Spanish, and Code)',
contextWindow: 32768,
interfaces: [LLM_IF_OAI_Chat],
},
{
idPrefix: 'mistral-tiny',
label: 'Mistral Tiny',
description: 'Used for large batch processing tasks where cost is a significant factor but reasoning capabilities are not crucial',
contextWindow: 32768,
interfaces: [LLM_IF_OAI_Chat],
},
{
idPrefix: 'mistral-embed',
label: 'Mistral Embed',
description: 'Mistral Medium on Mistral',
// output: 1024 dimensions
maxCompletionTokens: 1024, // HACK - it's 1024 dimensions, but those are not 'completion tokens'
contextWindow: 32768, // actually unknown, assumed from the other models
interfaces: [],
hidden: true,
},
];

export function mistralModelToModelDescription(_model: unknown): ModelDescriptionSchema {
const model = wireMistralModelsListOutputSchema.parse(_model);
return fromManualMapping(_knownMistralChatModels, model.id, model.created, undefined, {
idPrefix: model.id,
label: model.id.replaceAll(/[_-]/g, ' '),
description: 'New Mistral Model',
contextWindow: 32768,
interfaces: [LLM_IF_OAI_Chat], // assume..
hidden: true,
});
}

export function mistralModelsSort(a: ModelDescriptionSchema, b: ModelDescriptionSchema): number {
if (a.hidden && !b.hidden)
return 1;
if (!a.hidden && b.hidden)
return -1;
return a.id.localeCompare(b.id);
}


// [Oobabooga]
const _knownOobaboogaChatModels: ManualMappings = [];

Expand Down Expand Up @@ -346,7 +406,7 @@ export function openRouterModelToModelDescription(modelId: string, created: numb
const orModel = orModelMap[modelId] ?? null;
let label = orModel?.name || modelId.replace('/', ' · ');
if (orModel?.cp === 0 && orModel?.cc === 0)
label += ' · 🎁' // Free? Discounted?
label += ' · 🎁'; // Free? Discounted?

if (SERVER_DEBUG_WIRE && !orModel)
console.log(' - openRouterModelToModelDescription: non-mapped model id:', modelId);
Expand Down
27 changes: 24 additions & 3 deletions src/modules/llms/transports/server/openai/openai.router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ import { Brand } from '~/common/app.config';

import type { OpenAIWire } from './openai.wiretypes';
import { listModelsOutputSchema, ModelDescriptionSchema } from '../server.schemas';
import { localAIModelToModelDescription, oobaboogaModelToModelDescription, openAIModelToModelDescription, openRouterModelFamilySortFn, openRouterModelToModelDescription } from './models.data';
import { localAIModelToModelDescription, mistralModelsSort, mistralModelToModelDescription, oobaboogaModelToModelDescription, openAIModelToModelDescription, openRouterModelFamilySortFn, openRouterModelToModelDescription } from './models.data';


// Input Schemas

const openAIDialects = z.enum(['azure', 'localai', 'oobabooga', 'openai', 'openrouter']);
const openAIDialects = z.enum(['azure', 'localai', 'mistral', 'oobabooga', 'openai', 'openrouter']);

export const openAIAccessSchema = z.object({
dialect: openAIDialects,
Expand Down Expand Up @@ -186,12 +186,18 @@ export const llmOpenAIRouter = createTRPCRouter({
.map((model): ModelDescriptionSchema => openAIModelToModelDescription(model.id, model.created));
break;

case 'mistral':
models = openAIModels
.map(mistralModelToModelDescription)
.sort(mistralModelsSort);
break;

case 'openrouter':
models = openAIModels
.sort(openRouterModelFamilySortFn)
.map(model => openRouterModelToModelDescription(model.id, model.created, (model as any)?.['context_length']));
break;

}

return { models };
Expand Down Expand Up @@ -267,9 +273,10 @@ async function openaiPOST<TOut extends object, TPostBody extends object>(access:
}


const DEFAULT_HELICONE_OPENAI_HOST = 'oai.hconeai.com';
const DEFAULT_MISTRAL_HOST = 'https://api.mistral.ai';
const DEFAULT_OPENAI_HOST = 'api.openai.com';
const DEFAULT_OPENROUTER_HOST = 'https://openrouter.ai/api';
const DEFAULT_HELICONE_OPENAI_HOST = 'oai.hconeai.com';

export function fixupHost(host: string, apiPath: string): string {
if (!host.startsWith('http'))
Expand Down Expand Up @@ -361,6 +368,20 @@ export function openAIAccess(access: OpenAIAccessSchema, modelRefId: string | nu
};


case 'mistral':
// https://docs.mistral.ai/platform/client
const mistralKey = access.oaiKey || env.MISTRAL_API_KEY || '';
const mistralHost = fixupHost(access.oaiHost || DEFAULT_MISTRAL_HOST, apiPath);
return {
headers: {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': `Bearer ${mistralKey}`,
},
url: mistralHost + apiPath,
};


case 'openrouter':
const orKey = access.oaiKey || env.OPENROUTER_API_KEY || '';
const orHost = fixupHost(access.oaiHost || DEFAULT_OPENROUTER_HOST, apiPath);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ export async function openaiStreamingRelayHandler(req: NextRequest): Promise<Res

case 'azure':
case 'localai':
case 'mistral':
case 'oobabooga':
case 'openai':
case 'openrouter':
Expand Down
2 changes: 1 addition & 1 deletion src/modules/llms/vendors/IModelVendor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import type { DLLM, DModelSourceId } from '../store-llms';
import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../transports/chatGenerate';


export type ModelVendorId = 'anthropic' | 'azure' | 'localai' | 'ollama' | 'oobabooga' | 'openai' | 'openrouter';
export type ModelVendorId = 'anthropic' | 'azure' | 'localai' | 'mistral' | 'ollama' | 'oobabooga' | 'openai' | 'openrouter';

export type ModelVendorRegistryType = Record<ModelVendorId, IModelVendor>;

Expand Down
61 changes: 61 additions & 0 deletions src/modules/llms/vendors/mistral/MistralSourceSetup.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import * as React from 'react';

import { FormInputKey } from '~/common/components/forms/FormInputKey';
import { InlineError } from '~/common/components/InlineError';
import { Link } from '~/common/components/Link';
import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton';
import { apiQuery } from '~/common/util/trpc.client';

import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms';
import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup';

import { ModelVendorMistral } from './mistral.vendor';


const MISTRAL_REG_LINK = 'https://console.mistral.ai/';


export function MistralSourceSetup(props: { sourceId: DModelSourceId }) {

// external state
const { source, sourceSetupValid, sourceHasLLMs, access, updateSetup } =
useSourceSetup(props.sourceId, ModelVendorMistral);

// derived state
const { oaiKey: mistralKey } = access;

const needsUserKey = !ModelVendorMistral.hasBackendCap?.();
const shallFetchSucceed = !needsUserKey || (!!mistralKey && sourceSetupValid);
const showKeyError = !!mistralKey && !sourceSetupValid;

// fetch models
const { isFetching, refetch, isError, error } = apiQuery.llmOpenAI.listModels.useQuery({ access }, {
enabled: false,
onSuccess: models => source && useModelsStore.getState().setLLMs(
models.models.map(model => modelDescriptionToDLLM(model, source)),
props.sourceId,
),
staleTime: Infinity,
});

return <>

<FormInputKey
id='mistral-key' label='Mistral Key'
rightLabel={<>{needsUserKey
? !mistralKey && <Link level='body-sm' href={MISTRAL_REG_LINK} target='_blank'>request Key</Link>
: '✔️ already set in server'}
</>}
value={mistralKey} onChange={value => updateSetup({ oaiKey: value })}
required={needsUserKey} isError={showKeyError}
placeholder='...'
/>

<SetupFormRefetchButton
refetch={refetch} disabled={/*!shallFetchSucceed ||*/ isFetching} error={isError}
/>

{isError && <InlineError error={error} />}

</>;
}
57 changes: 57 additions & 0 deletions src/modules/llms/vendors/mistral/mistral.vendor.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { backendCaps } from '~/modules/backend/state-backend';

import { MistralIcon } from '~/common/components/icons/MistralIcon';

import type { IModelVendor } from '../IModelVendor';
import type { OpenAIAccessSchema } from '../../transports/server/openai/openai.router';
import type { VChatMessageIn, VChatMessageOut } from '../../transports/chatGenerate';

import { LLMOptionsOpenAI, openAICallChatGenerate, SourceSetupOpenAI } from '../openai/openai.vendor';
import { OpenAILLMOptions } from '../openai/OpenAILLMOptions';

import { MistralSourceSetup } from './MistralSourceSetup';


// special symbols

export type SourceSetupMistral = Pick<SourceSetupOpenAI, 'oaiKey' | 'oaiHost'>;


/** Implementation Notes for the Mistral vendor
*/
export const ModelVendorMistral: IModelVendor<SourceSetupMistral, OpenAIAccessSchema, LLMOptionsOpenAI> = {
id: 'mistral',
name: 'Mistral',
rank: 15,
location: 'cloud',
instanceLimit: 1,
hasBackendCap: () => backendCaps().hasLlmMistral,

// components
Icon: MistralIcon,
SourceSetupComponent: MistralSourceSetup,
LLMOptionsComponent: OpenAILLMOptions,

// functions
initializeSetup: () => ({
oaiHost: 'https://api.mistral.ai/',
oaiKey: '',
}),
validateSetup: (setup) => {
return setup.oaiKey?.length >= 32;
},
getTransportAccess: (partialSetup): OpenAIAccessSchema => ({
dialect: 'mistral',
oaiKey: partialSetup?.oaiKey || '',
oaiOrg: '',
oaiHost: partialSetup?.oaiHost || '',
heliKey: '',
moderationCheck: false,
}),
callChatGenerate(llm, messages: VChatMessageIn[], maxTokens?: number): Promise<VChatMessageOut> {
return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, null, null, maxTokens);
},
callChatGenerateWF() {
throw new Error('Mistral does not support "Functions" yet');
},
};
2 changes: 2 additions & 0 deletions src/modules/llms/vendors/vendors.registry.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { ModelVendorAnthropic } from './anthropic/anthropic.vendor';
import { ModelVendorAzure } from './azure/azure.vendor';
import { ModelVendorLocalAI } from './localai/localai.vendor';
import { ModelVendorMistral } from './mistral/mistral.vendor';
import { ModelVendorOllama } from './ollama/ollama.vendor';
import { ModelVendorOoobabooga } from './oobabooga/oobabooga.vendor';
import { ModelVendorOpenAI } from './openai/openai.vendor';
Expand All @@ -14,6 +15,7 @@ const MODEL_VENDOR_REGISTRY: ModelVendorRegistryType = {
anthropic: ModelVendorAnthropic,
azure: ModelVendorAzure,
localai: ModelVendorLocalAI,
mistral: ModelVendorMistral,
ollama: ModelVendorOllama,
oobabooga: ModelVendorOoobabooga,
openai: ModelVendorOpenAI,
Expand Down

1 comment on commit c0c724a

@vercel
Copy link

@vercel vercel bot commented on c0c724a Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

big-agi – ./

big-agi-enricoros.vercel.app
big-agi-git-main-stable-enricoros.vercel.app
get.big-agi.com

Please sign in to comment.