Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(win): better gpu detection #1141

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/backend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,21 @@
"watch": "vite --mode development build -w"
},
"dependencies": {
"fast-xml-parser": "^4.4.0",
"isomorphic-git": "^1.25.10",
"mustache": "^4.2.0",
"openai": "^4.47.2",
"postman-code-generators": "^1.10.1",
"postman-collection": "^4.4.0",
"semver": "^7.6.2",
"winreg": "^1.2.5",
"xml-js": "^1.6.11"
},
"devDependencies": {
"@podman-desktop/api": "0.0.202404101645-5d46ba5",
"@types/js-yaml": "^4.0.9",
"@types/node": "^20",
"@types/postman-collection": "^3.5.10",
"@types/winreg": "^1.2.36",
"vitest": "^1.6.0",
"@types/mustache": "^4.2.5"
}
Expand Down
79 changes: 3 additions & 76 deletions packages/backend/src/managers/GPUManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/
import { expect, test, vi, beforeEach } from 'vitest';
import { containerEngine, env } from '@podman-desktop/api';
import type { ContainerInspectInfo, ContainerProviderConnection, ImageInfo, Webview } from '@podman-desktop/api';
import { env } from '@podman-desktop/api';
import type { Webview } from '@podman-desktop/api';
import { GPUManager } from './GPUManager';
import { getImageInfo, getProviderContainerConnection } from '../utils/inferenceUtils';
import { XMLParser } from 'fast-xml-parser';

