Skip to content

Commit

Permalink
feat: prevent user from deleting model when inUse (#1321)
Browse files Browse the repository at this point in the history
* feat: prevent user from deleting model when inUse

Signed-off-by: axel7083 <[email protected]>

* fix: remove unused span

Signed-off-by: axel7083 <[email protected]>

* Apply suggestions from code review

Signed-off-by: axel7083 <[email protected]>

* Update packages/frontend/src/lib/table/model/ModelColumnIcon.svelte

Signed-off-by: axel7083 <[email protected]>

* Update packages/frontend/src/lib/table/model/ModelColumnActions.svelte

Signed-off-by: axel7083 <[email protected]>

* Update packages/frontend/src/lib/table/model/ModelColumnAction.spec.ts

Signed-off-by: axel7083 <[email protected]>

---------

Signed-off-by: axel7083 <[email protected]>
  • Loading branch information
axel7083 authored Jul 3, 2024
1 parent 824f329 commit 1024cee
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 23 deletions.
14 changes: 11 additions & 3 deletions packages/frontend/src/lib/icons/ModelWhite.svelte
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
<script lang="ts">
export let size = '40';
export let solid: boolean = false;
const fg = solid ? 'white' : '#888';
const fg = solid ? 'white' : 'currentColor';
</script>

<div role="img" class="rounded py-[6px] pl-[7px] pr-[5px]" class:bg-green-400="{solid}">
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<div role="img">
<svg
width="{size}"
height="{size}"
style="{$$props.style}"
class="{$$props.class}"
viewBox="0 0 16 16"
xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_47_118)">
<g clip-path="url(#clip1_47_118)">
<path
Expand Down
57 changes: 57 additions & 0 deletions packages/frontend/src/lib/table/model/ModelColumnAction.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/svelte';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import ModelColumnActions from '/@/lib/table/model/ModelColumnActions.svelte';
import { router } from 'tinro';
import { type InferenceServer, InferenceType } from '@shared/src/models/IInference';

const mocks = vi.hoisted(() => ({
requestRemoveLocalModel: vi.fn(),
openFile: vi.fn(),
downloadModel: vi.fn(),
getInferenceServersMock: vi.fn<void[], InferenceServer[]>(),
}));

