Skip to content

Commit

Permalink
feat: support openrouter, cohere engine (#946)
Browse files Browse the repository at this point in the history
  • Loading branch information
marknguyen1302 authored Jul 31, 2024
1 parent 779907f commit 0511475
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 2 deletions.
1 change: 1 addition & 0 deletions cortex-js/src/domain/abstracts/oai.abstract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ export abstract class OAIEngineExtension extends EngineExtension {
}),
);


if (!response) {
throw new Error('No response');
}
Expand Down
111 changes: 111 additions & 0 deletions cortex-js/src/extensions/cohere.engine.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import stream from 'stream';
import { HttpService } from '@nestjs/axios';
import { OAIEngineExtension } from '../domain/abstracts/oai.abstract';
import { ConfigsUsecases } from '@/usecases/configs/configs.usecase';
import { EventEmitter2 } from '@nestjs/event-emitter';
import _ from 'lodash';
import { EngineStatus } from '@/domain/abstracts/engine.abstract';
import { ChatCompletionMessage } from '@/infrastructure/dtos/chat/chat-completion-message.dto';

enum RoleType {
user = 'USER',
chatbot = 'CHATBOT',
system = 'SYSTEM',
}

type CoherePayloadType = {
chat_history?: Array<{ role: RoleType; message: string }>
message?: string
preamble?: string
}

/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/
export default class CoHereEngineExtension extends OAIEngineExtension {
apiUrl = 'https://api.cohere.ai/v1/chat';
name = 'cohere';
productName = 'Cohere Inference Engine';
description = 'This extension enables Cohere chat completion API calls';
version = '0.0.1';
apiKey?: string;

constructor(
protected readonly httpService: HttpService,
protected readonly configsUsecases: ConfigsUsecases,
protected readonly eventEmmitter: EventEmitter2,
) {
super(httpService);

eventEmmitter.on('config.updated', async (data) => {
if (data.engine === this.name) {
this.apiKey = data.value;
this.status =
(this.apiKey?.length ?? 0) > 0
? EngineStatus.READY
: EngineStatus.MISSING_CONFIGURATION;
}
});
}

async onLoad() {
const configs = (await this.configsUsecases.getGroupConfigs(
this.name,
)) as unknown as { apiKey: string };
this.apiKey = configs?.apiKey;
this.status =
(this.apiKey?.length ?? 0) > 0
? EngineStatus.READY
: EngineStatus.MISSING_CONFIGURATION;
}

transformPayload = (payload: any): CoherePayloadType => {
console.log('payload', payload)
if (payload.messages.length === 0) {
return {}
}

const { messages, ...params } = payload;
const convertedData: CoherePayloadType = {
...params,
chat_history: [],
message: '',
};
(messages as ChatCompletionMessage[]).forEach((item: ChatCompletionMessage, index: number) => {
// Assign the message of the last item to the `message` property
if (index === messages.length - 1) {
convertedData.message = item.content as string
return
}
if (item.role === 'user') {
convertedData.chat_history!!.push({
role: 'USER' as RoleType,
message: item.content as string,
})
} else if (item.role === 'assistant') {
convertedData.chat_history!!.push({
role: 'CHATBOT' as RoleType,
message: item.content as string,
})
} else if (item.role === 'system') {
convertedData.preamble = item.content as string
}
})
return convertedData
}

transformResponse = (data: any) => {
const text = typeof data === 'object' ? data.text : JSON.parse(data).text ?? ''
return JSON.stringify({
choices: [
{
delta: {
content: text,
},
},
],
});
}
}
4 changes: 4 additions & 0 deletions cortex-js/src/extensions/extensions.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import { ConfigsUsecases } from '@/usecases/configs/configs.usecase';
import { ConfigsModule } from '@/usecases/configs/configs.module';
import { EventEmitter2 } from '@nestjs/event-emitter';
import AnthropicEngineExtension from './anthropic.engine';
import OpenRouterEngineExtension from './openrouter.engine';
import CoHereEngineExtension from './cohere.engine';

