diff --git a/cortex-js/src/infrastructure/commanders/models/model-pull.command.ts b/cortex-js/src/infrastructure/commanders/models/model-pull.command.ts index c7a751b21..93dfbcc7a 100644 --- a/cortex-js/src/infrastructure/commanders/models/model-pull.command.ts +++ b/cortex-js/src/infrastructure/commanders/models/model-pull.command.ts @@ -43,7 +43,7 @@ export class ModelPullCommand extends CommandRunner { } const modelId = passedParams[0]; - checkModelCompatibility(modelId); + await checkModelCompatibility(modelId); await this.modelsCliUsecases.pullModel(modelId).catch((e: Error) => { if (e instanceof ModelNotFoundException) diff --git a/cortex-js/src/infrastructure/commanders/models/model-start.command.ts b/cortex-js/src/infrastructure/commanders/models/model-start.command.ts index 5cec03037..63c7130c9 100644 --- a/cortex-js/src/infrastructure/commanders/models/model-start.command.ts +++ b/cortex-js/src/infrastructure/commanders/models/model-start.command.ts @@ -65,8 +65,9 @@ export class ModelStartCommand extends CommandRunner { process.exit(1); } - checkModelCompatibility(modelId); + await checkModelCompatibility(modelId); checkingSpinner.succeed('Model found'); + const engine = existingModel.engine || Engines.llamaCPP; // Pull engine if not exist if ( diff --git a/cortex-js/src/infrastructure/commanders/shortcuts/run.command.ts b/cortex-js/src/infrastructure/commanders/shortcuts/run.command.ts index 5c6e96ffe..0b5598f0c 100644 --- a/cortex-js/src/infrastructure/commanders/shortcuts/run.command.ts +++ b/cortex-js/src/infrastructure/commanders/shortcuts/run.command.ts @@ -81,7 +81,8 @@ export class RunCommand extends CommandRunner { checkingSpinner.succeed('Model found'); // Check model compatibility on this machine - checkModelCompatibility(modelId); + await checkModelCompatibility(modelId); + const engine = existingModel.engine || Engines.llamaCPP; // Pull engine if not exist if ( diff --git a/cortex-js/src/infrastructure/constants/cortex.ts b/cortex-js/src/infrastructure/constants/cortex.ts index 0f7d5b7cc..8b91809f9 100644 --- a/cortex-js/src/infrastructure/constants/cortex.ts +++ b/cortex-js/src/infrastructure/constants/cortex.ts @@ -49,3 +49,5 @@ export const CUDA_DOWNLOAD_URL = 'https://catalog.jan.ai/dist/cuda-dependencies///cuda.tar.gz'; export const telemetryServerUrl = 'https://telemetry.jan.ai'; + +export const MIN_CUDA_VERSION = '12.3'; \ No newline at end of file diff --git a/cortex-js/src/utils/cuda.ts b/cortex-js/src/utils/cuda.ts index fc252a20a..c533414f9 100644 --- a/cortex-js/src/utils/cuda.ts +++ b/cortex-js/src/utils/cuda.ts @@ -15,6 +15,7 @@ export type GpuSettingInfo = { * @returns CUDA Version 11 | 12 */ export const cudaVersion = async () => { + let filesCuda12: string[]; let filesCuda11: string[]; let paths: string[]; @@ -71,6 +72,33 @@ export const checkNvidiaGPUExist = (): Promise => { }); }; +export const getCudaVersion = (): Promise => { + return new Promise((resolve, reject) => { + // Execute the nvidia-smi command + exec('nvidia-smi', (error, stdout) => { + if (!error) { + const cudaVersionLine = stdout.split('\n').find(line => line.includes('CUDA Version')); + + if (cudaVersionLine) { + // Extract the CUDA version number + const cudaVersionMatch = cudaVersionLine.match(/CUDA Version:\s+(\d+\.\d+)/); + if (cudaVersionMatch) { + const cudaVersion = cudaVersionMatch[1]; + resolve(cudaVersion); + } else { + reject('CUDA Version not found.'); + } + } else { + reject('CUDA Version not found.'); + } + } else { + reject(error); + } + + }); + }); +}; + /** * Get GPU information from the system * @returns GPU information diff --git a/cortex-js/src/utils/model-check.ts b/cortex-js/src/utils/model-check.ts index 73e144295..1606cf111 100644 --- a/cortex-js/src/utils/model-check.ts +++ b/cortex-js/src/utils/model-check.ts @@ -1,11 +1,31 @@ -export const checkModelCompatibility = (modelId: string) => { +import { MIN_CUDA_VERSION } from "@/infrastructure/constants/cortex"; +import { getCudaVersion } from "./cuda"; + +export const checkModelCompatibility = async (modelId: string) => { if (modelId.includes('onnx') && process.platform !== 'win32') { console.error('The ONNX engine does not support this OS yet.'); process.exit(1); } - if (modelId.includes('tensorrt-llm') && process.platform === 'darwin') { - console.error('Tensorrt-LLM models are not supported on this OS'); - process.exit(1); + if (modelId.includes('tensorrt-llm') ) { + if(process.platform === 'darwin'){ + console.error('Tensorrt-LLM models are not supported on this OS'); + process.exit(1); + } + + try{ + const version = await getCudaVersion(); + const [currentMajor, currentMinor] = version.split('.').map(Number); + const [requiredMajor, requiredMinor] = MIN_CUDA_VERSION.split('.').map(Number); + const isMatchRequired = currentMajor > requiredMajor || (currentMajor === requiredMajor && currentMinor >= requiredMinor); + if (!isMatchRequired) { + console.error(`CUDA version ${version} is not compatible with TensorRT-LLM models. Required version: ${MIN_CUDA_VERSION}`); + process.exit(1); + } + } catch (e) { + console.error(e.message ?? e); + process.exit(1); + } + } };