vi.mock('../utils/inferenceUtils', () => ({
getProviderContainerConnection: vi.fn(),
Expand All @@ -29,12 +27,6 @@ vi.mock('../utils/inferenceUtils', () => ({

vi.mock('@podman-desktop/api', async () => {
return {
containerEngine: {
createContainer: vi.fn(),
logsContainer: vi.fn(),
deleteContainer: vi.fn(),
inspectContainer: vi.fn(),
},
env: {
isWindows: false,
},
Expand All @@ -52,47 +44,6 @@ const webviewMock = {
beforeEach(() => {
vi.resetAllMocks();
vi.mocked(webviewMock.postMessage).mockResolvedValue(true);

vi.mocked(getProviderContainerConnection).mockReturnValue({
providerId: 'dummyProviderId',
connection: {} as unknown as ContainerProviderConnection,
});
vi.mocked(getImageInfo).mockResolvedValue({
engineId: 'dummyEngineId',
Id: 'dummyImageId',
} as unknown as ImageInfo);

vi.mocked(containerEngine.createContainer).mockResolvedValue({
id: 'dummyContainerId',
});

vi.mocked(containerEngine.logsContainer).mockImplementation(async (_engineId, _containerId, callback) => {
callback('', '</nvidia_smi_log>');
});

vi.mocked(XMLParser).mockReturnValue({
parse: vi.fn().mockReturnValue({
nvidia_smi_log: {
attached_gpus: 1,
cuda_version: 2,
driver_version: 3,
timestamp: 4,
gpu: {
uuid: 'dummyUUID',
product_name: 'dummyProductName',
},
},
}),
} as unknown as XMLParser);

vi.mocked(containerEngine.inspectContainer).mockImplementation(async (_engineId, _id) => {
return {
State: {
Running: false,
ExitCode: 0,
},
} as unknown as ContainerInspectInfo;
});
});

test('post constructor should have no items', () => {
Expand All @@ -106,29 +57,5 @@ test('non-windows host should throw error', async () => {
const manager = new GPUManager(webviewMock);
await expect(() => {
return manager.collectGPUs();
}).rejects.toThrowError('Cannot collect GPUs information on this machine.');
});

test('windows host should start then delete container with proper configuration', async () => {
vi.mocked(env).isWindows = true;

const manager = new GPUManager(webviewMock);
const gpus = await manager.collectGPUs({
providerId: 'dummyProviderId',
});

expect(gpus.length).toBe(1);
expect(gpus[0].uuid).toBe('dummyUUID');
expect(gpus[0].product_name).toBe('dummyProductName');

expect(getProviderContainerConnection).toHaveBeenCalledWith('dummyProviderId');

expect(containerEngine.createContainer).toHaveBeenCalledWith('dummyEngineId', {
Image: 'dummyImageId',
Cmd: expect.anything(),
Detach: false,
Entrypoint: '/usr/bin/sh',
HostConfig: expect.anything(),
});
expect(containerEngine.deleteContainer).toHaveBeenCalledWith('dummyEngineId', 'dummyContainerId');
}).rejects.toThrowError();
});
155 changes: 18 additions & 137 deletions packages/backend/src/managers/GPUManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,159 +15,40 @@
*
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/
import {
containerEngine,
type Disposable,
type Webview,
type ImageInfo,
type PullEvent,
type ContainerCreateOptions,
env,
} from '@podman-desktop/api';
import { getImageInfo, getProviderContainerConnection } from '../utils/inferenceUtils';
import { XMLParser } from 'fast-xml-parser';
import { type Disposable, type Webview } from '@podman-desktop/api';
import type { IGPUInfo } from '@shared/src/models/IGPUInfo';
import { Publisher } from '../utils/Publisher';
import { Messages } from '@shared/Messages';

export const CUDA_UBI8_IMAGE = 'nvcr.io/nvidia/cuda:12.3.2-devel-ubi8';
import type { IWorker } from '../workers/IWorker';
import { WinGPUDetector } from '../workers/gpu/WinGPUDetector';
import { platform } from 'node:os';

/**
* @experimental
*/
export class GPUManager extends Publisher<IGPUInfo[]> implements Disposable {
// Map uuid -> info
#gpus: Map<string, IGPUInfo>;
#gpus: IGPUInfo[];

#workers: IWorker<void, IGPUInfo[]>[];

constructor(webview: Webview) {
super(webview, Messages.MSG_GPUS_UPDATE, () => this.getAll());
this.#gpus = new Map();
}
dispose(): void {
this.#gpus.clear();
}

getAll(): IGPUInfo[] {
return Array.from(this.#gpus.values());
// init properties
this.#gpus = [];
this.#workers = [new WinGPUDetector()];
}

async collectGPUs(options?: { providerId: string }): Promise<IGPUInfo[]> {
if (!env.isWindows) {
throw new Error('Cannot collect GPUs information on this machine.');
}

const provider = getProviderContainerConnection(options?.providerId);
const imageInfo: ImageInfo = await getImageInfo(provider.connection, CUDA_UBI8_IMAGE, (_event: PullEvent) => {});

const result = await containerEngine.createContainer(
imageInfo.engineId,
this.getWindowsContainerCreateOptions(imageInfo),
);

const exitCode = await this.waitForExit(imageInfo.engineId, result.id);
if (exitCode !== 0) throw new Error(`nvidia CUDA Container exited with code ${exitCode}.`);

try {
const logs = await this.getLogs(imageInfo.engineId, result.id);
const parsed: {
nvidia_smi_log: {
attached_gpus: number;
cuda_version: number;
driver_version: number;
timestamp: string;
gpu: IGPUInfo;
};
} = new XMLParser().parse(logs);
dispose(): void {}

if (parsed.nvidia_smi_log.attached_gpus > 1) throw new Error('machine with more than one GPU are not supported.');

this.#gpus.set(parsed.nvidia_smi_log.gpu.uuid, parsed.nvidia_smi_log.gpu);
this.notify();
return this.getAll();
} finally {
await containerEngine.deleteContainer(imageInfo.engineId, result.id);
}
}

private getWindowsContainerCreateOptions(imageInfo: ImageInfo): ContainerCreateOptions {
return {
Image: imageInfo.Id,
Detach: false,
HostConfig: {
AutoRemove: false,
Mounts: [
{
Target: '/usr/lib/wsl',
Source: '/usr/lib/wsl',
Type: 'bind',
},
],
DeviceRequests: [
{
Capabilities: [['gpu']],
Count: -1, // -1: all
},
],
Devices: [
{
PathOnHost: '/dev/dxg',
PathInContainer: '/dev/dxg',
CgroupPermissions: 'r',
},
],
},
Entrypoint: '/usr/bin/sh',
Cmd: [
'-c',
'/usr/bin/ln -s /usr/lib/wsl/lib/* /usr/lib64/ && PATH="${PATH}:/usr/lib/wsl/lib/" && nvidia-smi -x -q',
],
};
}

private waitForExit(engineId: string, containerId: string): Promise<number> {
return new Promise<number>((resolve, reject) => {
let retry = 0;
const interval = setInterval(() => {
if (retry === 3) {
reject(new Error('timeout: container never exited.'));
return;
}

retry++;

containerEngine
.inspectContainer(engineId, containerId)
.then(inspectInfo => {
if (inspectInfo.State.Running) return;

clearInterval(interval);
resolve(inspectInfo.State.ExitCode);
})
.catch((err: unknown) => {
console.error('Something went wrong while trying to inspect container', err);
clearInterval(interval);
reject(new Error(`Failed to inspect container ${containerId}.`));
});
}, 2000);
});
getAll(): IGPUInfo[] {
return this.#gpus;
}

private getLogs(engineId: string, containerId: string): Promise<string> {
return new Promise<string>((resolve, reject) => {
const interval = setTimeout(() => {
reject(new Error('timeout'));
}, 10000);
async collectGPUs(): Promise<IGPUInfo[]> {
const worker = this.#workers.find(worker => worker.enabled());
if (worker === undefined) throw new Error(`no worker enable to collect GPU on platform ${platform}`);

let logs = '';
containerEngine
.logsContainer(engineId, containerId, (name, data) => {
logs += data;
if (data.includes('</nvidia_smi_log>')) {
clearTimeout(interval);
resolve(logs);
}
})
.catch(reject);
});
this.#gpus = await worker.perform();
return this.getAll();
}
}
Loading