Skip to content

Commit

Permalink
chore: update model start params DTO
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan committed Jun 10, 2024
1 parent c97a46e commit 85c4f69
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 36 deletions.
16 changes: 1 addition & 15 deletions cortex-js/src/domain/models/model.interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ export interface Model {
ngl?: number;

/**
* The number of parallel operations. Only set when enable continuous batching.
* Number of parallel sequences to decode
*/
n_parallel?: number;

Expand All @@ -96,13 +96,6 @@ export interface Model {
engine?: string;
}

export interface ModelMetadata {
author: string;
tags: string[];
size: number;
cover?: string;
}

/**
* The available model settings.
*/
Expand Down Expand Up @@ -140,10 +133,3 @@ export interface ModelRuntimeParams {
presence_penalty?: number;
engine?: string;
}

/**
* Represents the model initialization error.
*/
export type ModelInitFailed = Model & {
error: Error;
};
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import {
import { exit } from 'node:process';
import { ModelsCliUsecases } from '../usecases/models.cli.usecases';
import { CortexUsecases } from '@/usecases/cortex/cortex.usecases';
import { isLocalModel } from '../utils/normalize-model-id';

type ModelStartOptions = {
attach: boolean;
Expand Down Expand Up @@ -52,9 +51,7 @@ export class ModelStartCommand extends CommandRunner {
}

modelInquiry = async () => {
const models = (await this.modelsCliUsecases.listAllModels()).filter(
(model) => isLocalModel(model.files),
);
const models = await this.modelsCliUsecases.listAllModels();
if (!models.length) throw 'No models found';
const { model } = await this.inquirerService.inquirer.prompt({
type: 'list',
Expand Down
10 changes: 5 additions & 5 deletions cortex-js/src/infrastructure/commanders/shortcuts/run.command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import {
} from 'nest-commander';
import { exit } from 'node:process';
import { ChatCliUsecases } from '../usecases/chat.cli.usecases';
import { defaultCortexCppHost, defaultCortexCppPort } from '@/infrastructure/constants/cortex';
import {
defaultCortexCppHost,
defaultCortexCppPort,
} from '@/infrastructure/constants/cortex';
import { ModelsCliUsecases } from '../usecases/models.cli.usecases';
import { isLocalModel } from '../utils/normalize-model-id';
import { ModelNotFoundException } from '@/infrastructure/exception/model-not-found.exception';

type RunOptions = {
Expand Down Expand Up @@ -77,9 +79,7 @@ export class RunCommand extends CommandRunner {
}

modelInquiry = async () => {
const models = (await this.modelsCliUsecases.listAllModels()).filter(
(model) => isLocalModel(model.files),
);
const models = await this.modelsCliUsecases.listAllModels();
if (!models.length) throw 'No models found';
const { model } = await this.inquirerService.inquirer.prompt({
type: 'list',
Expand Down
8 changes: 6 additions & 2 deletions cortex-js/src/infrastructure/controllers/models.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import { ApiOperation, ApiParam, ApiTags, ApiResponse } from '@nestjs/swagger';
import { StartModelSuccessDto } from '@/infrastructure/dtos/models/start-model-success.dto';
import { TransformInterceptor } from '../interceptors/transform.interceptor';
import { CortexUsecases } from '@/usecases/cortex/cortex.usecases';
import { ModelSettingsDto } from '../dtos/models/model-settings.dto';

@ApiTags('Models')
@Controller('models')
Expand Down Expand Up @@ -61,10 +62,13 @@ export class ModelsController {
description: 'The unique identifier of the model.',
})
@Post(':modelId(*)/start')
startModel(@Param('modelId') modelId: string, @Body() model: ModelDto) {
startModel(
@Param('modelId') modelId: string,
@Body() params: ModelSettingsDto,
) {
return this.cortexUsecases
.startCortex()
.then(() => this.modelsUsecases.startModel(modelId, model));
.then(() => this.modelsUsecases.startModel(modelId, params));
}

@HttpCode(200)
Expand Down
18 changes: 15 additions & 3 deletions cortex-js/src/infrastructure/dtos/models/create-model.dto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
IsNumber,
IsOptional,
IsString,
Min,
} from 'class-validator';
import { Model } from '@/domain/models/model.interface';
import { ModelArtifactDto } from './model-artifact.dto';
Expand Down Expand Up @@ -50,6 +51,7 @@ export class CreateModelDto implements Partial<Model> {
@ApiProperty({
description:
'Sets the upper limit on the number of tokens the model can generate in a single output.',
example: 4096,
})
@IsOptional()
@IsNumber()
Expand Down Expand Up @@ -97,30 +99,40 @@ export class CreateModelDto implements Partial<Model> {
@ApiProperty({
description:
'Sets the maximum input the model can use to generate a response, it varies with the model used.',
example: 4096,
})
@IsOptional()
@IsNumber()
ctx_len?: number;

@ApiProperty({ description: 'Determines GPU layer usage.' })
@ApiProperty({ description: 'Determines GPU layer usage.', example: 32 })
@IsOptional()
@IsNumber()
ngl?: number;

@ApiProperty({ description: 'Number of parallel processing units to use.' })
@ApiProperty({
description: 'Number of parallel processing units to use.',
example: 1,
})
@IsOptional()
@IsNumber()
@Min(1)
n_parallel?: number;

@ApiProperty({
description:
'Determines CPU inference threads, limited by hardware and OS. ',
example: 10,
})
@IsOptional()
@IsNumber()
@Min(1)
cpu_threads?: number;

@ApiProperty({ description: 'The engine used to run the model.' })
@ApiProperty({
description: 'The engine used to run the model.',
example: 'cortex.llamacpp',
})
@IsOptional()
@IsString()
engine?: string;
Expand Down
56 changes: 56 additions & 0 deletions cortex-js/src/infrastructure/dtos/models/model-settings.dto.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import { ModelSettingParams } from '@/domain/models/model.interface';
import { ApiProperty } from '@nestjs/swagger';
import { IsArray, IsNumber, IsOptional, Min } from 'class-validator';

export class ModelSettingsDto implements ModelSettingParams {
// Prompt Settings
@ApiProperty({
example: 'system\n{system_message}\nuser\n{prompt}\nassistant',
description:
"A predefined text or framework that guides the AI model's response generation.",
})
@IsOptional()
prompt_template?: string;

@ApiProperty({
type: [String],
example: [],
description:
'Defines specific tokens or phrases that signal the model to stop producing further output.',
})
@IsArray()
@IsOptional()
stop?: string[];

// Engine Settings
@ApiProperty({ description: 'Determines GPU layer usage.', example: 4096 })
@IsOptional()
@IsNumber()
ngl?: number;

@ApiProperty({
description:
'The context length for model operations varies; the maximum depends on the specific model used.',
example: 4096,
})
@IsOptional()
@IsNumber()
ctx_len?: number;

@ApiProperty({
description:
'Determines CPU inference threads, limited by hardware and OS. ',
example: 10,
})
@IsOptional()
@IsNumber()
@Min(1)
cpu_threads?: number;

@ApiProperty({
example: 'cortex.llamacpp',
description: 'The engine to use.',
})
@IsOptional()
engine?: string;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ import {
writeFileSync,
} from 'fs';
import { load, dump } from 'js-yaml';
import { normalizeModelId } from '@/infrastructure/commanders/utils/normalize-model-id';
import {
isLocalModel,
normalizeModelId,
} from '@/infrastructure/commanders/utils/normalize-model-id';

@Injectable()
export class ModelRepositoryImpl implements ModelRepository {
Expand Down Expand Up @@ -58,7 +61,9 @@ export class ModelRepositoryImpl implements ModelRepository {
* @returns the created model
*/
findAll(): Promise<Model[]> {
return this.loadModels();
return this.loadModels().then((res) =>
res.filter((model) => isLocalModel(model.files)),
);
}
/**
* Find one model by id
Expand Down
5 changes: 4 additions & 1 deletion cortex-js/src/usecases/cortex/cortex.usecases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ import { ChildProcess, spawn } from 'child_process';
import { join } from 'path';
import { CortexOperationSuccessfullyDto } from '@/infrastructure/dtos/cortex/cortex-operation-successfully.dto';
import { HttpService } from '@nestjs/axios';
import { defaultCortexCppHost, defaultCortexCppPort } from '@/infrastructure/constants/cortex';
import {
defaultCortexCppHost,
defaultCortexCppPort,
} from '@/infrastructure/constants/cortex';
import { existsSync } from 'node:fs';
import { firstValueFrom } from 'rxjs';
import { FileManagerService } from '@/file-manager/file-manager.service';
Expand Down
3 changes: 1 addition & 2 deletions cortex-js/src/usecases/messages/messages.module.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import { Module } from '@nestjs/common';
import { MessagesUsecases } from './messages.usecases';
import { MessagesController } from '@/infrastructure/controllers/messages.controller';
import { DatabaseModule } from '@/infrastructure/database/database.module';

@Module({
imports: [DatabaseModule],
controllers: [MessagesController],
controllers: [],
providers: [MessagesUsecases],
exports: [MessagesUsecases],
})
Expand Down
10 changes: 8 additions & 2 deletions cortex-js/src/usecases/models/models.usecases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ export class ModelsUsecases {

return this.modelRepository
.remove(id)
.then(() => rmdirSync(modelFolder, { recursive: true }))
.then(
() =>
existsSync(modelFolder) &&
rmdirSync(modelFolder, { recursive: true }),
)
.then(() => {
return {
message: 'Model removed successfully',
Expand Down Expand Up @@ -100,7 +104,9 @@ export class ModelsUsecases {
// Default settings
ctx_len: 4096,
ngl: 100,
...(Array.isArray(model?.files) &&
//TODO: Utils for model file retrieval
...(model?.files &&
Array.isArray(model.files) &&
!('llama_model_path' in model) && {
llama_model_path: (model.files as string[])[0],
}),
Expand Down

0 comments on commit 85c4f69

Please sign in to comment.