Skip to content

Commit

Permalink
chore: check model not exist before pulling
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan committed Jun 5, 2024
1 parent dc1c170 commit f5d8688
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ export class ModelGetCommand extends CommandRunner {
exit(1);
}

const models = await this.modelsCliUsecases.getModel(input[0]);
console.log(models);
const model = await this.modelsCliUsecases.getModel(input[0]);
if (!model) console.error('Model not found');
else console.log(model);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { CommandRunner, InquirerService, SubCommand } from 'nest-commander';
import { exit } from 'node:process';
import { ModelsCliUsecases } from '../usecases/models.cli.usecases';
import { RepoDesignation, listFiles } from '@huggingface/hub';
import { ModelNotFoundException } from '@/infrastructure/exception/model-not-found.exception';

@SubCommand({
name: 'pull',
Expand All @@ -28,9 +29,16 @@ export class ModelPullCommand extends CommandRunner {
? undefined
: await this.tryToGetBranches(input[0]);

await this.modelsCliUsecases.pullModel(
!branches ? input[0] : await this.handleJanHqModel(input[0], branches),
);
await this.modelsCliUsecases
.pullModel(
!branches ? input[0] : await this.handleJanHqModel(input[0], branches),
)
.catch((e: Error) => {
if (e instanceof ModelNotFoundException)
console.error('Model does not exist.');
else console.error(e);
exit(1);
});

console.log('\nDownload complete!');
exit(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ export class ModelStartCommand extends CommandRunner {
}
}

const existingModel = await this.modelsCliUsecases.getModel(modelId);
if (
!existingModel ||
!Array.isArray(existingModel.files) ||
/^(http|https):\/\/[^/]+\/.*/.test(existingModel.files[0])
) {
console.error('Model is not available. Please pull the model first.');
process.exit(1);
}

await this.cortexUsecases
.startCortex(options.attach)
.then(() => this.modelsCliUsecases.startModel(modelId, options.preset))
Expand All @@ -41,7 +51,11 @@ export class ModelStartCommand extends CommandRunner {
}

modelInquiry = async () => {
const models = await this.modelsCliUsecases.listAllModels();
const models = (await this.modelsCliUsecases.listAllModels()).filter(
(model) =>
Array.isArray(model.files) &&
!/^(http|https):\/\/[^/]+\/.*/.test(model.files[0]),
);
if (!models.length) throw 'No models found';
const { model } = await this.inquirerService.inquirer.prompt({
type: 'list',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ export class ModelsCliUsecases {
* @param modelId
* @returns
*/
private async getModelOrStop(modelId: string): Promise<Model> {
async getModelOrStop(modelId: string): Promise<Model> {
const model = await this.modelsUsecases.findOne(modelId);
if (!model) {
console.debug('Model not found');
Expand All @@ -127,9 +127,8 @@ export class ModelsCliUsecases {
* @param modelId
* @returns
*/
async getModel(modelId: string): Promise<Model> {
const model = await this.getModelOrStop(modelId);
return model;
async getModel(modelId: string): Promise<Model | null> {
return this.modelsUsecases.findOne(modelId);
}

/**
Expand All @@ -147,7 +146,12 @@ export class ModelsCliUsecases {
* @param modelId
*/
async pullModel(modelId: string) {
if (await this.modelsUsecases.findOne(modelId)) {
const existingModel = await this.modelsUsecases.findOne(modelId);
if (
existingModel &&
Array.isArray(existingModel.files) &&
!/^(http|https):\/\/[^/]+\/.*/.test(existingModel.files[0])
) {
console.error('Model already exists');
process.exit(1);
}
Expand All @@ -161,15 +165,20 @@ export class ModelsCliUsecases {
bar.update(progress);
};

await this.modelsUsecases.downloadModel(modelId, callback);
try {
await this.modelsUsecases.downloadModel(modelId, callback);

const model = await this.modelsUsecases.findOne(modelId);
const fileUrl = join(
await this.fileService.getModelsPath(),
normalizeModelId(modelId),
basename((model?.files as string[])[0]),
);
await this.modelsUsecases.update(modelId, { files: [fileUrl] });
const model = await this.modelsUsecases.findOne(modelId);
const fileUrl = join(
await this.fileService.getModelsPath(),
normalizeModelId(modelId),
basename((model?.files as string[])[0]),
);
await this.modelsUsecases.update(modelId, { files: [fileUrl] });
} catch (err) {
bar.stop();
throw err;
}
}

private async getHFModelTokenizer(
Expand Down

0 comments on commit f5d8688

Please sign in to comment.