const provider = {
provide: 'EXTENSIONS_PROVIDER',
Expand All @@ -20,6 +22,8 @@ const provider = {
new GroqEngineExtension(httpService, configUsecases, eventEmitter),
new MistralEngineExtension(httpService, configUsecases, eventEmitter),
new AnthropicEngineExtension(httpService, configUsecases, eventEmitter),
new OpenRouterEngineExtension(httpService, configUsecases, eventEmitter),
new CoHereEngineExtension(httpService, configUsecases, eventEmitter),
],
};

Expand Down
58 changes: 58 additions & 0 deletions cortex-js/src/extensions/openrouter.engine.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import stream from 'stream';
import { HttpService } from '@nestjs/axios';
import { OAIEngineExtension } from '../domain/abstracts/oai.abstract';
import { ConfigsUsecases } from '@/usecases/configs/configs.usecase';
import { EventEmitter2 } from '@nestjs/event-emitter';
import _ from 'lodash';
import { EngineStatus } from '@/domain/abstracts/engine.abstract';

/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/
export default class OpenRouterEngineExtension extends OAIEngineExtension {
apiUrl = 'https://openrouter.ai/api/v1/chat/completions';
name = 'openrouter';
productName = 'OpenRouter Inference Engine';
description = 'This extension enables OpenRouter chat completion API calls';
version = '0.0.1';
apiKey?: string;

constructor(
protected readonly httpService: HttpService,
protected readonly configsUsecases: ConfigsUsecases,
protected readonly eventEmmitter: EventEmitter2,
) {
super(httpService);

eventEmmitter.on('config.updated', async (data) => {
if (data.engine === this.name) {
this.apiKey = data.value;
this.status =
(this.apiKey?.length ?? 0) > 0
? EngineStatus.READY
: EngineStatus.MISSING_CONFIGURATION;
}
});
}

async onLoad() {
const configs = (await this.configsUsecases.getGroupConfigs(
this.name,
)) as unknown as { apiKey: string };
this.apiKey = configs?.apiKey;
this.status =
(this.apiKey?.length ?? 0) > 0
? EngineStatus.READY
: EngineStatus.MISSING_CONFIGURATION;
}

transformPayload = (data: any): any => {
return {
...data,
model:"openrouter/auto",
}
};

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import { downloadProgress } from '@/utils/download-progress';
import { CortexClient } from '../services/cortex.client';
import { DownloadType } from '@/domain/models/download.interface';
import ora from 'ora';
import { isLocalFile } from '@/utils/urls';
import { isRemoteEngine } from '@/utils/normalize-model-id';

@SubCommand({
name: 'pull',
Expand Down Expand Up @@ -70,6 +70,7 @@ export class ModelPullCommand extends BaseCommand {

// Pull engine if not exist
if (
!isRemoteEngine(engine) &&
!existsSync(join(await this.fileService.getCortexCppEnginePath(), engine))
) {
console.log('\n');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ export enum Engines {
mistral = 'mistral',
openai = 'openai',
anthropic = 'anthropic',
openrouter = 'openrouter',
cohere = 'cohere',
}

export const EngineNamesMap: {
Expand All @@ -23,4 +25,6 @@ export const RemoteEngines: Engines[] = [
Engines.mistral,
Engines.openai,
Engines.anthropic,
Engines.openrouter,
Engines.cohere,
];
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ export class EnginesController {
description: 'The unique identifier of the engine.',
})
@Patch(':name(*)')
update(@Param('name') name: string, @Body() configs: ConfigUpdateDto) {
update(@Param('name') name: string, @Body() configs?: any | undefined) {
console.log('configs', configs)
return this.enginesUsecases.updateConfigs(
configs.config,
configs.value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ export class DownloadManagerService {

writer.on('finish', () => {
try {
if (timeoutId) clearTimeout(timeoutId);
// delete the abort controller
delete this.abortControllers[downloadId][destination];
const currentDownloadState = this.allDownloadStates.find(
Expand All @@ -210,6 +211,7 @@ export class DownloadManagerService {
});
writer.on('error', (error) => {
try {
if (timeoutId) clearTimeout(timeoutId);
this.handleError(error, downloadId, destination);
} finally {
bar.stop();
Expand Down

0 comments on commit 0511475

Please sign in to comment.