diff --git a/cortex-js/src/domain/models/model.event.ts b/cortex-js/src/domain/models/model.event.ts index 7d2177911..a28a9afa4 100644 --- a/cortex-js/src/domain/models/model.event.ts +++ b/cortex-js/src/domain/models/model.event.ts @@ -8,6 +8,7 @@ const ModelLoadingEvents = [ 'starting-failed', 'stopping-failed', 'model-downloaded', + 'model-downloaded-failed', 'model-deleted', ] as const; export type ModelLoadingEvent = (typeof ModelLoadingEvents)[number]; diff --git a/cortex-js/src/infrastructure/controllers/models.controller.ts b/cortex-js/src/infrastructure/controllers/models.controller.ts index 11fcd4f7d..c9b1e4ca5 100644 --- a/cortex-js/src/infrastructure/controllers/models.controller.ts +++ b/cortex-js/src/infrastructure/controllers/models.controller.ts @@ -8,6 +8,8 @@ import { Delete, HttpCode, UseInterceptors, + Query, + BadRequestException, } from '@nestjs/common'; import { ModelsUsecases } from '@/usecases/models/models.usecases'; import { CreateModelDto } from '@/infrastructure/dtos/models/create-model.dto'; @@ -26,6 +28,7 @@ import { } from '@/domain/telemetry/telemetry.interface'; import { TelemetryUsecases } from '@/usecases/telemetry/telemetry.usecases'; import { CommonResponseDto } from '../dtos/common/common-response.dto'; +import { HuggingFaceRepoSibling } from '@/domain/models/huggingface.interface'; @ApiTags('Models') @Controller('models') @@ -117,8 +120,17 @@ export class ModelsController { }) @Get('download/:modelId(*)') - downloadModel(@Param('modelId') modelId: string) { - this.modelsUsecases.pullModel(modelId, false).then(() => this.telemetryUsecases.addEventToQueue({ + downloadModel(@Param('modelId') modelId: string, @Query('fileName') fileName: string) { + this.modelsUsecases.pullModel(modelId, false, (files) => { + return new Promise(async (resolve, reject) => { + const file = files + .find((e) => e.quantization && e.rfilename === fileName) + if(!file) { + return reject(new BadRequestException('File not found')); + } + return resolve(file); + }); + }).then(() => this.telemetryUsecases.addEventToQueue({ name: EventName.DOWNLOAD_MODEL, modelId, }) @@ -162,8 +174,17 @@ export class ModelsController { description: 'The unique identifier of the model.', }) @Get('pull/:modelId(*)') - pullModel(@Param('modelId') modelId: string) { - this.modelsUsecases.pullModel(modelId).then(() => this.telemetryUsecases.addEventToQueue({ + pullModel(@Param('modelId') modelId: string, @Query('fileName') fileName: string) { + this.modelsUsecases.pullModel(modelId, false, (files) => { + return new Promise(async (resolve, reject) => { + const file = files + .find((e) => e.quantization && e.rfilename === fileName) + if(!file) { + return reject(new BadRequestException('File not found')); + } + return resolve(file); + }); + }).then(() => this.telemetryUsecases.addEventToQueue({ name: EventName.DOWNLOAD_MODEL, modelId, }) diff --git a/cortex-js/src/usecases/models/models.usecases.ts b/cortex-js/src/usecases/models/models.usecases.ts index 471ad47b8..52375b7f9 100644 --- a/cortex-js/src/usecases/models/models.usecases.ts +++ b/cortex-js/src/usecases/models/models.usecases.ts @@ -342,10 +342,21 @@ export class ModelsUsecases { await promises.mkdir(modelFolder, { recursive: true }).catch(() => {}); let files = (await fetchJanRepoData(modelId)).siblings; - + // HuggingFace GGUF Repo - Only one file is downloaded if (modelId.includes('/') && selection && files.length) { + try { files = [await selection(files)]; + } catch (e) { + const modelEvent: ModelEvent = { + model: modelId, + event: 'model-downloaded-failed', + metadata: { + error: e.message || e, + }, + }; + this.eventEmitter.emit('model.event', modelEvent); + } } // Start downloading the model @@ -420,7 +431,9 @@ export class ModelsUsecases { const modelEvent: ModelEvent = { model: modelId, event: 'model-downloaded', - metadata: {}, + metadata: { + ...(selection ? { file: [files] } : {}), + }, }; this.eventEmitter.emit('model.event', modelEvent); }, diff --git a/cortex-js/src/utils/model-check.ts b/cortex-js/src/utils/model-check.ts index 59b0e2842..53ef84d47 100644 --- a/cortex-js/src/utils/model-check.ts +++ b/cortex-js/src/utils/model-check.ts @@ -21,7 +21,7 @@ export const checkModelCompatibility = async (modelId: string, spinner?: ora.Ora process.exit(1); } - try{ + try { const version = await getCudaVersion(); const [currentMajor, currentMinor] = version.split('.').map(Number); const [requiredMajor, requiredMinor] = MIN_CUDA_VERSION.split('.').map(Number);