vi.mock('/@/utils/client', () => ({
Expand All @@ -37,9 +39,20 @@ vi.mock('/@/utils/client', () => ({
},
}));

vi.mock('../../../stores/inferenceServers', () => ({
inferenceServers: {
subscribe: (f: (msg: InferenceServer[]) => void) => {
f(mocks.getInferenceServersMock());
return () => {};
},
},
}));

beforeEach(() => {
vi.resetAllMocks();

mocks.getInferenceServersMock.mockReturnValue([]);

mocks.downloadModel.mockResolvedValue(undefined);
mocks.openFile.mockResolvedValue(undefined);
mocks.requestRemoveLocalModel.mockResolvedValue(undefined);
Expand Down Expand Up @@ -160,3 +173,47 @@ test('Expect router to be called when rocket icon clicked', async () => {
expect(replaceMock).toHaveBeenCalledWith({ 'model-id': 'my-model' });
});
});

test('Expect delete button to be disabled when model in use', async () => {
const object: ModelInfo = {
id: 'my-model',
description: '',
hw: '',
license: '',
name: '',
registry: '',
url: '',
file: {
file: 'file',
creation: new Date(),
size: 1000,
path: 'path',
},
memory: 1000,
};

mocks.getInferenceServersMock.mockReturnValue([
{
models: [object],
type: InferenceType.LLAMA_CPP,
status: 'running',
container: {
containerId: '',
engineId: '',
},
connection: {
port: 0,
},
health: undefined,
},
]);
render(ModelColumnActions, { object });

const deleteBtn = screen.getByTitle('Delete Model');
expect(deleteBtn).toBeDefined();

await vi.waitFor(() => {
// disable class
expect(deleteBtn.classList).toContain('text-[var(--pd-action-button-disabled-text)]');
});
});
17 changes: 16 additions & 1 deletion packages/frontend/src/lib/table/model/ModelColumnActions.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@ import { faFolderOpen } from '@fortawesome/free-solid-svg-icons';
import ListItemButtonIcon from '../../button/ListItemButtonIcon.svelte';
import { studioClient } from '/@/utils/client';
import { router } from 'tinro';
import { onMount } from 'svelte';
import { inferenceServers } from '/@/stores/inferenceServers';
export let object: ModelInfo;
let inUse: boolean = false;
$: inUse;
function deleteModel() {
studioClient.requestRemoveLocalModel(object.id).catch(err => {
console.error(`Something went wrong while trying to delete model ${String(err)}.`);
Expand All @@ -31,6 +36,12 @@ function createModelService() {
router.goto('/service/create');
router.location.query.replace({ 'model-id': object.id });
}
onMount(() => {
return inferenceServers.subscribe(servers => {
inUse = servers.some(server => server.models.some(model => model.id === object.id));
});
});
</script>

{#if object.file !== undefined}
Expand All @@ -44,7 +55,11 @@ function createModelService() {
onClick="{() => openModelFolder()}"
title="Open Model Folder"
enabled="{!object.state}" />
<ListItemButtonIcon icon="{faTrash}" onClick="{deleteModel}" title="Delete Model" enabled="{!object.state}" />
<ListItemButtonIcon
icon="{faTrash}"
onClick="{deleteModel}"
title="Delete Model"
enabled="{!inUse && !object.state}" />
{:else}
<ListItemButtonIcon icon="{faDownload}" onClick="{downloadModel}" title="Download Model" enabled="{!object.state}" />
{/if}
94 changes: 77 additions & 17 deletions packages/frontend/src/lib/table/model/ModelColumnIcon.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,55 @@
***********************************************************************/

import '@testing-library/jest-dom/vitest';
import { test, expect } from 'vitest';
import { expect, test, vi, beforeEach } from 'vitest';
import { render, screen } from '@testing-library/svelte';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import ModelColumnIcon from './ModelColumnIcon.svelte';
import { type InferenceServer, InferenceType } from '@shared/src/models/IInference';

test('Expect green background when model has a file', async () => {
const d = new Date();
d.setDate(d.getDate() - 2);
const mocks = vi.hoisted(() => {
return {
getInferenceServersMock: vi.fn<void[], InferenceServer[]>(),
};
});

vi.mock('../../../stores/inferenceServers', () => ({
inferenceServers: {
subscribe: (f: (msg: InferenceServer[]) => void) => {
f(mocks.getInferenceServersMock());
return () => {};
},
},
}));

beforeEach(() => {
vi.resetAllMocks();
});

test('Expect remote model to have NONE title', async () => {
const object: ModelInfo = {
id: 'my-model',
id: 'model-downloaded-id',
description: '',
hw: '',
license: '',
name: '',
registry: '',
url: '',
memory: 1000,
};

mocks.getInferenceServersMock.mockReturnValue([]);

render(ModelColumnIcon, { object });

const role = screen.getByRole('status');
expect(role).toBeDefined();
expect(role.title).toBe('NONE');
});

test('Expect downloaded model to have DOWNLOADED title', async () => {
const object: ModelInfo = {
id: 'model-downloaded-id',
description: '',
hw: '',
license: '',
Expand All @@ -36,36 +74,58 @@ test('Expect green background when model has a file', async () => {
url: '',
file: {
file: 'file',
creation: d,
creation: undefined,
size: 1000,
path: 'path',
},
memory: 1000,
};

mocks.getInferenceServersMock.mockReturnValue([]);

render(ModelColumnIcon, { object });

const logo = screen.getByRole('img');
expect(logo).toBeInTheDocument();
expect(logo).toHaveClass(/^bg-green-/);
const role = screen.getByRole('status');
expect(role).toBeDefined();
expect(role.title).toBe('DOWNLOADED');
});

test('Expect non green background when model has no file', async () => {
const d = new Date();
d.setDate(d.getDate() - 2);

test('Expect in used model to have USED title', async () => {
const object: ModelInfo = {
id: 'my-model',
id: 'model-in-used-id',
description: '',
hw: '',
license: '',
name: '',
registry: '',
url: '',
file: {
file: 'file',
creation: undefined,
size: 1000,
path: 'path',
},
memory: 1000,
};

mocks.getInferenceServersMock.mockReturnValue([
{
models: [object],
type: InferenceType.LLAMA_CPP,
status: 'running',
container: {
containerId: '',
engineId: '',
},
connection: {
port: 0,
},
health: undefined,
},
]);
render(ModelColumnIcon, { object });

const logo = screen.getByRole('img');
expect(logo).toBeInTheDocument();
expect(logo).not.toHaveClass(/^bg-green-/);
const role = screen.getByRole('status');
expect(role).toBeDefined();
expect(role.title).toBe('USED');
});
18 changes: 17 additions & 1 deletion packages/frontend/src/lib/table/model/ModelColumnIcon.svelte
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
<script lang="ts">
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import ModelWhite from '../../icons/ModelWhite.svelte';
import { onMount } from 'svelte';
import { inferenceServers } from '/@/stores/inferenceServers';
import StatusIcon from '/@/lib/StatusIcon.svelte';
export let object: ModelInfo;
let status: string | undefined = undefined;
$: status;
onMount(() => {
return inferenceServers.subscribe(servers => {
if (servers.some(server => server.models.some(model => model.id === object.id))) {
status = 'USED';
} else {
status = object.file ? 'DOWNLOADED' : 'NONE';
}
});
});
</script>

<ModelWhite solid="{!!object.file}" />
<StatusIcon status="{status}" icon="{ModelWhite}" />
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import { expect, test, vi, describe } from 'vitest';
import { render, screen, fireEvent, waitFor } from '@testing-library/svelte';
import ServiceStatus from './ServiceStatus.svelte';
import { studioClient } from '/@/utils/client';
import type { InferenceServerStatus } from '@shared/src/models/IInference';
import { type InferenceServerStatus, InferenceType } from '@shared/src/models/IInference';

vi.mock('../../../utils/client', async () => ({
studioClient: {
Expand All @@ -39,6 +39,7 @@ describe('transition statuses', () => {
connection: { port: 8888 },
status: status,
container: { containerId: 'dummyContainerId', engineId: 'dummyEngineId' },
type: InferenceType.LLAMA_CPP,
},
});

Expand All @@ -62,6 +63,7 @@ describe('stable statuses', () => {
connection: { port: 8888 },
status: status,
container: { containerId: 'dummyContainerId', engineId: 'dummyEngineId' },
type: InferenceType.LLAMA_CPP,
},
});

Expand All @@ -86,6 +88,7 @@ test('defined health should not display a spinner', async () => {
connection: { port: 8888 },
status: 'running',
container: { containerId: 'dummyContainerId', engineId: 'dummyEngineId' },
type: InferenceType.LLAMA_CPP,
},
});

Expand All @@ -108,6 +111,7 @@ test('click on status icon should redirect to container', async () => {
connection: { port: 8888 },
status: 'running',
container: { containerId: 'dummyContainerId', engineId: 'dummyEngineId' },
type: InferenceType.LLAMA_CPP,
},
});
// Get button and click on it
Expand All @@ -126,6 +130,7 @@ test('error status should show degraded', async () => {
connection: { port: 8888 },
status: 'error',
container: { containerId: 'dummyContainerId', engineId: 'dummyEngineId' },
type: InferenceType.LLAMA_CPP,
},
});
// Get button and click on it
Expand All @@ -140,6 +145,7 @@ test('running status with no healthcheck should show starting', async () => {
connection: { port: 8888 },
status: 'running',
container: { containerId: 'dummyContainerId', engineId: 'dummyEngineId' },
type: InferenceType.LLAMA_CPP,
},
});
// Get button and click on it
Expand Down
10 changes: 10 additions & 0 deletions packages/frontend/src/pages/Models.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { screen, render, waitFor, within } from '@testing-library/svelte';
import Models from './Models.svelte';
import { router } from 'tinro';
import userEvent from '@testing-library/user-event';
import type { InferenceServer } from '@shared/src/models/IInference';

const mocks = vi.hoisted(() => {
return {
Expand All @@ -45,6 +46,15 @@ const mocks = vi.hoisted(() => {
};
});

vi.mock('../stores/inferenceServers', () => ({
inferenceServers: {
subscribe: (f: (msg: InferenceServer[]) => void) => {
f([]);
return () => {};
},
},
}));

vi.mock('/@/utils/client', async () => {
return {
studioClient: {
Expand Down
Loading

0 comments on commit 1024cee

Please sign in to comment.