diff --git a/packages/backend/src/managers/GPUManager.spec.ts b/packages/backend/src/managers/GPUManager.spec.ts index f87470236..f3e91aec7 100644 --- a/packages/backend/src/managers/GPUManager.spec.ts +++ b/packages/backend/src/managers/GPUManager.spec.ts @@ -15,21 +15,39 @@ * * SPDX-License-Identifier: Apache-2.0 ***********************************************************************/ -import { expect, test, vi, beforeEach } from 'vitest'; -import type { Webview } from '@podman-desktop/api'; +import { beforeEach, describe, expect, test, vi } from 'vitest'; +import { type ContainerProviderConnection, type Webview, process, env } from '@podman-desktop/api'; import { GPUManager } from './GPUManager'; import { graphics, type Systeminformation } from 'systeminformation'; -import { GPUVendor } from '@shared/src/models/IGPUInfo'; +import { + type ContainerDeviceInterface, + GPUVendor, + type IGPUInfo, + type NvidiaCTKVersion, +} from '@shared/src/models/IGPUInfo'; +import type { PodmanConnection } from './podmanConnection'; +import { readFile, stat } from 'node:fs/promises'; +import type { Stats } from 'node:fs'; vi.mock('../utils/inferenceUtils', () => ({ getProviderContainerConnection: vi.fn(), getImageInfo: vi.fn(), })); +vi.mock('node:fs/promises', () => ({ + stat: vi.fn(), + readFile: vi.fn(), +})); + vi.mock('@podman-desktop/api', async () => { return { env: { isWindows: false, + isLinux: false, + isMac: false, + }, + process: { + exec: vi.fn(), }, }; }); @@ -42,13 +60,21 @@ const webviewMock = { postMessage: vi.fn(), } as unknown as Webview; +const podmanConnectionMock: PodmanConnection = { + executeSSH: vi.fn(), +} as unknown as PodmanConnection; + beforeEach(() => { vi.resetAllMocks(); vi.mocked(webviewMock.postMessage).mockResolvedValue(true); + + (env.isLinux as boolean) = false; + (env.isWindows as boolean) = false; + (env.isMac as boolean) = false; }); test('post constructor should have no items', () => { - const manager = new GPUManager(webviewMock); + const manager = new GPUManager(webviewMock, podmanConnectionMock); expect(manager.getAll().length).toBe(0); }); @@ -58,7 +84,7 @@ test('no controller should return empty array', async () => { displays: [], }); - const manager = new GPUManager(webviewMock); + const manager = new GPUManager(webviewMock, podmanConnectionMock); expect(await manager.collectGPUs()).toHaveLength(0); }); @@ -74,7 +100,7 @@ test('intel controller should return intel vendor', async () => { displays: [], }); - const manager = new GPUManager(webviewMock); + const manager = new GPUManager(webviewMock, podmanConnectionMock); expect(await manager.collectGPUs()).toStrictEqual([ { vendor: GPUVendor.INTEL, @@ -96,7 +122,7 @@ test('NVIDIA controller should return intel vendor', async () => { displays: [], }); - const manager = new GPUManager(webviewMock); + const manager = new GPUManager(webviewMock, podmanConnectionMock); expect(await manager.collectGPUs()).toStrictEqual([ { vendor: GPUVendor.NVIDIA, @@ -105,3 +131,208 @@ test('NVIDIA controller should return intel vendor', async () => { }, ]); }); + +class GPUManagerTest extends GPUManager { + public override parseNvidiaCTKVersion(stdout: string): NvidiaCTKVersion { + return super.parseNvidiaCTKVersion(stdout); + } + + public override async getNvidiaContainerToolKitVersion( + connection: ContainerProviderConnection, + ): Promise { + return super.getNvidiaContainerToolKitVersion(connection); + } + + public override parseNvidiaCDI(stdout: string): ContainerDeviceInterface { + return super.parseNvidiaCDI(stdout); + } + + public override async getNvidiaCDI(connection: ContainerProviderConnection): Promise { + return super.getNvidiaCDI(connection); + } + + public override getAll(): IGPUInfo[] { + return [ + { + vendor: GPUVendor.NVIDIA, + model: 'demo-model', + vram: 555, + }, + ]; + } +} + +const NVIDIA_CTK_VERSION = `NVIDIA Container Toolkit CLI version 1.16.1\ncommit: a470818ba7d9166be282cd0039dd2fc9b0a34d73`; + +describe('parseNvidiaCTKVersion', () => { + test('valid stdout', () => { + const manager = new GPUManagerTest(webviewMock, podmanConnectionMock); + const { version, commit } = manager.parseNvidiaCTKVersion(NVIDIA_CTK_VERSION); + expect(version).toBe('1.16.1'); + expect(commit).toBe('a470818ba7d9166be282cd0039dd2fc9b0a34d73'); + }); + + test('empty stdout', () => { + const manager = new GPUManagerTest(webviewMock, podmanConnectionMock); + expect(() => { + manager.parseNvidiaCTKVersion(''); + }).toThrowError('malformed version output'); + }); +}); + +const NVIDIA_CDI = ` +--- +cdiVersion: 0.3.0 +devices: +- containerEdits: + deviceNodes: + - path: /dev/dxg + name: all +kind: nvidia.com/gpu +`; + +describe('parseNvidiaCDI', () => { + test('valid stdout', () => { + const manager = new GPUManagerTest(webviewMock, podmanConnectionMock); + const { cdiVersion, devices, kind } = manager.parseNvidiaCDI(NVIDIA_CDI); + expect(cdiVersion).toBe('0.3.0'); + expect(devices).toStrictEqual([ + { + containerEdits: { + deviceNodes: [ + { + path: '/dev/dxg', + }, + ], + }, + name: 'all', + }, + ]); + expect(kind).toBe('nvidia.com/gpu'); + }); + + test('empty stdout', () => { + const manager = new GPUManagerTest(webviewMock, podmanConnectionMock); + expect(() => { + manager.parseNvidiaCDI(''); + }).toThrowError('malformed output nvidia CDI output'); + }); +}); + +describe('getNvidiaContainerToolKitVersion', () => { + beforeEach(() => { + vi.resetAllMocks(); + vi.mocked(podmanConnectionMock.executeSSH).mockResolvedValue({ + stdout: NVIDIA_CTK_VERSION, + stderr: '', + command: '', + }); + vi.mocked(process.exec).mockResolvedValue({ + stdout: NVIDIA_CTK_VERSION, + stderr: '', + command: '', + }); + }); + + test('windows wsl connection should use executeSSH', async () => { + (env.isWindows as boolean) = true; + const manager = new GPUManagerTest(webviewMock, podmanConnectionMock); + const { version, commit } = await manager.getNvidiaContainerToolKitVersion(WSL_CONNECTION); + expect(version).toBe('1.16.1'); + expect(commit).toBe('a470818ba7d9166be282cd0039dd2fc9b0a34d73'); + + expect(podmanConnectionMock.executeSSH).toHaveBeenCalledWith(WSL_CONNECTION, ['nvidia-ctk', '--quiet', '-v']); + expect(process.exec).not.toHaveBeenCalled(); + }); + + test('connection without vmType on non-linux system should throw an error', async () => { + (env.isLinux as boolean) = false; + const manager = new GPUManagerTest(webviewMock, podmanConnectionMock); + await expect(async () => { + await manager.getNvidiaContainerToolKitVersion(NATIVE_CONNECTION); + }).rejects.toThrowError('cannot determine the environment to execute nvidia-ctk'); + }); + + test('linux with native connection should execute nvidia-ctk on host', async () => { + (env.isLinux as boolean) = true; + + const manager = new GPUManagerTest(webviewMock, podmanConnectionMock); + const { version, commit } = await manager.getNvidiaContainerToolKitVersion(NATIVE_CONNECTION); + expect(version).toBe('1.16.1'); + expect(commit).toBe('a470818ba7d9166be282cd0039dd2fc9b0a34d73'); + + expect(process.exec).toHaveBeenCalledWith('nvidia-ctk', ['--quiet', '-v']); + expect(podmanConnectionMock.executeSSH).not.toHaveBeenCalled(); + }); +}); + +const NATIVE_CONNECTION: ContainerProviderConnection = { + status: () => 'started', + vmType: undefined, + name: 'podman', + type: 'podman', +} as unknown as ContainerProviderConnection; + +const WSL_CONNECTION: ContainerProviderConnection = { + status: () => 'started', + vmType: 'wsl', + name: 'podman', + type: 'podman', +} as unknown as ContainerProviderConnection; + +describe('getNvidiaCDI', () => { + beforeEach(() => { + vi.resetAllMocks(); + vi.mocked(podmanConnectionMock.executeSSH).mockResolvedValue({ + stdout: NVIDIA_CDI, + stderr: '', + command: '', + }); + }); + + test('windows wsl connection should use executeSSH', async () => { + (env.isWindows as boolean) = true; + const manager = new GPUManagerTest(webviewMock, podmanConnectionMock); + const { kind } = await manager.getNvidiaCDI(WSL_CONNECTION); + expect(kind).toBe('nvidia.com/gpu'); + expect(podmanConnectionMock.executeSSH).toHaveBeenCalledWith(WSL_CONNECTION, ['cat', '/etc/cdi/nvidia.yaml']); + expect(stat).not.toHaveBeenCalled(); + }); + + test('connection without vmType on non-linux system should throw an error', async () => { + (env.isLinux as boolean) = false; + const manager = new GPUManagerTest(webviewMock, podmanConnectionMock); + await expect(async () => { + await manager.getNvidiaCDI(NATIVE_CONNECTION); + }).rejects.toThrowError('cannot determine the environment to read nvidia CDI file'); + }); + + test('linux with native connection should read config on host', async () => { + (env.isLinux as boolean) = true; + vi.mocked(stat).mockResolvedValue({ + isFile: () => true, + } as Stats); + vi.mocked(readFile).mockResolvedValue(NVIDIA_CDI); + + const manager = new GPUManagerTest(webviewMock, podmanConnectionMock); + const { kind } = await manager.getNvidiaCDI(NATIVE_CONNECTION); + expect(kind).toBe('nvidia.com/gpu'); + // on native linux we should not ssh in the machine + expect(podmanConnectionMock.executeSSH).not.toHaveBeenCalled(); + expect(stat).toHaveBeenCalledWith('/etc/cdi/nvidia.yaml'); + expect(readFile).toHaveBeenCalledWith('/etc/cdi/nvidia.yaml', { encoding: 'utf8' }); + }); + + test('linux with native connection should throw an error if file does not exists', async () => { + (env.isLinux as boolean) = true; + vi.mocked(stat).mockRejectedValue(new Error('file do not exists')); + + const manager = new GPUManagerTest(webviewMock, podmanConnectionMock); + + await expect(async () => { + await manager.getNvidiaCDI(NATIVE_CONNECTION); + }).rejects.toThrowError('file do not exists'); + + expect(readFile).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/backend/src/managers/GPUManager.ts b/packages/backend/src/managers/GPUManager.ts index 1ddd324c2..5351d894c 100644 --- a/packages/backend/src/managers/GPUManager.ts +++ b/packages/backend/src/managers/GPUManager.ts @@ -15,11 +15,26 @@ * * SPDX-License-Identifier: Apache-2.0 ***********************************************************************/ -import { type Disposable, type Webview } from '@podman-desktop/api'; -import { GPUVendor, type IGPUInfo } from '@shared/src/models/IGPUInfo'; +import { + type ContainerProviderConnection, + type Disposable, + type Webview, + process, + type RunResult, + env, +} from '@podman-desktop/api'; +import { + type ContainerDeviceInterface, + GPUVendor, + type IGPUInfo, + type NvidiaCTKVersion, +} from '@shared/src/models/IGPUInfo'; import { Publisher } from '../utils/Publisher'; import { Messages } from '@shared/Messages'; import { graphics } from 'systeminformation'; +import type { PodmanConnection } from './podmanConnection'; +import { load } from 'js-yaml'; +import { readFile, stat } from 'node:fs/promises'; /** * @experimental @@ -27,7 +42,10 @@ import { graphics } from 'systeminformation'; export class GPUManager extends Publisher implements Disposable { #gpus: IGPUInfo[]; - constructor(webview: Webview) { + constructor( + webview: Webview, + private podman: PodmanConnection, + ) { super(webview, Messages.MSG_GPUS_UPDATE, () => this.getAll()); // init properties this.#gpus = []; @@ -41,11 +59,12 @@ export class GPUManager extends Publisher implements Disposable { async collectGPUs(): Promise { const { controllers } = await graphics(); - return controllers.map(controller => ({ + this.#gpus = controllers.map(controller => ({ vendor: this.getVendor(controller.vendor), model: controller.model, vram: controller.vram ?? undefined, })); + return this.getAll(); } protected getVendor(raw: string): GPUVendor { @@ -60,4 +79,100 @@ export class GPUManager extends Publisher implements Disposable { return GPUVendor.UNKNOWN; } } + + protected parseNvidiaCTKVersion(stdout: string): NvidiaCTKVersion { + const lines = stdout.split('\n'); + if (lines.length !== 2) throw new Error('malformed version output'); + return { + version: lines[0].substring('NVIDIA Container Toolkit CLI version'.length).trim(), + commit: lines[1].substring('commit:'.length).trim(), + }; + } + + protected async getNvidiaContainerToolKitVersion(connection: ContainerProviderConnection): Promise { + let result: RunResult; + + if (connection.vmType) { + // if vmType is defined we are working with virtual machine so we need to SSH in it + result = await this.podman.executeSSH(connection, ['nvidia-ctk', '--quiet', '-v']); + } else if (env.isLinux) { + // if vmType is undefined on linux system we are working with podman native + result = await process.exec('nvidia-ctk', ['--quiet', '-v']); + } else { + throw new Error('cannot determine the environment to execute nvidia-ctk'); + } + if (result.stderr.length > 0) throw new Error(result.stderr); + return this.parseNvidiaCTKVersion(result.stdout); + } + + protected parseNvidiaCDI(stdout: string): ContainerDeviceInterface { + const containerDeviceInterface: unknown = load(stdout); + if (!containerDeviceInterface || typeof containerDeviceInterface !== 'object') + throw new Error('malformed output nvidia CDI output'); + if (!('cdiVersion' in containerDeviceInterface)) throw new Error('missing cdiVersion in nvidia CDI'); + if (containerDeviceInterface.cdiVersion !== '0.3.0') + throw new Error('invalid cdiVersion: expected 0.3.0 received containerDeviceInterface.cdiVersion'); + + if (!('kind' in containerDeviceInterface)) throw new Error('missing kind in nvidia CDI'); + if (typeof containerDeviceInterface.kind !== 'string') throw new Error('malformed kind in nvidia CDI'); + + if (!('devices' in containerDeviceInterface)) throw new Error('missing devices in nvidia CDI'); + if (!Array.isArray(containerDeviceInterface.devices)) throw new Error('devices is malformed in nvidia CDI'); + + return { + cdiVersion: containerDeviceInterface.cdiVersion, + kind: containerDeviceInterface.kind, + devices: containerDeviceInterface.devices, + }; + } + + /** + * This method will parse the `/etc/cdi/nvidia.yaml` in the available ContainerProviderConnection + * @protected + */ + protected async getNvidiaCDI(connection: ContainerProviderConnection): Promise { + if (connection.vmType) { + // if vmType is defined we are working with virtual machine so we need to SSH in it + const { stdout, stderr } = await this.podman.executeSSH(connection, ['cat', '/etc/cdi/nvidia.yaml']); + if (stderr.length > 0) throw new Error(stderr); + return this.parseNvidiaCDI(stdout); + } + + if (!env.isLinux) { + throw new Error('cannot determine the environment to read nvidia CDI file'); + } + + // if vmType is undefined on linux system we are working with podman native + const info = await stat('/etc/cdi/nvidia.yaml'); + if (!info.isFile()) throw new Error('invalid /etc/cdi/nvidia.yaml file'); + + const content = await readFile('/etc/cdi/nvidia.yaml', { encoding: 'utf8' }); + return this.parseNvidiaCDI(content); + } + + /** + * see https://github.com/cncf-tags/container-device-interface + * @param connection + */ + public async getGPUContainerDeviceInterface( + connection: ContainerProviderConnection, + ): Promise { + const gpus = this.getAll(); + // ensure at least one GPU is available + if (gpus.length === 0) { + throw new Error('no gpu has been detected'); + } + + // ensure all GPU(s) are NVIDIA vendor + if (gpus.some(gpu => gpu.vendor !== GPUVendor.NVIDIA)) { + throw new Error('cannot get container device interface for non-nvidia GPU(s)'); + } + + // check nvidia-ctk version + const { version } = await this.getNvidiaContainerToolKitVersion(connection); + console.log('nvidia-ctk version', version); + + // get the nvidia Container device interface + return this.getNvidiaCDI(connection); + } } diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index d1e50a4e6..c42c36bb8 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -241,7 +241,7 @@ export class Studio { /** * GPUManager is a class responsible for detecting and storing the GPU specs */ - this.#gpuManager = new GPUManager(this.#panel.webview); + this.#gpuManager = new GPUManager(this.#panel.webview, this.#podmanConnection); this.#extensionContext.subscriptions.push(this.#gpuManager); /** diff --git a/packages/shared/src/models/IGPUInfo.ts b/packages/shared/src/models/IGPUInfo.ts index 7bba4b2db..913462d7c 100644 --- a/packages/shared/src/models/IGPUInfo.ts +++ b/packages/shared/src/models/IGPUInfo.ts @@ -28,3 +28,19 @@ export enum GPUVendor { INTEL = 'Intel Corporation', UNKNOWN = 'unknown', } + +export interface NvidiaCTKVersion { + version: string; + commit: string; +} + +/** + * ref https://github.com/cncf-tags/container-device-interface + */ +export interface ContainerDeviceInterface { + cdiVersion: string; + kind: string; + devices: { + name: string; + }[]; +}