Skip to content

Commit

Permalink
feature: support local GGUF model pull
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-menlo committed Jul 30, 2024
1 parent 84216df commit 2012869
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 44 deletions.
5 changes: 2 additions & 3 deletions cortex-js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@
},
"dependencies": {
"@cortexso/cortex.js": "^0.1.5",
"@huggingface/gguf": "^0.1.5",
"@huggingface/hub": "^0.15.1",
"@nestjs/axios": "^3.0.2",
"@nestjs/common": "^10.0.0",
"@nestjs/config": "^3.2.2",
Expand All @@ -59,6 +57,7 @@
"cortex-cpp": "0.4.34",
"cpu-instructions": "^0.0.11",
"decompress": "^4.2.1",
"hyllama": "^0.2.2",
"js-yaml": "^4.1.0",
"nest-commander": "^3.13.0",
"ora": "5.4.1",
Expand Down Expand Up @@ -94,10 +93,10 @@
"@yao-pkg/pkg": "^5.12.0",
"cpx": "^1.5.0",
"env-cmd": "10.1.0",
"eslint": "8.57.0",
"eslint-config-prettier": "9.1.0",
"eslint-plugin-import": "2.29.1",
"eslint-plugin-prettier": "5.2.1",
"eslint": "8.57.0",
"hanbi": "^1.0.3",
"is-primitive": "^3.0.1",
"jest": "^29.5.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { downloadProgress } from '@/utils/download-progress';
import { CortexClient } from '../services/cortex.client';
import { DownloadType } from '@/domain/models/download.interface';
import ora from 'ora';
import { isLocalFile } from '@/utils/urls';

@SubCommand({
name: 'pull',
Expand Down Expand Up @@ -61,9 +62,8 @@ export class ModelPullCommand extends BaseCommand {
exit(1);
});

ora().succeed('Model downloaded');

await downloadProgress(this.cortex, modelId);
ora().succeed('Model downloaded');

const existingModel = await this.cortex.models.retrieve(modelId);
const engine = existingModel?.engine || Engines.llamaCPP;
Expand Down
9 changes: 9 additions & 0 deletions cortex-js/src/infrastructure/commanders/run.command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import { ChatClient } from './services/chat-client';
import { downloadProgress } from '@/utils/download-progress';
import { CortexClient } from './services/cortex.client';
import { DownloadType } from '@/domain/models/download.interface';
import { isLocalFile } from '@/utils/urls';
import { parse } from 'node:path';

type RunOptions = {
threadId?: string;
Expand Down Expand Up @@ -71,6 +73,12 @@ export class RunCommand extends BaseCommand {
await downloadProgress(this.cortex, modelId);
checkingSpinner.succeed('Model downloaded');

// Update to persisted modelId
// TODO: Should be retrieved from the request
if (isLocalFile(modelId)) {
modelId = parse(modelId).name;
}

// Second check if model is available
existingModel = await this.cortex.models.retrieve(modelId);
if (!existingModel) {
Expand All @@ -93,6 +101,7 @@ export class RunCommand extends BaseCommand {
}

const startingSpinner = ora('Loading model...').start();

return this.cortex.models
.start(modelId, await this.fileService.getPreset(options.preset))
.then(() => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ export interface ModelMetadata {
stopWord?: string;
promptTemplate: string;
version: number;
name?: string
}
78 changes: 46 additions & 32 deletions cortex-js/src/usecases/models/models.usecases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { UpdateModelDto } from '@/infrastructure/dtos/models/update-model.dto';
import { BadRequestException, Injectable } from '@nestjs/common';
import { Model, ModelSettingParams } from '@/domain/models/model.interface';
import { ModelNotFoundException } from '@/infrastructure/exception/model-not-found.exception';
import { basename, join } from 'path';
import { basename, join, parse } from 'path';
import { promises, existsSync, mkdirSync, readFileSync, rmSync } from 'fs';
import { StartModelSuccessDto } from '@/infrastructure/dtos/models/start-model-success.dto';
import { ExtensionRepository } from '@/domain/repositories/extension.interface';
Expand All @@ -17,7 +17,6 @@ import { TelemetrySource } from '@/domain/telemetry/telemetry.interface';
import { ModelRepository } from '@/domain/repositories/model.interface';
import { ModelParameterParser } from '@/utils/model-parameter.parser';
import {
HuggingFaceModelVersion,
HuggingFaceRepoData,
HuggingFaceRepoSibling,
} from '@/domain/models/huggingface.interface';
Expand All @@ -26,7 +25,10 @@ import {
fetchJanRepoData,
getHFModelMetadata,
} from '@/utils/huggingface';
import { DownloadType } from '@/domain/models/download.interface';
import {
DownloadStatus,
DownloadType,
} from '@/domain/models/download.interface';
import { EventEmitter2 } from '@nestjs/event-emitter';
import { ModelEvent, ModelId, ModelStatus } from '@/domain/models/model.event';
import { DownloadManagerService } from '@/infrastructure/services/download-manager/download-manager.service';
Expand All @@ -35,6 +37,7 @@ import { Engines } from '@/infrastructure/commanders/types/engine.interface';
import { load } from 'js-yaml';
import { llamaModelFile } from '@/utils/app-path';
import { CortexUsecases } from '../cortex/cortex.usecases';
import { isLocalFile } from '@/utils/urls';

@Injectable()
export class ModelsUsecases {
Expand Down Expand Up @@ -127,7 +130,9 @@ export class ModelsUsecases {
)) as EngineExtension | undefined;

if (engine) {
await engine.unloadModel(id, model.engine || Engines.llamaCPP).catch(() => {}); // Silent fail
await engine
.unloadModel(id, model.engine || Engines.llamaCPP)
.catch(() => {}); // Silent fail
}
return this.modelRepository
.remove(id)
Expand Down Expand Up @@ -174,7 +179,7 @@ export class ModelsUsecases {
}

// Attempt to start cortex
await this.cortexUsecases.startCortex()
await this.cortexUsecases.startCortex();

const loadingModelSpinner = ora('Loading model...').start();
// update states and emitting event
Expand Down Expand Up @@ -341,10 +346,26 @@ export class ModelsUsecases {
) {
const modelId = persistedModelId ?? originModelId;
const existingModel = await this.findOne(modelId);

if (isLocalModel(existingModel?.files)) {
throw new BadRequestException('Model already exists');
}

// Pull a local model file
if (isLocalFile(originModelId)) {
await this.populateHuggingFaceModel(originModelId, persistedModelId);
this.eventEmitter.emit('download.event', [
{
id: modelId,
type: DownloadType.Model,
status: DownloadStatus.Downloaded,
progress: 100,
children: [],
},
]);
return;
}

const modelsContainerDir = await this.fileManagerService.getModelsPath();

if (!existsSync(modelsContainerDir)) {
Expand Down Expand Up @@ -422,22 +443,18 @@ export class ModelsUsecases {
model.model = modelId;
if (!(await this.findOne(modelId))) await this.create(model);
} else {
await this.populateHuggingFaceModel(modelId, files[0]);
const model = await this.findOne(modelId);
if (model) {
const fileUrl = join(
await this.fileManagerService.getModelsPath(),
normalizeModelId(modelId),
basename(
files.find((e) => e.rfilename.endsWith('.gguf'))?.rfilename ??
files[0].rfilename,
),
);
await this.update(modelId, {
files: [fileUrl],
name: modelId.replace(':main', ''),
});
}
const fileUrl = join(
await this.fileManagerService.getModelsPath(),
normalizeModelId(modelId),
basename(
files.find((e) => e.rfilename.endsWith('.gguf'))?.rfilename ??
files[0].rfilename,
),
);
await this.populateHuggingFaceModel(
fileUrl,
modelId.replace(':main', ''),
);
}
uploadModelMetadataSpiner.succeed('Model metadata updated');
const modelEvent: ModelEvent = {
Expand All @@ -458,21 +475,18 @@ export class ModelsUsecases {
* It could be a model from Jan's repo or other authors
* @param modelId HuggingFace model id. e.g. "janhq/llama-3 or llama3:7b"
*/
async populateHuggingFaceModel(
modelId: string,
modelVersion: HuggingFaceModelVersion,
) {
if (!modelVersion) throw 'No expected quantization found';

const tokenizer = await getHFModelMetadata(modelVersion.downloadUrl!);
async populateHuggingFaceModel(ggufUrl: string, overridenId?: string) {
const metadata = await getHFModelMetadata(ggufUrl);

const stopWords: string[] = tokenizer?.stopWord ? [tokenizer.stopWord] : [];
const stopWords: string[] = metadata?.stopWord ? [metadata.stopWord] : [];

const modelId =
overridenId ?? (isLocalFile(ggufUrl) ? parse(ggufUrl).name : ggufUrl);
const model: CreateModelDto = {
files: [modelVersion.downloadUrl ?? ''],
files: [ggufUrl],
model: modelId,
name: modelId,
prompt_template: tokenizer?.promptTemplate,
name: metadata?.name ?? modelId,
prompt_template: metadata?.promptTemplate,
stop: stopWords,

// Default Inference Params
Expand Down
3 changes: 3 additions & 0 deletions cortex-js/src/utils/download-progress.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import { exit, stdin, stdout } from 'node:process';
import { DownloadState, DownloadType } from "@/domain/models/download.interface";

export const downloadProgress = async (cortex: Cortex, downloadId?: string, downloadType?: DownloadType) => {
// Do not update on local file symlink
if (downloadId && isLocalFile(downloadId)) return;

const response = await cortex.events.downloadEvent();

const rl = require('readline').createInterface({
Expand Down
19 changes: 14 additions & 5 deletions cortex-js/src/utils/huggingface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import {
ZEPHYR,
ZEPHYR_JINJA,
} from '@/infrastructure/constants/prompt-constants';
import { gguf } from '@huggingface/gguf';
import axios from 'axios';
import { parseModelHubEngineBranch } from './normalize-model-id';
import { closeSync, openSync, readSync } from 'fs';

// TODO: move this to somewhere else, should be reused by API as well. Maybe in a separate service / provider?
export function guessPromptTemplateFromHuggingFace(jinjaCode?: string): string {
Expand Down Expand Up @@ -209,20 +209,29 @@ export async function getHFModelMetadata(
ggufUrl: string,
): Promise<ModelMetadata | undefined> {
try {
const { metadata } = await gguf(ggufUrl);
// @ts-expect-error "tokenizer.ggml.eos_token_id"
let metadata: any;
const { ggufMetadata } = await import('hyllama');
// Read first 10mb of gguf file
const fd = openSync(ggufUrl, 'r');
const buffer = new Uint8Array(10_000_000);
readSync(fd, buffer, 0, 10_000_000, 0);
closeSync(fd);

// Parse metadata and tensor info
({ metadata } = ggufMetadata(buffer.buffer));

const index = metadata['tokenizer.ggml.eos_token_id'];
// @ts-expect-error "tokenizer.ggml.eos_token_id"
const hfChatTemplate = metadata['tokenizer.chat_template'];
const promptTemplate = guessPromptTemplateFromHuggingFace(hfChatTemplate);
// @ts-expect-error "tokenizer.ggml.tokens"
const stopWord: string = metadata['tokenizer.ggml.tokens'][index] ?? '';
const name = metadata['general.name'];

const version: number = metadata['version'];
return {
stopWord,
promptTemplate,
version,
name,
};
} catch (err) {
console.log('Failed to get model metadata:', err.message);
Expand Down
11 changes: 11 additions & 0 deletions cortex-js/src/utils/urls.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { isAbsolute } from 'path';

/**
* Check if a string is a valid URL.
* @param input - The string to check.
Expand All @@ -12,3 +14,12 @@ export function isValidUrl(input: string | undefined): boolean {
return false;
}
}

/**
* Check if the URL is a lcoal file path
* @param modelFiles
* @returns
*/
export const isLocalFile = (path: string): boolean => {
return !/^(http|https):\/\/[^/]+\/.*/.test(path) && isAbsolute(path);
};
4 changes: 2 additions & 2 deletions cortex-js/tsconfig.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"compilerOptions": {
"module": "commonjs",
"moduleResolution": "node",
"module": "node16",
"moduleResolution": "node16",
"declaration": true,
"removeComments": true,
"emitDecoratorMetadata": true,
Expand Down

0 comments on commit 2012869

Please sign in to comment.