diff --git a/cortex-js/src/utils/cuda.ts b/cortex-js/src/utils/cuda.ts index a20fa17f3..fc252a20a 100644 --- a/cortex-js/src/utils/cuda.ts +++ b/cortex-js/src/utils/cuda.ts @@ -3,6 +3,13 @@ import { existsSync } from 'fs'; import { delimiter } from 'path'; import { checkFileExistenceInPaths } from './app-path'; +export type GpuSettingInfo = { + id: string; + vram: string; + name: string; + arch?: string; +}; + /** * Return the CUDA version installed on the system * @returns CUDA Version 11 | 12 @@ -63,3 +70,46 @@ export const checkNvidiaGPUExist = (): Promise => { }); }); }; + +/** + * Get GPU information from the system + * @returns GPU information + */ +export const getGpuInfo = async (): Promise => + new Promise((resolve) => { + exec( + 'nvidia-smi --query-gpu=index,memory.total,name --format=csv,noheader,nounits', + async (error, stdout) => { + if (!error) { + // Get GPU info and gpu has higher memory first + let highestVram = 0; + let highestVramId = '0'; + const gpus: GpuSettingInfo[] = stdout + .trim() + .split('\n') + .map((line) => { + let [id, vram, name] = line.split(', '); + const arch = getGpuArch(name); + vram = vram.replace(/\r/g, ''); + if (parseFloat(vram) > highestVram) { + highestVram = parseFloat(vram); + highestVramId = id; + } + return { id, vram, name, arch }; + }); + + resolve(gpus); + } else { + resolve([]); + } + }, + ); + }); + +const getGpuArch = (gpuName: string): string => { + if (!gpuName.toLowerCase().includes('nvidia')) return 'unknown'; + + if (gpuName.includes('30')) return 'ampere'; + else if (gpuName.includes('40')) return 'ada'; + else return 'unknown'; +}; diff --git a/cortex-js/src/utils/huggingface.ts b/cortex-js/src/utils/huggingface.ts index 975b8fe89..de2a65dff 100644 --- a/cortex-js/src/utils/huggingface.ts +++ b/cortex-js/src/utils/huggingface.ts @@ -20,6 +20,7 @@ import { } from '@/infrastructure/constants/prompt-constants'; import { gguf } from '@huggingface/gguf'; import axios from 'axios'; +import { parseModelHubEngineBranch } from './normalize-model-id'; // TODO: move this to somewhere else, should be reused by API as well. Maybe in a separate service / provider? export function guessPromptTemplateFromHuggingFace(jinjaCode?: string): string { @@ -64,7 +65,6 @@ export function guessPromptTemplateFromHuggingFace(jinjaCode?: string): string { export async function fetchHuggingFaceRepoData( repoId: string, ): Promise { - const sanitizedUrl = getRepoModelsUrl(repoId); const { data: response } = await axios.get(sanitizedUrl); @@ -113,7 +113,7 @@ export async function fetchJanRepoData( modelId: string, ): Promise { const repo = modelId.split(':')[0]; - const tree = modelId.split(':')[1] ?? 'default'; + const tree = await parseModelHubEngineBranch(modelId.split(':')[1] ?? 'default'); const url = getRepoModelsUrl(`cortexhub/${repo}`, tree); const res = await fetch(url); @@ -164,8 +164,6 @@ export async function fetchJanRepoData( data.modelUrl = url; - - return data; } diff --git a/cortex-js/src/utils/normalize-model-id.ts b/cortex-js/src/utils/normalize-model-id.ts index 8c98e935e..f5d9e0b51 100644 --- a/cortex-js/src/utils/normalize-model-id.ts +++ b/cortex-js/src/utils/normalize-model-id.ts @@ -1,4 +1,5 @@ import { ModelArtifact } from '@/domain/models/model.interface'; +import { getGpuInfo } from './cuda'; export const normalizeModelId = (modelId: string): string => { return modelId.replace(':default', '').replace(/[:/]/g, '-'); @@ -13,3 +14,27 @@ export const isLocalModel = ( !/^(http|https):\/\/[^/]+\/.*/.test(modelFiles[0]) ); }; + +/** + * Parse the model hub engine branch + * @param branch + * @returns + */ +export const parseModelHubEngineBranch = async ( + branch: string, +): Promise => { + if (branch.includes('tensorrt')) { + let engineBranch = branch; + const platform = process.platform == 'win32' ? 'windows' : 'linux'; + if (!engineBranch.includes(platform)) { + engineBranch += `-${platform}`; + } + + const gpus = await getGpuInfo(); + if (gpus[0]?.arch && !engineBranch.includes(gpus[0].arch)) { + engineBranch += `-${gpus[0].arch}`; + } + return engineBranch; + } + return branch; +};