Skip to content

Commit

Permalink
feat: support pull model by specific fileName
Browse files Browse the repository at this point in the history
  • Loading branch information
marknguyen1302 authored Jul 9, 2024
1 parent ad545d1 commit bafe1e8
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 7 deletions.
1 change: 1 addition & 0 deletions cortex-js/src/domain/models/model.event.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const ModelLoadingEvents = [
'starting-failed',
'stopping-failed',
'model-downloaded',
'model-downloaded-failed',
'model-deleted',
] as const;
export type ModelLoadingEvent = (typeof ModelLoadingEvents)[number];
Expand Down
29 changes: 25 additions & 4 deletions cortex-js/src/infrastructure/controllers/models.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import {
Delete,
HttpCode,
UseInterceptors,
Query,
BadRequestException,
} from '@nestjs/common';
import { ModelsUsecases } from '@/usecases/models/models.usecases';
import { CreateModelDto } from '@/infrastructure/dtos/models/create-model.dto';
Expand All @@ -26,6 +28,7 @@ import {
} from '@/domain/telemetry/telemetry.interface';
import { TelemetryUsecases } from '@/usecases/telemetry/telemetry.usecases';
import { CommonResponseDto } from '../dtos/common/common-response.dto';
import { HuggingFaceRepoSibling } from '@/domain/models/huggingface.interface';

@ApiTags('Models')
@Controller('models')
Expand Down Expand Up @@ -117,8 +120,17 @@ export class ModelsController {
})

@Get('download/:modelId(*)')
downloadModel(@Param('modelId') modelId: string) {
this.modelsUsecases.pullModel(modelId, false).then(() => this.telemetryUsecases.addEventToQueue({
downloadModel(@Param('modelId') modelId: string, @Query('fileName') fileName: string) {
this.modelsUsecases.pullModel(modelId, false, (files) => {
return new Promise<HuggingFaceRepoSibling>(async (resolve, reject) => {
const file = files
.find((e) => e.quantization && e.rfilename === fileName)
if(!file) {
return reject(new BadRequestException('File not found'));
}
return resolve(file);
});
}).then(() => this.telemetryUsecases.addEventToQueue({
name: EventName.DOWNLOAD_MODEL,
modelId,
})
Expand Down Expand Up @@ -162,8 +174,17 @@ export class ModelsController {
description: 'The unique identifier of the model.',
})
@Get('pull/:modelId(*)')
pullModel(@Param('modelId') modelId: string) {
this.modelsUsecases.pullModel(modelId).then(() => this.telemetryUsecases.addEventToQueue({
pullModel(@Param('modelId') modelId: string, @Query('fileName') fileName: string) {
this.modelsUsecases.pullModel(modelId, false, (files) => {
return new Promise<HuggingFaceRepoSibling>(async (resolve, reject) => {
const file = files
.find((e) => e.quantization && e.rfilename === fileName)
if(!file) {
return reject(new BadRequestException('File not found'));
}
return resolve(file);
});
}).then(() => this.telemetryUsecases.addEventToQueue({
name: EventName.DOWNLOAD_MODEL,
modelId,
})
Expand Down
17 changes: 15 additions & 2 deletions cortex-js/src/usecases/models/models.usecases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,21 @@ export class ModelsUsecases {
await promises.mkdir(modelFolder, { recursive: true }).catch(() => {});

let files = (await fetchJanRepoData(modelId)).siblings;

// HuggingFace GGUF Repo - Only one file is downloaded
if (modelId.includes('/') && selection && files.length) {
try {
files = [await selection(files)];
} catch (e) {
const modelEvent: ModelEvent = {
model: modelId,
event: 'model-downloaded-failed',
metadata: {
error: e.message || e,
},
};
this.eventEmitter.emit('model.event', modelEvent);
}
}

// Start downloading the model
Expand Down Expand Up @@ -420,7 +431,9 @@ export class ModelsUsecases {
const modelEvent: ModelEvent = {
model: modelId,
event: 'model-downloaded',
metadata: {},
metadata: {
...(selection ? { file: [files] } : {}),
},
};
this.eventEmitter.emit('model.event', modelEvent);
},
Expand Down
2 changes: 1 addition & 1 deletion cortex-js/src/utils/model-check.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export const checkModelCompatibility = async (modelId: string, spinner?: ora.Ora
process.exit(1);
}

try{
try {
const version = await getCudaVersion();
const [currentMajor, currentMinor] = version.split('.').map(Number);
const [requiredMajor, requiredMinor] = MIN_CUDA_VERSION.split('.').map(Number);
Expand Down

0 comments on commit bafe1e8

Please sign in to comment.