Skip to content

Commit

Permalink
feat: add cortex embeddings command
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-menlo committed Jun 6, 2024
1 parent 6200cce commit e1ea4f4
Show file tree
Hide file tree
Showing 15 changed files with 269 additions and 56 deletions.
2 changes: 2 additions & 0 deletions cortex-js/src/command.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import { FileManagerModule } from './file-manager/file-manager.module';
import { PSCommand } from './infrastructure/commanders/ps.command';
import { KillCommand } from './infrastructure/commanders/kill.command';
import { PresetCommand } from './infrastructure/commanders/presets.command';
import { EmbeddingCommand } from './infrastructure/commanders/embeddings.command';

@Module({
imports: [
Expand Down Expand Up @@ -54,6 +55,7 @@ import { PresetCommand } from './infrastructure/commanders/presets.command';
PSCommand,
KillCommand,
PresetCommand,
EmbeddingCommand,

// Questions
InitRunModeQuestions,
Expand Down
5 changes: 3 additions & 2 deletions cortex-js/src/infrastructure/commanders/chat.command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ export class ChatCommand extends CommandRunner {
async run(_input: string[], options: ChatOptions): Promise<void> {
let modelId = _input[0];
// First attempt to get message from input or options
let message = _input[1] ?? options.message;
// Extract input from 1 to end of array
let message = options.message ?? _input.slice(1).join(' ');

// Check for model existing
if (!modelId || !(await this.modelsUsecases.findOne(modelId))) {
// Model ID is not provided
// first input might be message input
message = _input[0] ?? options.message;
message = _input.length ? _input.join(' ') : options.message ?? '';
// If model ID is not provided, prompt user to select from running models
const models = await this.psCliUsecases.getModels();
if (models.length === 1) {
Expand Down
22 changes: 0 additions & 22 deletions cortex-js/src/infrastructure/commanders/constants/huggingface.ts

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { PSCommand } from './ps.command';
import { KillCommand } from './kill.command';
import pkg from '@/../package.json';
import { PresetCommand } from './presets.command';
import { EmbeddingCommand } from './embeddings.command';

interface CortexCommandOptions {
version: boolean;
Expand All @@ -24,6 +25,7 @@ interface CortexCommandOptions {
PSCommand,
KillCommand,
PresetCommand,
EmbeddingCommand,
],
description: 'Cortex CLI',
})
Expand Down
101 changes: 101 additions & 0 deletions cortex-js/src/infrastructure/commanders/embeddings.command.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import {
CommandRunner,
InquirerService,
Option,
SubCommand,
} from 'nest-commander';
import { ModelsUsecases } from '@/usecases/models/models.usecases';
import { ModelStat, PSCliUsecases } from './usecases/ps.cli.usecases';
import { ChatCliUsecases } from './usecases/chat.cli.usecases';
import { inspect } from 'util';

interface EmbeddingCommandOptions {
encoding_format?: string;
input?: string;
dimensions?: number;
}

@SubCommand({
name: 'embeddings',
description: 'Creates an embedding vector representing the input text.',
})
export class EmbeddingCommand extends CommandRunner {
constructor(
private readonly chatCliUsecases: ChatCliUsecases,
private readonly modelsUsecases: ModelsUsecases,
private readonly psCliUsecases: PSCliUsecases,
private readonly inquirerService: InquirerService,
) {
super();
}
async run(_input: string[], options: EmbeddingCommandOptions): Promise<void> {
let modelId = _input[0];
// First attempt to get message from input or options
let input: string | string[] = options.input ?? _input.splice(1);

// Check for model existing
if (!modelId || !(await this.modelsUsecases.findOne(modelId))) {
// Model ID is not provided
// first input might be message input
input = _input ?? options.input;
// If model ID is not provided, prompt user to select from running models
const models = await this.psCliUsecases.getModels();
if (models.length === 1) {
modelId = models[0].modelId;
} else if (models.length > 0) {
modelId = await this.modelInquiry(models);
} else {
console.error('Model ID is required');
process.exit(1);
}
}

return this.chatCliUsecases
.embeddings(modelId, input)
.then((res) =>
inspect(res, { showHidden: false, depth: null, colors: true }),
)
.then(console.log)
.catch(console.error);
}

modelInquiry = async (models: ModelStat[]) => {
const { model } = await this.inquirerService.inquirer.prompt({
type: 'list',
name: 'model',
message: 'Select running model to chat with:',
choices: models.map((e) => ({
name: e.modelId,
value: e.modelId,
})),
});
return model;
};

@Option({
flags: '-i, --input <input>',
description:
'Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays.',
})
parseInput(value: string) {
return value;
}

@Option({
flags: '-e, --encoding_format <encoding_format>',
description:
'Encoding format for the embeddings. Supported formats are float and int.',
})
parseEncodingFormat(value: string) {
return value;
}

@Option({
flags: '-d, --dimensions <dimensions>',
description:
'The number of dimensions the resulting output embeddings should have. Only supported in some models.',
})
parseDimensionsFormat(value: string) {
return value;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,30 @@ export class ChatCliUsecases {
}
}

/**
* Creates an embedding vector representing the input text.
* @param model Embedding model ID.
* @param input Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays.
* @param encoding_format Encoding format for the embeddings. Supported formats are 'float' and 'int'.
* @param dimensions The number of dimensions the resulting output embeddings should have. Only supported in some models.
* @param host Cortex CPP host.
* @param port Cortex CPP port.
* @returns Embedding vector.
*/
embeddings(
model: string,
input: string | string[],
encoding_format: string = 'float',
dimensions?: number,
) {
return this.chatUsecases.embeddings(
model,
input,
encoding_format,
dimensions,
);
}

private async getOrCreateNewThread(
modelId: string,
threadId?: string,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,10 @@ import { FileManagerService } from '@/file-manager/file-manager.service';
import { rm } from 'fs/promises';
import { exec } from 'child_process';
import { appPath } from '../utils/app-path';
import { CORTEX_RELEASES_URL, CUDA_DOWNLOAD_URL } from '../../constants/cortex';

@Injectable()
export class InitCliUsecases {
private readonly CORTEX_RELEASES_URL =
'https://api.github.com/repos/janhq/cortex/releases';
private readonly CUDA_DOWNLOAD_URL =
'https://catalog.jan.ai/dist/cuda-dependencies/<version>/<platform>/cuda.tar.gz';

constructor(
private readonly httpService: HttpService,
private readonly fileManagerService: FileManagerService,
Expand All @@ -30,7 +26,7 @@ export class InitCliUsecases {
): Promise<any> => {
const res = await firstValueFrom(
this.httpService.get(
this.CORTEX_RELEASES_URL + `${version === 'latest' ? '/latest' : ''}`,
CORTEX_RELEASES_URL + `${version === 'latest' ? '/latest' : ''}`,
{
headers: {
'X-GitHub-Api-Version': '2022-11-28',
Expand Down Expand Up @@ -182,7 +178,7 @@ export class InitCliUsecases {
const platform = process.platform === 'win32' ? 'windows' : 'linux';

const dataFolderPath = await this.fileManagerService.getDataFolderPath();
const url = this.CUDA_DOWNLOAD_URL.replace(
const url = CUDA_DOWNLOAD_URL.replace(
'<version>',
options.cudaVersion === '11' ? '11.7' : '12.0',
).replace('<platform>', platform);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import {
OPEN_CHAT_3_5_JINJA,
ZEPHYR,
ZEPHYR_JINJA,
} from '../constants/prompt-constants';
} from '../../constants/prompt-constants';
import { ModelTokenizer } from '../types/model-tokenizer.interface';
import { HttpService } from '@nestjs/axios';
import { firstValueFrom } from 'rxjs';
Expand All @@ -29,6 +29,12 @@ import { join, basename } from 'path';
import { load } from 'js-yaml';
import { existsSync, readFileSync } from 'fs';
import { isLocalModel, normalizeModelId } from '../utils/normalize-model-id';
import {
HUGGING_FACE_DOWNLOAD_FILE_MAIN_URL,
HUGGING_FACE_REPO_MODEL_API_URL,
HUGGING_FACE_REPO_URL,
HUGGING_FACE_TREE_REF_URL,
} from '../../constants/huggingface';

@Injectable()
export class ModelsCliUsecases {
Expand Down Expand Up @@ -68,9 +74,7 @@ export class ModelsCliUsecases {
* @param modelId
*/
async stopModel(modelId: string): Promise<void> {
return this.getModelOrStop(modelId)
.then(() => this.modelsUsecases.stopModel(modelId))
.then();
return this.modelsUsecases.stopModel(modelId).then();
}

/**
Expand Down Expand Up @@ -126,17 +130,13 @@ export class ModelsCliUsecases {
* @param modelId
*/
async pullModel(modelId: string) {
modelId = /[:/]/.test(modelId) ? modelId : `${modelId}:default`;

const existingModel = await this.modelsUsecases.findOne(modelId);
if (isLocalModel(existingModel?.files)) {
console.error('Model already exists');
process.exit(1);
}

if (/[:/]/.test(modelId)) {
await this.pullHuggingFaceModel(modelId);
}
await this.pullHuggingFaceModel(modelId);
const bar = new SingleBar({}, Presets.shades_classic);
bar.start(100, 0);
const callback = (progress: number) => {
Expand All @@ -152,7 +152,10 @@ export class ModelsCliUsecases {
normalizeModelId(modelId),
basename((model?.files as string[])[0]),
);
await this.modelsUsecases.update(modelId, { files: [fileUrl] });
await this.modelsUsecases.update(modelId, {
files: [fileUrl],
name: modelId.replace(':default', ''),
});
} catch (err) {
bar.stop();
throw err;
Expand Down Expand Up @@ -291,7 +294,7 @@ export class ModelsCliUsecases {
*/
private async fetchJanRepoData(modelId: string) {
const repo = modelId.split(':')[0];
const tree = modelId.split(':')[1];
const tree = modelId.split(':')[1] ?? 'default';
const url = this.getRepoModelsUrl(`janhq/${repo}`, tree);
const res = await fetch(url);
const response:
Expand All @@ -310,7 +313,7 @@ export class ModelsCliUsecases {
? response.map((e) => {
return {
rfilename: e.path,
downloadUrl: `https://huggingface.co/janhq/${repo}/resolve/${tree}/${e.path}`,
downloadUrl: HUGGING_FACE_TREE_REF_URL(repo, tree, e.path),
fileSize: e.size ?? 0,
};
})
Expand Down Expand Up @@ -367,7 +370,11 @@ export class ModelsCliUsecases {
const paths = url.pathname.split('/').filter((e) => e.trim().length > 0);

for (let i = 0; i < data.siblings.length; i++) {
const downloadUrl = `https://huggingface.co/${paths[2]}/${paths[3]}/resolve/main/${data.siblings[i].rfilename}`;
const downloadUrl = HUGGING_FACE_DOWNLOAD_FILE_MAIN_URL(
paths[2],
paths[3],
data.siblings[i].rfilename,
);
data.siblings[i].downloadUrl = downloadUrl;
}

Expand All @@ -379,12 +386,12 @@ export class ModelsCliUsecases {
});
});

data.modelUrl = `https://huggingface.co/${paths[2]}/${paths[3]}`;
data.modelUrl = HUGGING_FACE_REPO_URL(paths[2], paths[3]);
return data;
}

private getRepoModelsUrl(repoId: string, tree?: string): string {
return `https://huggingface.co/api/models/${repoId}${tree ? `/tree/${tree}` : ''}`;
return `${HUGGING_FACE_REPO_MODEL_API_URL(repoId)}${tree ? `/tree/${tree}` : ''}`;
}

private async parsePreset(preset?: string): Promise<object> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ModelArtifact } from '@/domain/models/model.interface';

export const normalizeModelId = (modelId: string): string => {
return modelId.replace(/[:/]/g, '-');
return modelId.replace(':default', '').replace(/[:/]/g, '-');
};

export const isLocalModel = (
Expand Down
24 changes: 24 additions & 0 deletions cortex-js/src/infrastructure/constants/cortex.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { defaultCortexCppHost, defaultCortexCppPort } from 'constant';

// CORTEX CPP
export const CORTEX_CPP_EMBEDDINGS_URL = (
host: string = defaultCortexCppHost,
port: number = defaultCortexCppPort,
) => `http://${host}:${port}/inferences/server/embedding`;

export const CORTEX_CPP_PROCESS_DESTROY_URL = (
host: string = defaultCortexCppHost,
port: number = defaultCortexCppPort,
) => `http://${host}:${port}/processmanager/destroy`;

export const CORTEX_CPP_HEALTH_Z_URL = (
host: string = defaultCortexCppHost,
port: number = defaultCortexCppPort,
) => `http://${host}:${port}/healthz`;

// INITIALIZATION
export const CORTEX_RELEASES_URL =
'https://api.github.com/repos/janhq/cortex/releases';

export const CUDA_DOWNLOAD_URL =
'https://catalog.jan.ai/dist/cuda-dependencies/<version>/<platform>/cuda.tar.gz';
Loading

0 comments on commit e1ea4f4

Please sign in to comment.