Skip to content

Commit

Permalink
feat: improve podman cli execution
Browse files Browse the repository at this point in the history
Signed-off-by: axel7083 <[email protected]>
  • Loading branch information
axel7083 committed Oct 1, 2024
1 parent 9a1421f commit 1f0b6f5
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 168 deletions.
22 changes: 4 additions & 18 deletions packages/backend/src/managers/modelsManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import os from 'node:os';
import fs, { type Stats, type PathLike } from 'node:fs';
import path from 'node:path';
import { ModelsManager } from './modelsManager';
import { env, process as coreProcess } from '@podman-desktop/api';
import { env } from '@podman-desktop/api';
import type { RunResult, TelemetryLogger, Webview, ContainerProviderConnection } from '@podman-desktop/api';
import type { CatalogManager } from './catalogManager';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
Expand All @@ -33,7 +33,6 @@ import type { GGUFParseOutput } from '@huggingface/gguf';
import { gguf } from '@huggingface/gguf';
import type { PodmanConnection } from './podmanConnection';
import { VMType } from '@shared/src/models/IPodman';
import { getPodmanMachineName } from '../utils/podman';
import type { ConfigurationRegistry } from '../registries/ConfigurationRegistry';
import { Uploader } from '../utils/uploader';

Expand All @@ -47,7 +46,6 @@ const mocks = vi.hoisted(() => {
getTargetMock: vi.fn(),
getDownloaderCompleter: vi.fn(),
isCompletionEventMock: vi.fn(),
getPodmanCliMock: vi.fn(),
};
});

Expand All @@ -59,11 +57,6 @@ vi.mock('@huggingface/gguf', () => ({
gguf: vi.fn(),
}));

vi.mock('../utils/podman', () => ({
getPodmanCli: mocks.getPodmanCliMock,
getPodmanMachineName: vi.fn(),
}));

