Skip to content

Commit

Permalink
feat: add cortex embeddings command
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-menlo committed Jun 6, 2024
1 parent 6200cce commit 25057f5
Show file tree
Hide file tree
Showing 14 changed files with 262 additions and 48 deletions.
2 changes: 2 additions & 0 deletions cortex-js/src/command.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import { FileManagerModule } from './file-manager/file-manager.module';
import { PSCommand } from './infrastructure/commanders/ps.command';
import { KillCommand } from './infrastructure/commanders/kill.command';
import { PresetCommand } from './infrastructure/commanders/presets.command';
import { EmbeddingCommand } from './infrastructure/commanders/embeddings.command';

@Module({
imports: [
Expand Down Expand Up @@ -54,6 +55,7 @@ import { PresetCommand } from './infrastructure/commanders/presets.command';
PSCommand,
KillCommand,
PresetCommand,
EmbeddingCommand,

// Questions
InitRunModeQuestions,
Expand Down
5 changes: 3 additions & 2 deletions cortex-js/src/infrastructure/commanders/chat.command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ export class ChatCommand extends CommandRunner {
async run(_input: string[], options: ChatOptions): Promise<void> {
let modelId = _input[0];
// First attempt to get message from input or options
let message = _input[1] ?? options.message;
// Extract input from 1 to end of array
let message = options.message ?? _input.slice(1).join(' ');

// Check for model existing
if (!modelId || !(await this.modelsUsecases.findOne(modelId))) {
// Model ID is not provided
// first input might be message input
message = _input[0] ?? options.message;
message = _input.length ? _input.join(' ') : options.message ?? '';
// If model ID is not provided, prompt user to select from running models
const models = await this.psCliUsecases.getModels();
if (models.length === 1) {
Expand Down
22 changes: 0 additions & 22 deletions cortex-js/src/infrastructure/commanders/constants/huggingface.ts

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { PSCommand } from './ps.command';
import { KillCommand } from './kill.command';
import pkg from '@/../package.json';
import { PresetCommand } from './presets.command';
import { EmbeddingCommand } from './embeddings.command';

interface CortexCommandOptions {
version: boolean;
Expand All @@ -24,6 +25,7 @@ interface CortexCommandOptions {
PSCommand,
KillCommand,
PresetCommand,
EmbeddingCommand,
],
description: 'Cortex CLI',
})
Expand Down
101 changes: 101 additions & 0 deletions cortex-js/src/infrastructure/commanders/embeddings.command.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import {
CommandRunner,
InquirerService,
Option,
SubCommand,
} from 'nest-commander';
import { ModelsUsecases } from '@/usecases/models/models.usecases';
import { ModelStat, PSCliUsecases } from './usecases/ps.cli.usecases';
import { ChatCliUsecases } from './usecases/chat.cli.usecases';
import { inspect } from 'util';

interface EmbeddingCommandOptions {
encoding_format?: string;
input?: string;
dimensions?: number;
}

@SubCommand({
name: 'embeddings',
description: 'Creates an embedding vector representing the input text.',
})
export class EmbeddingCommand extends CommandRunner {
constructor(
private readonly chatCliUsecases: ChatCliUsecases,
private readonly modelsUsecases: ModelsUsecases,
private readonly psCliUsecases: PSCliUsecases,
private readonly inquirerService: InquirerService,
) {
super();
}
async run(_input: string[], options: EmbeddingCommandOptions): Promise<void> {
let modelId = _input[0];
// First attempt to get message from input or options
let input: string | string[] = options.input ?? _input.splice(1);

// Check for model existing
if (!modelId || !(await this.modelsUsecases.findOne(modelId))) {
// Model ID is not provided
// first input might be message input
input = _input ?? options.input;
// If model ID is not provided, prompt user to select from running models
const models = await this.psCliUsecases.getModels();
if (models.length === 1) {
modelId = models[0].modelId;
} else if (models.length > 0) {
modelId = await this.modelInquiry(models);
} else {
console.error('Model ID is required');
process.exit(1);
}
}

return this.chatCliUsecases
.embeddings(modelId, input)
.then((res) =>
inspect(res, { showHidden: false, depth: null, colors: true }),
)
.then(console.log)
.catch(console.error);
}

modelInquiry = async (models: ModelStat[]) => {
const { model } = await this.inquirerService.inquirer.prompt({
type: 'list',
name: 'model',
message: 'Select running model to chat with:',
choices: models.map((e) => ({
name: e.modelId,
value: e.modelId,
})),
});
return model;
};

@Option({
flags: '-i, --input <input>',
description:
'Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays.',
})
parseInput(value: string) {
return value;
}

@Option({
flags: '-e, --encoding_format <encoding_format>',
description:
'Encoding format for the embeddings. Supported formats are float and int.',
})
parseEncodingFormat(value: string) {
return value;
}

@Option({
flags: '-d, --dimensions <dimensions>',
description:
'The number of dimensions the resulting output embeddings should have. Only supported in some models.',
})
parseDimensionsFormat(value: string) {
return value;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,30 @@ export class ChatCliUsecases {
}
}

/**
* Creates an embedding vector representing the input text.
* @param model Embedding model ID.
* @param input Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays.
* @param encoding_format Encoding format for the embeddings. Supported formats are 'float' and 'int'.
* @param dimensions The number of dimensions the resulting output embeddings should have. Only supported in some models.
* @param host Cortex CPP host.
* @param port Cortex CPP port.
* @returns Embedding vector.
*/
embeddings(
model: string,
input: string | string[],
encoding_format: string = 'float',
dimensions?: number,
) {
return this.chatUsecases.embeddings(
model,
input,
encoding_format,
dimensions,
);
}

private async getOrCreateNewThread(
modelId: string,
threadId?: string,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,10 @@ import { FileManagerService } from '@/file-manager/file-manager.service';
import { rm } from 'fs/promises';
import { exec } from 'child_process';
import { appPath } from '../utils/app-path';
import { CORTEX_RELEASES_URL, CUDA_DOWNLOAD_URL } from '../../constants/cortex';

@Injectable()
export class InitCliUsecases {
private readonly CORTEX_RELEASES_URL =
'https://api.github.com/repos/janhq/cortex/releases';
private readonly CUDA_DOWNLOAD_URL =
'https://catalog.jan.ai/dist/cuda-dependencies/<version>/<platform>/cuda.tar.gz';

constructor(
private readonly httpService: HttpService,
private readonly fileManagerService: FileManagerService,
Expand All @@ -30,7 +26,7 @@ export class InitCliUsecases {
): Promise<any> => {
const res = await firstValueFrom(
this.httpService.get(
this.CORTEX_RELEASES_URL + `${version === 'latest' ? '/latest' : ''}`,
CORTEX_RELEASES_URL + `${version === 'latest' ? '/latest' : ''}`,
{
headers: {
'X-GitHub-Api-Version': '2022-11-28',
Expand Down Expand Up @@ -182,7 +178,7 @@ export class InitCliUsecases {
const platform = process.platform === 'win32' ? 'windows' : 'linux';

const dataFolderPath = await this.fileManagerService.getDataFolderPath();
const url = this.CUDA_DOWNLOAD_URL.replace(
const url = CUDA_DOWNLOAD_URL.replace(
'<version>',
options.cudaVersion === '11' ? '11.7' : '12.0',
).replace('<platform>', platform);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import {
OPEN_CHAT_3_5_JINJA,
ZEPHYR,
ZEPHYR_JINJA,
} from '../constants/prompt-constants';
} from '../../constants/prompt-constants';
import { ModelTokenizer } from '../types/model-tokenizer.interface';
import { HttpService } from '@nestjs/axios';
import { firstValueFrom } from 'rxjs';
Expand All @@ -29,6 +29,12 @@ import { join, basename } from 'path';
import { load } from 'js-yaml';
import { existsSync, readFileSync } from 'fs';
import { isLocalModel, normalizeModelId } from '../utils/normalize-model-id';
import {
HUGGING_FACE_DOWNLOAD_FILE_MAIN_URL,
HUGGING_FACE_REPO_MODEL_API_URL,
HUGGING_FACE_REPO_URL,
HUGGING_FACE_TREE_REF_URL,
} from '../../constants/huggingface';

@Injectable()
export class ModelsCliUsecases {
Expand Down Expand Up @@ -68,9 +74,7 @@ export class ModelsCliUsecases {
* @param modelId
*/
async stopModel(modelId: string): Promise<void> {
return this.getModelOrStop(modelId)
.then(() => this.modelsUsecases.stopModel(modelId))
.then();
return this.modelsUsecases.stopModel(modelId).then();
}

/**
Expand Down Expand Up @@ -310,7 +314,7 @@ export class ModelsCliUsecases {
? response.map((e) => {
return {
rfilename: e.path,
downloadUrl: `https://huggingface.co/janhq/${repo}/resolve/${tree}/${e.path}`,
downloadUrl: HUGGING_FACE_TREE_REF_URL(repo, tree, e.path),
fileSize: e.size ?? 0,
};
})
Expand Down Expand Up @@ -367,7 +371,11 @@ export class ModelsCliUsecases {
const paths = url.pathname.split('/').filter((e) => e.trim().length > 0);

for (let i = 0; i < data.siblings.length; i++) {
const downloadUrl = `https://huggingface.co/${paths[2]}/${paths[3]}/resolve/main/${data.siblings[i].rfilename}`;
const downloadUrl = HUGGING_FACE_DOWNLOAD_FILE_MAIN_URL(
paths[2],
paths[3],
data.siblings[i].rfilename,
);
data.siblings[i].downloadUrl = downloadUrl;
}

Expand All @@ -379,12 +387,12 @@ export class ModelsCliUsecases {
});
});

data.modelUrl = `https://huggingface.co/${paths[2]}/${paths[3]}`;
data.modelUrl = HUGGING_FACE_REPO_URL(paths[2], paths[3]);
return data;
}

private getRepoModelsUrl(repoId: string, tree?: string): string {
return `https://huggingface.co/api/models/${repoId}${tree ? `/tree/${tree}` : ''}`;
return `${HUGGING_FACE_REPO_MODEL_API_URL(repoId)}${tree ? `/tree/${tree}` : ''}`;
}

private async parsePreset(preset?: string): Promise<object> {
Expand Down
24 changes: 24 additions & 0 deletions cortex-js/src/infrastructure/constants/cortex.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { defaultCortexCppHost, defaultCortexCppPort } from 'constant';

// CORTEX CPP
export const CORTEX_CPP_EMBEDDINGS_URL = (
host: string = defaultCortexCppHost,
port: number = defaultCortexCppPort,
) => `http://${host}:${port}/inferences/server/embedding`;

export const CORTEX_CPP_PROCESS_DESTROY_URL = (
host: string = defaultCortexCppHost,
port: number = defaultCortexCppPort,
) => `http://${host}:${port}/processmanager/destroy`;

export const CORTEX_CPP_HEALTH_Z_URL = (
host: string = defaultCortexCppHost,
port: number = defaultCortexCppPort,
) => `http://${host}:${port}/healthz`;

// INITIALIZATION
export const CORTEX_RELEASES_URL =
'https://api.github.com/repos/janhq/cortex/releases';

export const CUDA_DOWNLOAD_URL =
'https://catalog.jan.ai/dist/cuda-dependencies/<version>/<platform>/cuda.tar.gz';
40 changes: 40 additions & 0 deletions cortex-js/src/infrastructure/constants/huggingface.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
export const HUGGING_FACE_TREE_REF_URL = (
repo: string,
tree: string,
path: string,
) => `https://huggingface.co/janhq/${repo}/resolve/${tree}/${path}`;

export const HUGGING_FACE_DOWNLOAD_FILE_MAIN_URL = (
author: string,
repo: string,
fileName: string,
) => `https://huggingface.co/${author}/${repo}/resolve/main/${fileName}`;

export const HUGGING_FACE_REPO_URL = (author: string, repo: string) =>
`https://huggingface.co/${author}/${repo}`;

export const HUGGING_FACE_REPO_MODEL_API_URL = (repo: string) =>
`https://huggingface.co/api/models/${repo}`;

export const AllQuantizations = [
'Q3_K_S',
'Q3_K_M',
'Q3_K_L',
'Q4_K_S',
'Q4_K_M',
'Q5_K_S',
'Q5_K_M',
'Q4_0',
'Q4_1',
'Q5_0',
'Q5_1',
'IQ2_XXS',
'IQ2_XS',
'Q2_K',
'Q2_K_S',
'Q6_K',
'Q8_0',
'F16',
'F32',
'COPY',
];
3 changes: 2 additions & 1 deletion cortex-js/src/usecases/chat/chat.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import { ChatUsecases } from './chat.usecases';
import { DatabaseModule } from '@/infrastructure/database/database.module';
import { ExtensionModule } from '@/infrastructure/repositories/extensions/extension.module';
import { ModelRepositoryModule } from '@/infrastructure/repositories/model/model.module';
import { HttpModule } from '@nestjs/axios';

@Module({
imports: [DatabaseModule, ExtensionModule, ModelRepositoryModule],
imports: [DatabaseModule, ExtensionModule, ModelRepositoryModule, HttpModule],
controllers: [ChatController],
providers: [ChatUsecases],
exports: [ChatUsecases],
Expand Down
Loading

0 comments on commit 25057f5

Please sign in to comment.