vi.mock('@podman-desktop/api', () => {
return {
Disposable: {
Expand All @@ -72,9 +65,6 @@ vi.mock('@podman-desktop/api', () => {
env: {
isWindows: false,
},
process: {
exec: vi.fn(),
},
fs: {
createFileSystemWatcher: (): unknown => ({
onDidCreate: vi.fn(),
Expand Down Expand Up @@ -102,6 +92,7 @@ vi.mock('../utils/downloader', () => ({

const podmanConnectionMock = {
getContainerProviderConnections: vi.fn(),
executeSSH: vi.fn(),
} as unknown as PodmanConnection;

const cancellationTokenRegistryMock = {
Expand Down Expand Up @@ -598,8 +589,7 @@ describe('deleting models', () => {
});

test('deleting on windows should check for all connections', async () => {
vi.mocked(coreProcess.exec).mockResolvedValue({} as RunResult);
mocks.getPodmanCliMock.mockReturnValue('dummyCli');
vi.mocked(podmanConnectionMock.executeSSH).mockResolvedValue({} as RunResult);
vi.mocked(env).isWindows = true;
const connections: ContainerProviderConnection[] = [
{
Expand All @@ -622,7 +612,6 @@ describe('deleting models', () => {
},
];
vi.mocked(podmanConnectionMock.getContainerProviderConnections).mockReturnValue(connections);
vi.mocked(getPodmanMachineName).mockReturnValue('machine-2');

const rmSpy = vi.spyOn(fs.promises, 'rm');
rmSpy.mockResolvedValue(undefined);
Expand Down Expand Up @@ -659,10 +648,7 @@ describe('deleting models', () => {

expect(podmanConnectionMock.getContainerProviderConnections).toHaveBeenCalledOnce();

expect(coreProcess.exec).toHaveBeenCalledWith('dummyCli', [
'machine',
'ssh',
'machine-2',
expect(podmanConnectionMock.executeSSH).toHaveBeenCalledWith(connections[1], [
'rm',
'-f',
'/home/user/ai-lab/models/dummyFile',
Expand Down
32 changes: 23 additions & 9 deletions packages/backend/src/managers/modelsManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ import type { Task } from '@shared/src/models/ITask';
import type { BaseEvent } from '../models/baseEvent';
import { isCompletionEvent, isProgressEvent } from '../models/baseEvent';
import { Uploader } from '../utils/uploader';
import { deleteRemoteModel, getLocalModelFile, isModelUploaded } from '../utils/modelsUtils';
import { getPodmanMachineName } from '../utils/podman';
import { getLocalModelFile, getRemoteModelFile } from '../utils/modelsUtils';
import type { CancellationTokenRegistry } from '../registries/CancellationTokenRegistry';
import { getHash, hasValidSha } from '../utils/sha';
import type { GGUFParseOutput } from '@huggingface/gguf';
Expand Down Expand Up @@ -231,17 +230,32 @@ export class ModelsManager implements Disposable {
for (const connection of connections) {
// ignore non-wsl machines
if (connection.vmType !== VMType.WSL) continue;
// Get the corresponding machine name
const machineName = getPodmanMachineName(connection);

// check if model already loaded on the podman machine
const existsRemote = await isModelUploaded(machineName, modelInfo);
if (!existsRemote) return;
// check if remote model is
try {
await this.podmanConnection.executeSSH(connection, ['stat', getRemoteModelFile(modelInfo)]);
} catch (err: unknown) {
console.warn(err);
continue;
}

await deleteRemoteModel(machineName, modelInfo);
await this.deleteRemoteModelByConnection(connection, modelInfo);
}
}

/**
* Delete a model given a {@link ContainerProviderConnection}
* @param connection
* @param modelInfo
* @protected
*/
protected async deleteRemoteModelByConnection(
connection: ContainerProviderConnection,
modelInfo: ModelInfo,
): Promise<void> {
await this.podmanConnection.executeSSH(connection, ['rm', '-f', getRemoteModelFile(modelInfo)]);
}

/**
* This method will resolve when the provided model will be downloaded.
*
Expand Down Expand Up @@ -439,7 +453,7 @@ export class ModelsManager implements Disposable {
connection: connection.name,
});

const uploader = new Uploader(connection, model);
const uploader = new Uploader(this.podmanConnection, connection, model);
uploader.onEvent(event => this.onDownloadUploadEvent(event, 'upload'), this);

// perform download
Expand Down
61 changes: 1 addition & 60 deletions packages/backend/src/utils/modelsUtils.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,8 @@
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/
import { beforeEach, describe, expect, test, vi } from 'vitest';
import { process as apiProcess } from '@podman-desktop/api';
import {
deleteRemoteModel,
getLocalModelFile,
getRemoteModelFile,
isModelUploaded,
MACHINE_BASE_FOLDER,
} from './modelsUtils';
import { getLocalModelFile, getRemoteModelFile, MACHINE_BASE_FOLDER } from './modelsUtils';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import { getPodmanCli } from './podman';

vi.mock('@podman-desktop/api', () => {
return {
Expand All @@ -35,14 +27,8 @@ vi.mock('@podman-desktop/api', () => {
};
});

vi.mock('./podman', () => ({
getPodmanCli: vi.fn(),
}));

beforeEach(() => {
vi.resetAllMocks();

vi.mocked(getPodmanCli).mockReturnValue('dummyPodmanCli');
});

describe('getLocalModelFile', () => {
Expand Down Expand Up @@ -94,48 +80,3 @@ describe('getRemoteModelFile', () => {
expect(path).toBe(`${MACHINE_BASE_FOLDER}dummy.guff`);
});
});

describe('isModelUploaded', () => {
test('execute stat on targeted machine', async () => {
expect(
await isModelUploaded('dummyMachine', {
id: 'dummyModelId',
file: {
path: 'dummyPath',
file: 'dummy.guff',
},
} as unknown as ModelInfo),
).toBeTruthy();

expect(getPodmanCli).toHaveBeenCalled();
expect(apiProcess.exec).toHaveBeenCalledWith('dummyPodmanCli', [
'machine',
'ssh',
'dummyMachine',
'stat',
expect.anything(),
]);
});
});

describe('deleteRemoteModel', () => {
test('execute stat on targeted machine', async () => {
await deleteRemoteModel('dummyMachine', {
id: 'dummyModelId',
file: {
path: 'dummyPath',
file: 'dummy.guff',
},
} as unknown as ModelInfo);

expect(getPodmanCli).toHaveBeenCalled();
expect(apiProcess.exec).toHaveBeenCalledWith('dummyPodmanCli', [
'machine',
'ssh',
'dummyMachine',
'rm',
'-f',
expect.anything(),
]);
});
});
32 changes: 0 additions & 32 deletions packages/backend/src/utils/modelsUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
***********************************************************************/
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import { join, posix } from 'node:path';
import { getPodmanCli } from './podman';
import { process } from '@podman-desktop/api';

export const MACHINE_BASE_FOLDER = '/home/user/ai-lab/models/';

Expand All @@ -42,36 +40,6 @@ export function getRemoteModelFile(modelInfo: ModelInfo): string {
return posix.join(MACHINE_BASE_FOLDER, modelInfo.file.file);
}

/**
* utility method to determine if a model is already uploaded to the podman machine
* @param machine
* @param modelInfo
*/
export async function isModelUploaded(machine: string, modelInfo: ModelInfo): Promise<boolean> {
try {
const remotePath = getRemoteModelFile(modelInfo);
await process.exec(getPodmanCli(), ['machine', 'ssh', machine, 'stat', remotePath]);
return true;
} catch (err: unknown) {
console.error('Something went wrong while trying to stat remote model path', err);
return false;
}
}

/**
* Given a machine and a modelInfo, delete the corresponding file on the podman machine
* @param machine the machine to target
* @param modelInfo the model info
*/
export async function deleteRemoteModel(machine: string, modelInfo: ModelInfo): Promise<void> {
try {
const remotePath = getRemoteModelFile(modelInfo);
await process.exec(getPodmanCli(), ['machine', 'ssh', machine, 'rm', '-f', remotePath]);
} catch (err: unknown) {
console.error('Something went wrong while trying to stat remote model path', err);
}
}

export function getModelPropertiesForEnvironment(modelInfo: ModelInfo): string[] {
const envs: string[] = [];
if (modelInfo.properties) {
Expand Down
8 changes: 7 additions & 1 deletion packages/backend/src/utils/podman.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ export type MachineJSON = {
VMType?: string;
};

/**
* We should be using the {@link @podman-desktop/api#extensions.getExtensions} function to get podman
* exec method
* @deprecated
*/
export function getPodmanCli(): string {
// If we have a custom binary path regardless if we are running Windows or not
const customBinaryPath = getCustomBinaryPath();
Expand All @@ -52,8 +57,9 @@ export function getCustomBinaryPath(): string | undefined {
}

/**
* In the ${link ContainerProviderConnection.name} property the name is not usage, and we need to transform it
* In the {@link ContainerProviderConnection.name} property the name is not usage, and we need to transform it
* @param connection
* @deprecated
*/
export function getPodmanMachineName(connection: ContainerProviderConnection): string {
const runningConnectionName = connection.name;
Expand Down
10 changes: 6 additions & 4 deletions packages/backend/src/utils/uploader.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,13 @@ import { Uploader } from './uploader';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import type { ContainerProviderConnection } from '@podman-desktop/api';
import { VMType } from '@shared/src/models/IPodman';
import type { PodmanConnection } from '../managers/podmanConnection';

vi.mock('@podman-desktop/api', async () => {
return {
env: {
isWindows: false,
},
process: {
exec: vi.fn(),
},
EventEmitter: vi.fn().mockImplementation(() => {
return {
fire: vi.fn(),
Expand All @@ -41,6 +39,10 @@ vi.mock('@podman-desktop/api', async () => {
};
});

const podmanConnectionMock: PodmanConnection = {
executeSSH: vi.fn(),
} as unknown as PodmanConnection;

const connectionMock: ContainerProviderConnection = {
name: 'machine2',
type: 'podman',
Expand All @@ -51,7 +53,7 @@ const connectionMock: ContainerProviderConnection = {
},
};

const uploader = new Uploader(connectionMock, {
const uploader = new Uploader(podmanConnectionMock, connectionMock, {
id: 'dummyModelId',
file: {
file: 'dummyFile.guff',
Expand Down
4 changes: 3 additions & 1 deletion packages/backend/src/utils/uploader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,20 @@ import type { ModelInfo } from '@shared/src/models/IModelInfo';
import { getLocalModelFile } from './modelsUtils';
import type { IWorker } from '../workers/IWorker';
import type { UploaderOptions } from '../workers/uploader/UploaderOptions';
import type { PodmanConnection } from '../managers/podmanConnection';

export class Uploader {
readonly #_onEvent = new EventEmitter<BaseEvent>();
readonly onEvent: Event<BaseEvent> = this.#_onEvent.event;
readonly #workers: IWorker<UploaderOptions, string>[] = [];

constructor(
podman: PodmanConnection,
private connection: ContainerProviderConnection,
private modelInfo: ModelInfo,
private abortSignal?: AbortSignal,
) {
this.#workers = [new WSLUploader()];
this.#workers = [new WSLUploader(podman)];
}

/**
Expand Down
Loading

0 comments on commit 1f0b6f5

Please sign in to comment.