From d4f1177c58becb24bd154e6716c5447089c1f3b7 Mon Sep 17 00:00:00 2001 From: xyinshen Date: Wed, 22 Nov 2023 21:28:23 +0800 Subject: [PATCH 1/9] feat: update register modal Signed-off-by: xyinshen --- .../common/forms/model_file_format.tsx | 1 - .../forms/model_tag_array_field/tag_field.tsx | 2 +- public/components/monitoring/index.tsx | 1 - .../monitoring/model_connector_filter.tsx | 1 - .../register_model/model_source.tsx | 77 +++++++++++ .../register_model/pretrainedmodel_select.tsx | 128 ++++++++++++++++++ .../register_model/register_model.tsx | 96 +++++++++---- .../register_model/register_model.types.ts | 2 +- public/components/register_model/utils.ts | 6 +- .../register_model_type_modal/index.tsx | 108 +++++---------- 10 files changed, 317 insertions(+), 105 deletions(-) create mode 100644 public/components/register_model/model_source.tsx create mode 100644 public/components/register_model/pretrainedmodel_select.tsx diff --git a/public/components/common/forms/model_file_format.tsx b/public/components/common/forms/model_file_format.tsx index 4da864da..4ef3a446 100644 --- a/public/components/common/forms/model_file_format.tsx +++ b/public/components/common/forms/model_file_format.tsx @@ -37,7 +37,6 @@ export const ModelFileFormatSelect = ({ readOnly = false }: Props) => { }); const { ref: fileFormatInputRef, ...fileFormatField } = modelFileFormatController.field; - const selectedFileFormatOption = useMemo(() => { if (fileFormatField.value) { return FILE_FORMAT_OPTIONS.find((fmt) => fmt.value === fileFormatField.value); diff --git a/public/components/common/forms/model_tag_array_field/tag_field.tsx b/public/components/common/forms/model_tag_array_field/tag_field.tsx index bdec5bda..a6bdc8ce 100644 --- a/public/components/common/forms/model_tag_array_field/tag_field.tsx +++ b/public/components/common/forms/model_tag_array_field/tag_field.tsx @@ -354,7 +354,7 @@ export const TagField = ({ onRemove(index)} /> diff --git a/public/components/monitoring/index.tsx b/public/components/monitoring/index.tsx index e8516a41..0ed035b8 100644 --- a/public/components/monitoring/index.tsx +++ b/public/components/monitoring/index.tsx @@ -43,7 +43,6 @@ export const Monitoring = () => { } = useMonitoring(); const [previewModel, setPreviewModel] = useState(null); const searchInputRef = useRef(); - const setInputRef = useCallback((node: HTMLInputElement | null) => { searchInputRef.current = node; }, []); diff --git a/public/components/monitoring/model_connector_filter.tsx b/public/components/monitoring/model_connector_filter.tsx index 42d5434e..6094494c 100644 --- a/public/components/monitoring/model_connector_filter.tsx +++ b/public/components/monitoring/model_connector_filter.tsx @@ -36,7 +36,6 @@ export const ModelConnectorFilter = ({ ), [internalConnectorsResult?.data, allExternalConnectors] ); - return ( { + const { allExternalConnectors } = useMonitoring(); + const CONNECTOR_OPTIONS = allExternalConnectors?.map((item) => { + return Object.assign({}, { label: item.name, value: item.description }); + }); + const { control } = useFormContext<{ modelConnector: string }>(); + + const modelConnectorController = useController({ + name: 'modelConnector', + control, + rules: { + required: { + value: true, + message: '', + }, + }, + }); + const { ref: fileFormatInputRef, ...fileFormatField } = modelConnectorController.field; + const selectedConnectorOption = useMemo(() => { + if (fileFormatField.value) { + return CONNECTOR_OPTIONS?.find((connector) => connector.value === fileFormatField.value); + } + }, [fileFormatField, CONNECTOR_OPTIONS]); + + const onConnectorChange = useCallback( + (options: Array>) => { + const value = options[0]?.value; + fileFormatField.onChange(value); + }, + [fileFormatField] + ); + return ( +
+ +

Model source

+
+ + + External model source explained, connector provisioning, etc.{' '} + + Learn more + + . + + + + + + +
+ ); +}; diff --git a/public/components/register_model/pretrainedmodel_select.tsx b/public/components/register_model/pretrainedmodel_select.tsx new file mode 100644 index 00000000..7ed6e95e --- /dev/null +++ b/public/components/register_model/pretrainedmodel_select.tsx @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useCallback, useState, Fragment, useEffect } from 'react'; +import { + EuiSpacer, + EuiTextColor, + EuiSelectable, + EuiLink, + EuiSelectableOption, + EuiHighlight, +} from '@elastic/eui'; +import { useHistory } from 'react-router-dom'; +import { generatePath } from 'react-router-dom'; +import { modelRepositoryManager } from '../../utils/model_repository_manager'; +import { routerPaths } from '../../../common/router_paths'; +interface IItem { + label: string; + checked?: 'on' | undefined; + description: string; +} +interface Props { + getPreSelected: (val: boolean) => void; +} +const renderModelOption = (option: IItem, searchValue: string) => { + return ( + <> + {option.label} +
+ + + {option.description} + + + + ); +}; +export const PreTrainedModelSelect = ({ getPreSelected }: Props) => { + useEffect(() => { + const subscribe = modelRepositoryManager.getPreTrainedModels$().subscribe((models) => { + setModelRepoSelection( + Object.keys(models).map((name) => ({ + label: name, + description: models[name].description, + checked: undefined, + })) + ); + }); + return () => { + subscribe.unsubscribe(); + }; + }, []); + const ShowRest = useCallback( + (selected: boolean) => { + getPreSelected(selected); + }, + [getPreSelected] + ); + const [modelRepoSelection, setModelRepoSelection] = useState>>( + [] + ); + const history = useHistory(); + const onChange = useCallback( + (modelSelection: Array>) => { + setModelRepoSelection(modelSelection); + ShowRest(true); + }, + [ShowRest] + ); + useEffect(() => { + const selectedOption = modelRepoSelection.find((option) => option.checked === 'on'); + if (selectedOption?.label) { + history.push( + `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=import&name=${ + selectedOption?.label + }&version=${selectedOption?.label}` + ); + } + }, [modelRepoSelection, history]); + return ( +
+ + Model + + +
+ + For more information on each model, see + + + + OpenSearch model repository documentation + + +
+ + + {(list, search) => ( + + {search} + {list} + + )} + +
+ ); +}; diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index daf09248..8525c395 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -18,12 +18,12 @@ import { EuiFlexItem, EuiTextColor, EuiLink, - EuiLoadingSpinner, + EuiFormRow, + EuiCheckbox, EuiPageContent, } from '@elastic/eui'; import useObservable from 'react-use/lib/useObservable'; import { from } from 'rxjs'; - import { APIProvider } from '../../apis/api_provider'; import { useSearchParams } from '../../hooks/use_search_params'; import { isValidModelRegisterFormType } from './utils'; @@ -32,7 +32,7 @@ import { mountReactNode } from '../../../../../src/core/public/utils'; import { routerPaths } from '../../../common/router_paths'; import { ErrorCallOut } from '../../components/common'; import { modelRepositoryManager } from '../../utils/model_repository_manager'; - +import { PreTrainedModelSelect } from './pretrainedmodel_select'; import { modelTaskManager } from './model_task_manager'; import { ModelVersionNotesPanel } from './model_version_notes'; import { modelFileUploadManager } from './model_file_upload_manager'; @@ -43,6 +43,7 @@ import { ArtifactPanel } from './artifact'; import { ConfigurationPanel } from './model_configuration'; import { ModelTagsPanel } from './model_tags'; import { submitModelWithFile, submitModelWithURL } from './register_model_api'; +import { ModelSource } from './model_source'; const DEFAULT_VALUES = { name: '', @@ -88,13 +89,21 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo services: { chrome, notifications }, } = useOpenSearchDashboards(); const isLocked = useObservable(chrome?.getIsNavDrawerLocked$() ?? from([false])); - + const [preSelected, setPreSelect] = useState(false); + const getPreSelected = (val: boolean) => { + setPreSelect(val); + }; const formType = isValidModelRegisterFormType(typeParams) ? typeParams : 'upload'; - const [preTrainedModelLoading, setPreTrainedModelLoading] = useState(formType === 'import'); const partials = formType === 'import' - ? [ModelDetailsPanel, ModelTagsPanel, ModelVersionNotesPanel] - : [ + ? [ + PreTrainedModelSelect, + ...(!preSelected ? [] : [ModelDetailsPanel]), + ...(!preSelected ? [] : [ModelTagsPanel]), + ...(!preSelected ? [] : [ModelVersionNotesPanel]), + ] + : formType === 'upload' + ? [ ...(registerToModelId ? [] : [ModelOverviewTitle]), ...(registerToModelId ? [] : [ModelDetailsPanel]), ...(registerToModelId ? [] : [ModelTagsPanel]), @@ -103,8 +112,13 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo ConfigurationPanel, ...(registerToModelId ? [ModelTagsPanel] : []), ModelVersionNotesPanel, + ] + : [ + ...(registerToModelId ? [] : [ModelDetailsPanel]), + ...(registerToModelId ? [] : [ModelTagsPanel]), + ModelSource, + ModelVersionNotesPanel, ]; - const form = useForm({ mode: 'onChange', defaultValues, @@ -232,7 +246,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo const { config } = preTrainedModel; form.setValue('modelFileFormat', 'TORCH_SCRIPT'); if (config.name) { - form.setValue('name', `huggingface/${config.name}`); + form.setValue('name', config.name); } if (config.version) { form.setValue('version', config.version); @@ -240,7 +254,6 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo if (config.description) { form.setValue('description', config.description); } - setPreTrainedModelLoading(false); }, (error) => { // TODO: Should handle loading error here @@ -266,7 +279,17 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo const errorCount = Object.keys(form.formState.errors).length; const formHeader = ( <> - + {registerToModelId && ( @@ -279,6 +302,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo . )} + {formType === 'external' && !registerToModelId && <>Description lorem.} {formType === 'import' && !registerToModelId && <>Register a pre-trained model.} {formType === 'upload' && !registerToModelId && ( <> @@ -290,17 +314,33 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo ); - - if (preTrainedModelLoading) { - return ( - - {formHeader} - - - - ); - } - + const [checked, setChecked] = useState(false); + const onChange = (e: any) => { + setChecked(e.target.checked); + }; + const formFooter = ( + +
+ {Needs a description} + {(formType === 'upload' || formType === 'import') && ( + onChange(e)} + /> + )} + {formType === 'external' && !registerToModelId && ( + onChange(e)} + /> + )} +
+
+ ); return ( ( - + {FormPartial === PreTrainedModelSelect ? ( + + ) : ( + + )} {FormPartial === ModelOverviewTitle || FormPartial === FileAndVersionTitle ? ( ) : ( @@ -334,6 +378,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo )} ))} + {formType === 'import' ? preSelected && formFooter : formFooter} @@ -352,6 +397,11 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo )} + + setIsSubmitted(true)} iconType="cross" color="ghost"> + Cancel + + void; @@ -39,19 +37,6 @@ interface IItem { checked?: 'on' | undefined; description: string; } -const renderModelOption = (option: IItem, searchValue: string) => { - return ( - <> - {option.label} -
- - - {option.description} - - - - ); -}; export function RegisterModelTypeModal({ onCloseModal }: Props) { const [modelRepoSelection, setModelRepoSelection] = useState>>( [] @@ -65,21 +50,21 @@ export function RegisterModelTypeModal({ onCloseModal }: Props) { (selectedOption) => { selectedOption = onChange(modelRepoSelection); switch (modelSource) { - case ModelSource.USER_MODEL: - selectedOption = modelRepoSelection.find((option) => option.checked === 'on'); - if (selectedOption?.label) { - history.push( - `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=import&name=${ - selectedOption?.label - }&version=${selectedOption?.label}` - ); - } - break; case ModelSource.PRE_TRAINED_MODEL: + history.push( + `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=import` + ); + break; + case ModelSource.USER_MODEL: history.push( `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=upload` ); break; + case ModelSource.EXTERNAL_MODEL: + history.push( + `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=external` + ); + break; } }, [history, modelSource, modelRepoSelection, onChange] @@ -127,8 +112,8 @@ export function RegisterModelTypeModal({ onCloseModal }: Props) { } aria-label="Opensearch model repository" - checked={modelSource === ModelSource.USER_MODEL} - onChange={() => setModelSource(ModelSource.USER_MODEL)} + checked={modelSource === ModelSource.PRE_TRAINED_MODEL} + onChange={() => setModelSource(ModelSource.PRE_TRAINED_MODEL)} />
@@ -146,58 +131,31 @@ export function RegisterModelTypeModal({ onCloseModal }: Props) { } aria-label="Add your own model" - checked={modelSource === ModelSource.PRE_TRAINED_MODEL} - onChange={() => setModelSource(ModelSource.PRE_TRAINED_MODEL)} + checked={modelSource === ModelSource.USER_MODEL} + onChange={() => setModelSource(ModelSource.USER_MODEL)} + /> + + + + External source + + + Connect to an external source with a connector. + + + } + aria-label="External source" + checked={modelSource === ModelSource.EXTERNAL_MODEL} + onChange={() => setModelSource(ModelSource.EXTERNAL_MODEL)} /> -
- - Model - - -
- - For more information on each model, see - - - - OpenSearch model repository documentation - - -
- - - {(list, search) => ( - - {search} - {list} - - )} - -
Date: Wed, 6 Dec 2023 12:51:14 +0800 Subject: [PATCH 2/9] feat: update base review Signed-off-by: xyinshen --- .../common/forms/model_file_format.tsx | 1 + public/components/monitoring/index.tsx | 1 + .../monitoring/model_connector_filter.tsx | 1 + .../register_model/model_deployment.tsx | 51 +++++++ .../register_model/model_source.tsx | 10 +- ...select.tsx => pretrained_model_select.tsx} | 18 +-- .../register_model/register_model.tsx | 133 +++++++----------- .../register_model/register_model.types.ts | 1 + 8 files changed, 119 insertions(+), 97 deletions(-) create mode 100644 public/components/register_model/model_deployment.tsx rename public/components/register_model/{pretrainedmodel_select.tsx => pretrained_model_select.tsx} (90%) diff --git a/public/components/common/forms/model_file_format.tsx b/public/components/common/forms/model_file_format.tsx index 4ef3a446..4da864da 100644 --- a/public/components/common/forms/model_file_format.tsx +++ b/public/components/common/forms/model_file_format.tsx @@ -37,6 +37,7 @@ export const ModelFileFormatSelect = ({ readOnly = false }: Props) => { }); const { ref: fileFormatInputRef, ...fileFormatField } = modelFileFormatController.field; + const selectedFileFormatOption = useMemo(() => { if (fileFormatField.value) { return FILE_FORMAT_OPTIONS.find((fmt) => fmt.value === fileFormatField.value); diff --git a/public/components/monitoring/index.tsx b/public/components/monitoring/index.tsx index 0ed035b8..e8516a41 100644 --- a/public/components/monitoring/index.tsx +++ b/public/components/monitoring/index.tsx @@ -43,6 +43,7 @@ export const Monitoring = () => { } = useMonitoring(); const [previewModel, setPreviewModel] = useState(null); const searchInputRef = useRef(); + const setInputRef = useCallback((node: HTMLInputElement | null) => { searchInputRef.current = node; }, []); diff --git a/public/components/monitoring/model_connector_filter.tsx b/public/components/monitoring/model_connector_filter.tsx index 6094494c..42d5434e 100644 --- a/public/components/monitoring/model_connector_filter.tsx +++ b/public/components/monitoring/model_connector_filter.tsx @@ -36,6 +36,7 @@ export const ModelConnectorFilter = ({ ), [internalConnectorsResult?.data, allExternalConnectors] ); + return ( { + const searchParams = useSearchParams(); + const typeParams = searchParams.get('type'); + const [checked, setChecked] = useState(false); + const { control } = useFormContext<{ deployment: boolean }>(); + const modelDeploymentController = useController({ + name: 'deployment', + control, + }); + + const { ref: deploymentInputRef, ...deploymentField } = modelDeploymentController.field; + const onDeploymentChange = useCallback( + (e) => { + setChecked(e.target.checked); + deploymentField.onChange(checked); + }, + [deploymentField, checked] + ); + return ( + +
+ {Needs a description} + {(typeParams === 'upload' || typeParams === 'import') && ( + + )} + {typeParams === 'external' && ( + + )} +
+
+ ); +}; diff --git a/public/components/register_model/model_source.tsx b/public/components/register_model/model_source.tsx index 288e7cc5..a9bdc4f5 100644 --- a/public/components/register_model/model_source.tsx +++ b/public/components/register_model/model_source.tsx @@ -18,8 +18,8 @@ import { useMonitoring } from '../monitoring/use_monitoring'; export const ModelSource = () => { const { allExternalConnectors } = useMonitoring(); - const CONNECTOR_OPTIONS = allExternalConnectors?.map((item) => { - return Object.assign({}, { label: item.name, value: item.description }); + const connectorOptions = allExternalConnectors?.map((item) => { + return Object.assign({}, { label: item.name, value: item.id }); }); const { control } = useFormContext<{ modelConnector: string }>(); @@ -36,9 +36,9 @@ export const ModelSource = () => { const { ref: fileFormatInputRef, ...fileFormatField } = modelConnectorController.field; const selectedConnectorOption = useMemo(() => { if (fileFormatField.value) { - return CONNECTOR_OPTIONS?.find((connector) => connector.value === fileFormatField.value); + return connectorOptions?.find((connector) => connector.value === fileFormatField.value); } - }, [fileFormatField, CONNECTOR_OPTIONS]); + }, [fileFormatField, connectorOptions]); const onConnectorChange = useCallback( (options: Array>) => { @@ -65,7 +65,7 @@ export const ModelSource = () => { void; -} const renderModelOption = (option: IItem, searchValue: string) => { return ( <> @@ -37,7 +34,7 @@ const renderModelOption = (option: IItem, searchValue: string) => { ); }; -export const PreTrainedModelSelect = ({ getPreSelected }: Props) => { +export const PreTrainedModelSelect = () => { useEffect(() => { const subscribe = modelRepositoryManager.getPreTrainedModels$().subscribe((models) => { setModelRepoSelection( @@ -52,12 +49,6 @@ export const PreTrainedModelSelect = ({ getPreSelected }: Props) => { subscribe.unsubscribe(); }; }, []); - const ShowRest = useCallback( - (selected: boolean) => { - getPreSelected(selected); - }, - [getPreSelected] - ); const [modelRepoSelection, setModelRepoSelection] = useState>>( [] ); @@ -65,9 +56,10 @@ export const PreTrainedModelSelect = ({ getPreSelected }: Props) => { const onChange = useCallback( (modelSelection: Array>) => { setModelRepoSelection(modelSelection); - ShowRest(true); + // ShowRest(true); }, - [ShowRest] + // [ShowRest] + [] ); useEffect(() => { const selectedOption = modelRepoSelection.find((option) => option.checked === 'on'); @@ -75,7 +67,7 @@ export const PreTrainedModelSelect = ({ getPreSelected }: Props) => { history.push( `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=import&name=${ selectedOption?.label - }&version=${selectedOption?.label}` + }` ); } }, [modelRepoSelection, history]); diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 8525c395..7c47ea2a 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -18,8 +18,6 @@ import { EuiFlexItem, EuiTextColor, EuiLink, - EuiFormRow, - EuiCheckbox, EuiPageContent, } from '@elastic/eui'; import useObservable from 'react-use/lib/useObservable'; @@ -32,7 +30,7 @@ import { mountReactNode } from '../../../../../src/core/public/utils'; import { routerPaths } from '../../../common/router_paths'; import { ErrorCallOut } from '../../components/common'; import { modelRepositoryManager } from '../../utils/model_repository_manager'; -import { PreTrainedModelSelect } from './pretrainedmodel_select'; +import { PreTrainedModelSelect } from './pretrained_model_select'; import { modelTaskManager } from './model_task_manager'; import { ModelVersionNotesPanel } from './model_version_notes'; import { modelFileUploadManager } from './model_file_upload_manager'; @@ -44,7 +42,7 @@ import { ConfigurationPanel } from './model_configuration'; import { ModelTagsPanel } from './model_tags'; import { submitModelWithFile, submitModelWithURL } from './register_model_api'; import { ModelSource } from './model_source'; - +import { ModelDeployment } from './model_deployment'; const DEFAULT_VALUES = { name: '', description: '', @@ -89,36 +87,42 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo services: { chrome, notifications }, } = useOpenSearchDashboards(); const isLocked = useObservable(chrome?.getIsNavDrawerLocked$() ?? from([false])); - const [preSelected, setPreSelect] = useState(false); - const getPreSelected = (val: boolean) => { - setPreSelect(val); - }; const formType = isValidModelRegisterFormType(typeParams) ? typeParams : 'upload'; - const partials = - formType === 'import' - ? [ - PreTrainedModelSelect, - ...(!preSelected ? [] : [ModelDetailsPanel]), - ...(!preSelected ? [] : [ModelTagsPanel]), - ...(!preSelected ? [] : [ModelVersionNotesPanel]), - ] - : formType === 'upload' - ? [ - ...(registerToModelId ? [] : [ModelOverviewTitle]), - ...(registerToModelId ? [] : [ModelDetailsPanel]), - ...(registerToModelId ? [] : [ModelTagsPanel]), - ...(registerToModelId ? [] : [FileAndVersionTitle]), - ArtifactPanel, - ConfigurationPanel, - ...(registerToModelId ? [ModelTagsPanel] : []), - ModelVersionNotesPanel, - ] - : [ - ...(registerToModelId ? [] : [ModelDetailsPanel]), - ...(registerToModelId ? [] : [ModelTagsPanel]), - ModelSource, - ModelVersionNotesPanel, - ]; + const partials = (() => { + if (formType === 'import') { + if (!nameParams) { + return [PreTrainedModelSelect]; + } + return [ + PreTrainedModelSelect, + ModelDetailsPanel, + ModelTagsPanel, + ModelVersionNotesPanel, + ModelDeployment, + ]; + } + if (formType === 'external') { + return [ + ...(registerToModelId ? [] : [ModelDetailsPanel]), + ...(registerToModelId ? [] : [ModelTagsPanel]), + ModelSource, + ModelVersionNotesPanel, + ModelDeployment, + ]; + } + return [ + ...(registerToModelId ? [] : [ModelOverviewTitle]), + ...(registerToModelId ? [] : [ModelDetailsPanel]), + ...(registerToModelId ? [] : [ModelTagsPanel]), + ...(registerToModelId ? [] : [FileAndVersionTitle]), + ArtifactPanel, + ConfigurationPanel, + ...(registerToModelId ? [ModelTagsPanel] : []), + ModelVersionNotesPanel, + ModelDeployment, + ]; + })(); + const form = useForm({ mode: 'onChange', defaultValues, @@ -275,21 +279,23 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo // eslint-disable-next-line no-console console.log(errors); }, []); - + const getPageTitle = () => { + if (registerToModelId) { + return 'Register version'; + } + switch (formType) { + case 'external': + return 'Register external model'; + case 'external': + return 'Register pre-trained model'; + default: + return 'Register your own model'; + } + }; const errorCount = Object.keys(form.formState.errors).length; const formHeader = ( <> - + {registerToModelId && ( @@ -314,33 +320,6 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo ); - const [checked, setChecked] = useState(false); - const onChange = (e: any) => { - setChecked(e.target.checked); - }; - const formFooter = ( - -
- {Needs a description} - {(formType === 'upload' || formType === 'import') && ( - onChange(e)} - /> - )} - {formType === 'external' && !registerToModelId && ( - onChange(e)} - /> - )} -
-
- ); return ( ( - {FormPartial === PreTrainedModelSelect ? ( - - ) : ( - - )} + {FormPartial === PreTrainedModelSelect ? : } {FormPartial === ModelOverviewTitle || FormPartial === FileAndVersionTitle ? ( ) : ( @@ -378,7 +353,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo )} ))} - {formType === 'import' ? preSelected && formFooter : formFooter} + {/* {formType === 'import' ? nameParams && formFooter : formFooter} */} @@ -398,7 +373,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo )} - setIsSubmitted(true)} iconType="cross" color="ghost"> + setIsSubmitted(false)} iconType="cross" color="ghost"> Cancel diff --git a/public/components/register_model/register_model.types.ts b/public/components/register_model/register_model.types.ts index 85260fdf..33aa2635 100644 --- a/public/components/register_model/register_model.types.ts +++ b/public/components/register_model/register_model.types.ts @@ -14,6 +14,7 @@ export interface ModelFormBase { tags?: Tag[]; versionNotes?: string; type?: 'import' | 'upload' | 'external'; + deployment: boolean; } /** From 62a70b30f33c2b83269ed2f1b655c76dc09d3426 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Fri, 22 Dec 2023 16:29:50 +0800 Subject: [PATCH 3/9] Address PR comments Signed-off-by: Lin Wang --- .../register_model/model_deployment.tsx | 37 +++------- .../register_model/model_source.tsx | 17 +++-- .../pretrained_model_select.tsx | 71 +++++++++---------- .../register_model/register_model.tsx | 5 +- 4 files changed, 57 insertions(+), 73 deletions(-) diff --git a/public/components/register_model/model_deployment.tsx b/public/components/register_model/model_deployment.tsx index 27855849..97eceefd 100644 --- a/public/components/register_model/model_deployment.tsx +++ b/public/components/register_model/model_deployment.tsx @@ -3,48 +3,33 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useCallback, useState } from 'react'; +import React from 'react'; import { EuiCheckbox, EuiText, EuiFormRow } from '@elastic/eui'; import { useController, useFormContext } from 'react-hook-form'; import { useSearchParams } from '../../hooks/use_search_params'; export const ModelDeployment = () => { const searchParams = useSearchParams(); const typeParams = searchParams.get('type'); - const [checked, setChecked] = useState(false); const { control } = useFormContext<{ deployment: boolean }>(); const modelDeploymentController = useController({ name: 'deployment', control, }); + const isRegisterExternal = typeParams === 'external'; const { ref: deploymentInputRef, ...deploymentField } = modelDeploymentController.field; - const onDeploymentChange = useCallback( - (e) => { - setChecked(e.target.checked); - deploymentField.onChange(checked); - }, - [deploymentField, checked] - ); return ( - +
{Needs a description} - {(typeParams === 'upload' || typeParams === 'import') && ( - - )} - {typeParams === 'external' && ( - - )} +
); diff --git a/public/components/register_model/model_source.tsx b/public/components/register_model/model_source.tsx index a9bdc4f5..7c7d6407 100644 --- a/public/components/register_model/model_source.tsx +++ b/public/components/register_model/model_source.tsx @@ -12,15 +12,20 @@ import { EuiFormRow, EuiComboBox, } from '@elastic/eui'; - import { useController, useFormContext } from 'react-hook-form'; -import { useMonitoring } from '../monitoring/use_monitoring'; + +import { useFetcher } from '../../hooks'; +import { APIProvider } from '../../apis/api_provider'; export const ModelSource = () => { - const { allExternalConnectors } = useMonitoring(); - const connectorOptions = allExternalConnectors?.map((item) => { - return Object.assign({}, { label: item.name, value: item.id }); - }); + const { data: allConnectorsData } = useFetcher(APIProvider.getAPI('connector').getAll); + const connectorOptions = useMemo( + () => + allConnectorsData?.data?.map((item) => { + return Object.assign({}, { label: item.name, value: item.id }); + }), + [allConnectorsData] + ); const { control } = useFormContext<{ modelConnector: string }>(); const modelConnectorController = useController({ diff --git a/public/components/register_model/pretrained_model_select.tsx b/public/components/register_model/pretrained_model_select.tsx index f4d75485..763c03e7 100644 --- a/public/components/register_model/pretrained_model_select.tsx +++ b/public/components/register_model/pretrained_model_select.tsx @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useCallback, useState, Fragment, useEffect } from 'react'; +import React, { useCallback, Fragment } from 'react'; import { EuiSpacer, EuiTextColor, @@ -12,10 +12,12 @@ import { EuiSelectableOption, EuiHighlight, } from '@elastic/eui'; -import { useHistory } from 'react-router-dom'; -import { generatePath } from 'react-router-dom'; +import { useHistory, generatePath } from 'react-router-dom'; +import { useObservable } from 'react-use'; + import { modelRepositoryManager } from '../../utils/model_repository_manager'; import { routerPaths } from '../../../common/router_paths'; + interface IItem { label: string; checked?: 'on' | undefined; @@ -34,43 +36,36 @@ const renderModelOption = (option: IItem, searchValue: string) => { ); }; -export const PreTrainedModelSelect = () => { - useEffect(() => { - const subscribe = modelRepositoryManager.getPreTrainedModels$().subscribe((models) => { - setModelRepoSelection( - Object.keys(models).map((name) => ({ - label: name, - description: models[name].description, - checked: undefined, - })) - ); - }); - return () => { - subscribe.unsubscribe(); - }; - }, []); - const [modelRepoSelection, setModelRepoSelection] = useState>>( - [] - ); +export const PreTrainedModelSelect = ({ + checkedPreTrainedModel, +}: { + checkedPreTrainedModel?: string; +}) => { + const preTrainedModels = useObservable(modelRepositoryManager.getPreTrainedModels$()); + const preTrainedModelOptions = preTrainedModels + ? Object.keys(preTrainedModels).map((name) => ({ + label: name, + description: preTrainedModels[name].description, + checked: checkedPreTrainedModel === name ? ('on' as const) : undefined, + })) + : []; + const history = useHistory(); const onChange = useCallback( - (modelSelection: Array>) => { - setModelRepoSelection(modelSelection); - // ShowRest(true); + (options: Array>) => { + const selectedOption = options.find((option) => option.checked === 'on'); + + if (selectedOption?.label) { + history.push( + `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=import&name=${ + selectedOption.label + }` + ); + } }, - // [ShowRest] - [] + [history] ); - useEffect(() => { - const selectedOption = modelRepoSelection.find((option) => option.checked === 'on'); - if (selectedOption?.label) { - history.push( - `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=import&name=${ - selectedOption?.label - }` - ); - } - }, [modelRepoSelection, history]); + return (
@@ -95,7 +90,7 @@ export const PreTrainedModelSelect = () => { 'data-test-subj': 'findModel', placeholder: 'Find model', }} - options={modelRepoSelection} + options={preTrainedModelOptions} onChange={onChange} singleSelection={true} noMatchesMessage="No model found" @@ -106,7 +101,7 @@ export const PreTrainedModelSelect = () => { 'data-test-subj': 'opensearchModelList', showIcons: true, }} - isLoading={modelRepoSelection.length === 0} + isLoading={!preTrainedModels} > {(list, search) => ( diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 7c47ea2a..49880983 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -295,7 +295,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo const errorCount = Object.keys(form.formState.errors).length; const formHeader = ( <> - + {registerToModelId && ( @@ -353,7 +353,6 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo )} ))} - {/* {formType === 'import' ? nameParams && formFooter : formFooter} */} @@ -380,7 +379,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo setIsSubmitted(true)} From 11528b1cd72e96d54c04b2d14f5e30c7cf0fc48f Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Fri, 22 Dec 2023 17:07:06 +0800 Subject: [PATCH 4/9] Fix failed tests Signed-off-by: Lin Wang --- .../register_model/register_model.tsx | 2 +- .../__tests__/index.test.tsx | 44 +------------------ 2 files changed, 3 insertions(+), 43 deletions(-) diff --git a/public/components/register_model/register_model.tsx b/public/components/register_model/register_model.tsx index 49880983..24b4d78c 100644 --- a/public/components/register_model/register_model.tsx +++ b/public/components/register_model/register_model.tsx @@ -250,7 +250,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo const { config } = preTrainedModel; form.setValue('modelFileFormat', 'TORCH_SCRIPT'); if (config.name) { - form.setValue('name', config.name); + form.setValue('name', `huggingface/${config.name}`); } if (config.version) { form.setValue('version', config.version); diff --git a/public/components/register_model_type_modal/__tests__/index.test.tsx b/public/components/register_model_type_modal/__tests__/index.test.tsx index 4252cb86..a2935b74 100644 --- a/public/components/register_model_type_modal/__tests__/index.test.tsx +++ b/public/components/register_model_type_modal/__tests__/index.test.tsx @@ -37,16 +37,11 @@ const mockOffsetMethods = () => { }; describe('', () => { - it('should render two checkablecard', () => { + it('should render three checkablecard', () => { render( {}} />); expect(screen.getByLabelText('Opensearch model repository')).toBeInTheDocument(); expect(screen.getByLabelText('Add your own model')).toBeInTheDocument(); - }); - - it('should render select with Opensearch model repository', () => { - render( {}} />); - expect(screen.getByLabelText('Opensearch model repository')).toBeInTheDocument(); - expect(screen.getByLabelText('OpenSearch model repository models')).toBeInTheDocument(); + expect(screen.getByLabelText('External source')).toBeInTheDocument(); }); it('should call onCloseModal after click "cancel"', async () => { @@ -55,39 +50,4 @@ describe('', () => { await userEvent.click(screen.getByTestId('cancelRegister')); expect(onClickMock).toHaveBeenCalled(); }); - - it('should call opensearch model repository model list and link to url with selected option after click "Find model" and continue', async () => { - const mockReset = mockOffsetMethods(); - render( {}} />); - await userEvent.click(screen.getByLabelText('Opensearch model repository')); - expect(screen.getByTestId('findModel')).toBeInTheDocument(); - expect(screen.getByTestId('opensearchModelList')).toBeInTheDocument(); - await waitFor(() => - expect(screen.getByText('sentence-transformers/all-distilroberta-v1')).toBeInTheDocument() - ); - await userEvent.click(screen.getByText('sentence-transformers/all-distilroberta-v1')); - await userEvent.click(screen.getByTestId('continueRegister')); - expect(document.URL).toContain( - 'model-registry/register-model/?type=import&name=sentence-transformers/all-distilroberta-v1&version=sentence-transformers/all-distilroberta-v1' - ); - mockReset(); - }); - - it('should render no model found when input a invalid text to search model', async () => { - const mockReset = mockOffsetMethods(); - render( {}} />); - await userEvent.click(screen.getByLabelText('Opensearch model repository')); - await waitFor(() => - expect(screen.getByText('sentence-transformers/all-distilroberta-v1')).toBeInTheDocument() - ); - await userEvent.type(screen.getByTestId('findModel'), 'foo'); - expect(screen.getByText('No model found')).toBeInTheDocument(); - mockReset(); - }); - - it('should link href after selecting "add your own model" and continue ', async () => { - render( {}} />); - await userEvent.click(screen.getByTestId('continueRegister')); - expect(document.URL).toEqual('http://localhost/model-registry/register-model/?type=upload'); - }); }); From 1b3a473b95e61870581b99117f64d50e3a010437 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Mon, 8 Jan 2024 16:40:06 +0800 Subject: [PATCH 5/9] feat: support register external model Signed-off-by: Lin Wang --- public/apis/model_version.ts | 3 +- .../register_model/__tests__/setup.tsx | 8 ++--- .../register_model/model_deployment.tsx | 2 ++ .../register_model/model_details.tsx | 4 +-- .../register_model/model_source.tsx | 11 +++---- .../register_model/register_model.tsx | 15 ++++++---- .../register_model/register_model.types.ts | 12 +++++++- .../register_model/register_model_api.ts | 29 +++++++++++++++++-- server/routes/model_version_router.ts | 11 ++++++- server/services/model_version_service.ts | 24 +++++++++------ 10 files changed, 88 insertions(+), 31 deletions(-) diff --git a/public/apis/model_version.ts b/public/apis/model_version.ts index a9c78c00..3bb705f3 100644 --- a/public/apis/model_version.ts +++ b/public/apis/model_version.ts @@ -84,8 +84,9 @@ interface UploadModelBase { name: string; version?: string; description?: string; - modelFormat: string; + modelFormat?: string; modelId: string; + deployment: boolean; } export interface UploadModelByURL extends UploadModelBase { diff --git a/public/components/register_model/__tests__/setup.tsx b/public/components/register_model/__tests__/setup.tsx index 5095b1bc..e6102926 100644 --- a/public/components/register_model/__tests__/setup.tsx +++ b/public/components/register_model/__tests__/setup.tsx @@ -10,13 +10,13 @@ import { UserEvent } from '@testing-library/user-event/dist/types/setup/setup'; import { RegisterModelForm } from '../register_model'; import { render, RenderWithRouteProps, screen, waitFor } from '../../../../test/test_utils'; -import { ModelFileFormData, ModelUrlFormData } from '../register_model.types'; +import { ModelFormData } from '../register_model.types'; jest.mock('../../../apis/task'); interface SetupOptions extends Partial { mode?: 'model' | 'version' | 'import'; - defaultValues?: Partial | Partial; + defaultValues?: Partial; } interface SetupReturn { @@ -45,12 +45,12 @@ const DEFAULT_VALUES = { export async function setup(options: { route?: string; mode: 'version'; - defaultValues?: Partial | Partial; + defaultValues?: Partial; }): Promise>; export async function setup(options?: { route?: string; mode?: 'model' | 'import'; - defaultValues?: Partial | Partial; + defaultValues?: Partial; }): Promise; export async function setup( { route, mode, defaultValues }: SetupOptions = { diff --git a/public/components/register_model/model_deployment.tsx b/public/components/register_model/model_deployment.tsx index 97eceefd..718fa758 100644 --- a/public/components/register_model/model_deployment.tsx +++ b/public/components/register_model/model_deployment.tsx @@ -14,6 +14,7 @@ export const ModelDeployment = () => { const modelDeploymentController = useController({ name: 'deployment', control, + defaultValue: false, }); const isRegisterExternal = typeParams === 'external'; @@ -24,6 +25,7 @@ export const ModelDeployment = () => { {Needs a description} { - const { control, trigger, watch } = useFormContext(); + const { control, trigger, watch } = useFormContext(); const type = watch('type'); return ( diff --git a/public/components/register_model/model_source.tsx b/public/components/register_model/model_source.tsx index 7c7d6407..4a342497 100644 --- a/public/components/register_model/model_source.tsx +++ b/public/components/register_model/model_source.tsx @@ -16,6 +16,7 @@ import { useController, useFormContext } from 'react-hook-form'; import { useFetcher } from '../../hooks'; import { APIProvider } from '../../apis/api_provider'; +import { ExternalModelFormData } from './register_model.types'; export const ModelSource = () => { const { data: allConnectorsData } = useFetcher(APIProvider.getAPI('connector').getAll); @@ -26,19 +27,19 @@ export const ModelSource = () => { }), [allConnectorsData] ); - const { control } = useFormContext<{ modelConnector: string }>(); + const { control } = useFormContext(); const modelConnectorController = useController({ - name: 'modelConnector', + name: 'connectorId', control, rules: { required: { value: true, - message: '', + message: 'Model connector is required', }, }, }); - const { ref: fileFormatInputRef, ...fileFormatField } = modelConnectorController.field; + const { ref: connectorInputRef, ...fileFormatField } = modelConnectorController.field; const selectedConnectorOption = useMemo(() => { if (fileFormatField.value) { return connectorOptions?.find((connector) => connector.value === fileFormatField.value); @@ -69,7 +70,7 @@ export const ModelSource = () => { | Partial; + defaultValues?: Partial; } const ModelOverviewTitle = () => { @@ -123,7 +123,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo ]; })(); - const form = useForm({ + const form = useForm({ mode: 'onChange', defaultValues, criteriaMode: 'all', @@ -133,7 +133,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo const formErrors = useMemo(() => ({ ...form.formState.errors }), [form.formState]); const onSubmit = useCallback( - async (data: ModelFileFormData | ModelUrlFormData | ModelFormBase) => { + async (data: ModelFormData) => { try { const onComplete = () => { notifications?.toasts.addSuccess({ @@ -169,6 +169,9 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo onError, }); modelId = result.modelId; + } else if ('connectorId' in data) { + const result = await submitExternalModel(data); + modelId = result.modelId; } else { const result = await submitModelWithURL(data); modelTaskManager.query({ @@ -274,7 +277,7 @@ export const RegisterModelForm = ({ defaultValues = DEFAULT_VALUES }: RegisterMo form.setValue('type', formType); }, [formType, form]); - const onError = useCallback((errors: FieldErrors) => { + const onError = useCallback((errors: FieldErrors) => { // TODO // eslint-disable-next-line no-console console.log(errors); diff --git a/public/components/register_model/register_model.types.ts b/public/components/register_model/register_model.types.ts index 33aa2635..c955932c 100644 --- a/public/components/register_model/register_model.types.ts +++ b/public/components/register_model/register_model.types.ts @@ -10,7 +10,6 @@ export interface ModelFormBase { version?: string; modelId?: string; description?: string; - modelFileFormat: string; tags?: Tag[]; versionNotes?: string; type?: 'import' | 'upload' | 'external'; @@ -23,6 +22,7 @@ export interface ModelFormBase { export interface ModelFileFormData extends ModelFormBase { modelFile: File; configuration: string; + modelFileFormat: string; } /** @@ -31,4 +31,14 @@ export interface ModelFileFormData extends ModelFormBase { export interface ModelUrlFormData extends ModelFormBase { modelURL: string; configuration: string; + modelFileFormat: string; +} + +/** + * The type of the external model form data via connector + */ +export interface ExternalModelFormData extends ModelFormBase { + connectorId: string; } + +export type ModelFormData = ModelFileFormData | ModelUrlFormData | ExternalModelFormData; diff --git a/public/components/register_model/register_model_api.ts b/public/components/register_model/register_model_api.ts index e090e3cb..3ca6b657 100644 --- a/public/components/register_model/register_model_api.ts +++ b/public/components/register_model/register_model_api.ts @@ -6,7 +6,12 @@ import { APIProvider } from '../../apis/api_provider'; import { MAX_CHUNK_SIZE } from '../common/forms/form_constants'; import { getModelContentHashValue } from './get_model_content_hash_value'; -import { ModelFileFormData, ModelUrlFormData, ModelFormBase } from './register_model.types'; +import { + ModelFileFormData, + ModelUrlFormData, + ModelFormBase, + ExternalModelFormData, +} from './register_model.types'; const getModelUploadBase = ({ name, @@ -14,9 +19,11 @@ const getModelUploadBase = ({ versionNotes, modelFileFormat, configuration, -}: ModelFormBase & { configuration?: string }) => ({ + deployment, +}: ModelFormBase & { configuration?: string; modelFileFormat?: string }) => ({ name, version, + deployment, description: versionNotes, modelFormat: modelFileFormat, modelConfig: configuration ? JSON.parse(configuration) : undefined, @@ -96,3 +103,21 @@ export async function submitModelWithURL(model: ModelUrlFormData | ModelFormBase taskId: result.uploadResult.task_id, }; } + +export async function submitExternalModel(model: ExternalModelFormData) { + const result = await createModelIfNeedAndUploadVersion({ + ...model, + uploader: (modelId: string) => + APIProvider.getAPI('modelVersion').upload({ + ...getModelUploadBase(model), + modelFormat: undefined, + modelId, + connectorId: model.connectorId, + }), + }); + + return { + modelId: result.modelId, + taskId: result.uploadResult.task_id, + }; +} diff --git a/server/routes/model_version_router.ts b/server/routes/model_version_router.ts index f06023c1..b54fa087 100644 --- a/server/routes/model_version_router.ts +++ b/server/routes/model_version_router.ts @@ -55,21 +55,29 @@ const modelUploadBaseSchema = schema.object({ name: schema.string(), version: schema.maybe(schema.string()), description: schema.maybe(schema.string()), - modelFormat: schema.string(), modelId: schema.string(), + deployment: schema.boolean({ + defaultValue: false, + }), }); const modelUploadByURLSchema = modelUploadBaseSchema.extends({ + modelFormat: schema.string(), url: schema.string(), modelConfig: schema.object({}, { unknowns: 'allow' }), }); const modelUploadByChunkSchema = modelUploadBaseSchema.extends({ + modelFormat: schema.string(), modelContentHashValue: schema.string(), totalChunks: schema.number(), modelConfig: schema.object({}, { unknowns: 'allow' }), }); +const modelUploadByExternalSchema = modelUploadBaseSchema.extends({ + connectorId: schema.string(), +}); + export const modelVersionRouter = (router: IRouter) => { router.get( { @@ -256,6 +264,7 @@ export const modelVersionRouter = (router: IRouter) => { modelUploadByURLSchema, modelUploadByChunkSchema, modelUploadBaseSchema, + modelUploadByExternalSchema, ]), }, }, diff --git a/server/services/model_version_service.ts b/server/services/model_version_service.ts index 5f43e086..0e216c5a 100644 --- a/server/services/model_version_service.ts +++ b/server/services/model_version_service.ts @@ -40,21 +40,27 @@ interface UploadModelBase { name: string; version?: string; description?: string; - modelFormat: string; modelId: string; + deployment: boolean; } interface UploadModelByURL extends UploadModelBase { url: string; + modelFormat: string; modelConfig: Record; } interface UploadModelByChunk extends UploadModelBase { modelContentHashValue: string; totalChunks: number; + modelFormat: string; modelConfig: Record; } +interface UploadExternalModel extends UploadModelBase { + connectorId: string; +} + type UploadResultInner = T extends UploadModelByChunk ? { model_version_id: string; status: string } : { task_id: string; status: string }; @@ -170,14 +176,10 @@ export class ModelVersionService { ).body; } - public static async upload({ - client, - model, - }: { - client: IScopedClusterClient; - model: T; - }): UploadResult { - const { name, version, description, modelFormat, modelId } = model; + public static async upload< + T extends UploadModelByChunk | UploadModelByURL | UploadExternalModel | UploadModelBase + >({ client, model }: { client: IScopedClusterClient; model: T }): UploadResult { + const { name, version, description, modelFormat, modelId, deployment } = model; const uploadModelBase = { name, version, @@ -204,8 +206,12 @@ export class ModelVersionService { await client.asCurrentUser.transport.request({ method: 'POST', path: MODEL_UPLOAD_API, + querystring: deployment ? { deploy: deployment } : {}, body: { ...uploadModelBase, + ...('connectorId' in model + ? { connector_id: model.connectorId, function_name: 'remote' } + : {}), url: 'url' in model ? model.url : undefined, }, }) From 35f0e87d840b7b242a2efd89c9dece2b764ae764 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Mon, 8 Jan 2024 18:39:30 +0800 Subject: [PATCH 6/9] test: add unit tests for model connector Signed-off-by: Lin Wang --- .../__tests__/register_model_source.test.tsx | 53 +++++++++++++++++++ .../register_model/__tests__/setup.tsx | 13 +++-- .../register_model/model_source.tsx | 19 ++++--- 3 files changed, 73 insertions(+), 12 deletions(-) create mode 100644 public/components/register_model/__tests__/register_model_source.test.tsx diff --git a/public/components/register_model/__tests__/register_model_source.test.tsx b/public/components/register_model/__tests__/register_model_source.test.tsx new file mode 100644 index 00000000..893f7226 --- /dev/null +++ b/public/components/register_model/__tests__/register_model_source.test.tsx @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { setup } from './setup'; +import * as formAPI from '../register_model_api'; +import { Connector } from '../../../apis/connector'; +import { screen } from '../../../../test/test_utils'; + +describe(' Source', () => { + const onSubmitMock = jest.fn().mockResolvedValue('model_id'); + + beforeEach(() => { + jest.spyOn(formAPI, 'submitExternalModel').mockImplementation(onSubmitMock); + jest.spyOn(Connector.prototype, 'getAll').mockResolvedValue({ + data: [ + { id: 'connector-1', name: 'foo' }, + { id: 'connector-2', name: 'bar' }, + ], + total_connectors: 2, + }); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should render a model source panel', async () => { + const result = await setup({ mode: 'external', route: '/?type=external' }); + expect(result.connectorInput).toBeInTheDocument(); + }); + + it('should submit the register model form', async () => { + const result = await setup({ mode: 'external', route: '/?type=external' }); + expect(onSubmitMock).not.toHaveBeenCalled(); + + await result.user.click(result.connectorInput); + await result.user.click(screen.getByRole('option', { name: 'foo' })); + await result.user.click(result.submitButton); + + expect(onSubmitMock).toHaveBeenCalled(); + }); + + it('should NOT submit the register model form if model source is empty', async () => { + const result = await setup({ mode: 'external', route: '/?type=external' }); + + await result.user.click(result.submitButton); + + expect(result.connectorInput.closest('[class*="isInvalid"]')).toBeInTheDocument(); + expect(onSubmitMock).not.toHaveBeenCalled(); + }); +}); diff --git a/public/components/register_model/__tests__/setup.tsx b/public/components/register_model/__tests__/setup.tsx index e6102926..80c00d3c 100644 --- a/public/components/register_model/__tests__/setup.tsx +++ b/public/components/register_model/__tests__/setup.tsx @@ -15,7 +15,7 @@ import { ModelFormData } from '../register_model.types'; jest.mock('../../../apis/task'); interface SetupOptions extends Partial { - mode?: 'model' | 'version' | 'import'; + mode?: 'model' | 'version' | 'import' | 'external'; defaultValues?: Partial; } @@ -26,6 +26,7 @@ interface SetupReturn { form: HTMLElement; user: UserEvent; versionNotesInput: HTMLTextAreaElement; + connectorInput: HTMLSelectElement; } const CONFIGURATION = `{ @@ -46,10 +47,10 @@ export async function setup(options: { route?: string; mode: 'version'; defaultValues?: Partial; -}): Promise>; +}): Promise>; export async function setup(options?: { route?: string; - mode?: 'model' | 'import'; + mode?: 'model' | 'import' | 'external'; defaultValues?: Partial; }): Promise; export async function setup( @@ -75,6 +76,7 @@ export async function setup( const form = screen.getByTestId('mlCommonsPlugin-registerModelForm'); const user = userEvent.setup(); const versionNotesInput = screen.getByLabelText(/notes/i); + const connectorInput = screen.queryByLabelText('Model connector'); // fill model file if (modelFileInput) { @@ -106,11 +108,11 @@ export async function setup( } // fill model name - if (mode === 'model') { + if (mode === 'model' || mode === 'external') { await user.type(nameInput, 'test model name'); } // fill model description - if (mode === 'model') { + if (mode === 'model' || mode === 'external') { await user.type(descriptionInput, 'test model description'); } @@ -121,5 +123,6 @@ export async function setup( form, user, versionNotesInput, + connectorInput, }; } diff --git a/public/components/register_model/model_source.tsx b/public/components/register_model/model_source.tsx index 4a342497..7b69fccd 100644 --- a/public/components/register_model/model_source.tsx +++ b/public/components/register_model/model_source.tsx @@ -39,19 +39,19 @@ export const ModelSource = () => { }, }, }); - const { ref: connectorInputRef, ...fileFormatField } = modelConnectorController.field; + const { ref: connectorInputRef, ...connectorField } = modelConnectorController.field; const selectedConnectorOption = useMemo(() => { - if (fileFormatField.value) { - return connectorOptions?.find((connector) => connector.value === fileFormatField.value); + if (connectorField.value) { + return connectorOptions?.find((connector) => connector.value === connectorField.value); } - }, [fileFormatField, connectorOptions]); + }, [connectorField, connectorOptions]); const onConnectorChange = useCallback( (options: Array>) => { const value = options[0]?.value; - fileFormatField.onChange(value); + connectorField.onChange(value); }, - [fileFormatField] + [connectorField] ); return (
@@ -68,7 +68,11 @@ export const ModelSource = () => { - + { selectedOptions={selectedConnectorOption ? [selectedConnectorOption] : []} placeholder="Select a connector" onChange={onConnectorChange} + isInvalid={Boolean(modelConnectorController.fieldState.error)} />
From 37f110851a3850e60acbe4d2278780ba3793f733 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Tue, 9 Jan 2024 11:14:11 +0800 Subject: [PATCH 7/9] test: add unit tests for model automatic deployment Signed-off-by: Lin Wang --- .../register_model_deployment.test.tsx | 57 +++++++++++++++++++ .../register_model/model_deployment.tsx | 31 +++++----- 2 files changed, 75 insertions(+), 13 deletions(-) create mode 100644 public/components/register_model/__tests__/register_model_deployment.test.tsx diff --git a/public/components/register_model/__tests__/register_model_deployment.test.tsx b/public/components/register_model/__tests__/register_model_deployment.test.tsx new file mode 100644 index 00000000..68c5978e --- /dev/null +++ b/public/components/register_model/__tests__/register_model_deployment.test.tsx @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { setup } from './setup'; +import * as formAPI from '../register_model_api'; +import { screen } from '../../../../test/test_utils'; + +describe(' Deployment', () => { + const onSubmitMock = jest.fn().mockResolvedValue('model_id'); + + beforeEach(() => { + jest.spyOn(formAPI, 'submitExternalModel').mockImplementation(onSubmitMock); + jest.spyOn(formAPI, 'submitModelWithFile').mockImplementation(onSubmitMock); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should render a model deployment panel', async () => { + await setup(); + expect(screen.getByLabelText('Deployment')).toBeInTheDocument(); + }); + + it('should render a model activation panel', async () => { + await setup({ mode: 'external', route: '/?type=external' }); + expect(screen.getByLabelText('Activation')).toBeInTheDocument(); + }); + + it('should submit the register model form without automatic deployment flag', async () => { + const result = await setup(); + expect(onSubmitMock).not.toHaveBeenCalled(); + + await result.user.click(result.submitButton); + + expect(onSubmitMock).toHaveBeenCalledWith( + expect.objectContaining({ + deployment: false, + }) + ); + }); + it('should submit the register model form with automatic deployment flag', async () => { + const result = await setup(); + expect(onSubmitMock).not.toHaveBeenCalled(); + + await result.user.click(screen.getByLabelText('Start deployment automatically')); + await result.user.click(result.submitButton); + + expect(onSubmitMock).toHaveBeenCalledWith( + expect.objectContaining({ + deployment: true, + }) + ); + }); +}); diff --git a/public/components/register_model/model_deployment.tsx b/public/components/register_model/model_deployment.tsx index 718fa758..636ac7bc 100644 --- a/public/components/register_model/model_deployment.tsx +++ b/public/components/register_model/model_deployment.tsx @@ -7,6 +7,7 @@ import React from 'react'; import { EuiCheckbox, EuiText, EuiFormRow } from '@elastic/eui'; import { useController, useFormContext } from 'react-hook-form'; import { useSearchParams } from '../../hooks/use_search_params'; + export const ModelDeployment = () => { const searchParams = useSearchParams(); const typeParams = searchParams.get('type'); @@ -20,19 +21,23 @@ export const ModelDeployment = () => { const { ref: deploymentInputRef, ...deploymentField } = modelDeploymentController.field; return ( - -
- {Needs a description} - -
+ + Need a description, mention of “in use” might make sense +
+ } + > + ); }; From 1d05eca7c10d42ecf7086407000edec27dba3752 Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Wed, 10 Jan 2024 13:05:31 +0800 Subject: [PATCH 8/9] refactor model selector with combobox Signed-off-by: Lin Wang --- .../register_model_repository_import.test.tsx | 57 ++++++++++++ .../pretrained_model_select.tsx | 87 +++++++------------ 2 files changed, 89 insertions(+), 55 deletions(-) create mode 100644 public/components/register_model/__tests__/register_model_repository_import.test.tsx diff --git a/public/components/register_model/__tests__/register_model_repository_import.test.tsx b/public/components/register_model/__tests__/register_model_repository_import.test.tsx new file mode 100644 index 00000000..c22bcfdd --- /dev/null +++ b/public/components/register_model/__tests__/register_model_repository_import.test.tsx @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +import React from 'react'; +import userEvent from '@testing-library/user-event'; +import { Route } from 'react-router-dom'; + +import { render, screen, history, waitFor } from '../../../../test/test_utils'; +import { RegisterModelForm } from '../register_model'; +import { ModelRepository } from '../../../apis/model_repository'; + +describe(' Repository Import', () => { + beforeEach(() => { + jest.spyOn(ModelRepository.prototype, 'getPreTrainedModels').mockResolvedValue({ + foo: { + description: 'foo', + version: '1', + torch_script: { model_url: '', config_url: '' }, + onnx: { model_url: '', config_url: '' }, + }, + }); + jest.spyOn(ModelRepository.prototype, 'getPreTrainedModelConfig').mockResolvedValue({}); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('should render find model selector and disable "Register model" button', async () => { + render( + + + , + { route: '/?type=import' } + ); + + await waitFor(() => { + expect(screen.getByText('Find model')).toBeInTheDocument(); + expect(screen.queryByLabelText(/^name$/i)).not.toBeInTheDocument(); + expect(screen.getByRole('button', { name: 'Register model' })).toBeDisabled(); + }); + }); + + it('should update path with name params after model selected', async () => { + render( + + + , + { route: '/?type=import' } + ); + + await userEvent.click(screen.getByText('Find model')); + await userEvent.click(screen.getByRole('option', { name: 'foo' })); + expect(history.current.location.search).toContain('name=foo'); + }); +}); diff --git a/public/components/register_model/pretrained_model_select.tsx b/public/components/register_model/pretrained_model_select.tsx index 763c03e7..a4410ace 100644 --- a/public/components/register_model/pretrained_model_select.tsx +++ b/public/components/register_model/pretrained_model_select.tsx @@ -3,62 +3,56 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { useCallback, Fragment } from 'react'; +import React, { useCallback } from 'react'; import { EuiSpacer, EuiTextColor, - EuiSelectable, EuiLink, - EuiSelectableOption, EuiHighlight, + EuiComboBox, + EuiComboBoxProps, } from '@elastic/eui'; import { useHistory, generatePath } from 'react-router-dom'; import { useObservable } from 'react-use'; import { modelRepositoryManager } from '../../utils/model_repository_manager'; import { routerPaths } from '../../../common/router_paths'; +import { useSearchParams } from '../../hooks/use_search_params'; -interface IItem { - label: string; - checked?: 'on' | undefined; - description: string; -} -const renderModelOption = (option: IItem, searchValue: string) => { - return ( - <> - {option.label} -
+type PreTrainedModelProps = Required>; + +const renderOption: PreTrainedModelProps['renderOption'] = (option, searchValue) => ( + <> + {option.value && {option.value?.name}} +
+ {option.value && ( - {option.description} + {option.value.description} - - ); -}; -export const PreTrainedModelSelect = ({ - checkedPreTrainedModel, -}: { - checkedPreTrainedModel?: string; -}) => { + )} + +); + +export const PreTrainedModelSelect = () => { + const searchParams = useSearchParams(); + const nameParams = searchParams.get('name'); const preTrainedModels = useObservable(modelRepositoryManager.getPreTrainedModels$()); const preTrainedModelOptions = preTrainedModels ? Object.keys(preTrainedModels).map((name) => ({ label: name, - description: preTrainedModels[name].description, - checked: checkedPreTrainedModel === name ? ('on' as const) : undefined, + value: { name, description: preTrainedModels[name].description }, })) : []; const history = useHistory(); - const onChange = useCallback( - (options: Array>) => { - const selectedOption = options.find((option) => option.checked === 'on'); - - if (selectedOption?.label) { + const onChange = useCallback( + (options) => { + if (options[0].value) { history.push( `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=import&name=${ - selectedOption.label + options[0].value.name }` ); } @@ -83,33 +77,16 @@ export const PreTrainedModelSelect = ({
- - {(list, search) => ( - - {search} - {list} - - )} - + isClearable={false} + /> ); }; From 222ab9f59e2b2fcba436431347d6db09ac7a038a Mon Sep 17 00:00:00 2001 From: Lin Wang Date: Wed, 10 Jan 2024 15:07:58 +0800 Subject: [PATCH 9/9] test: add miss tests for register model type modal Signed-off-by: Lin Wang --- .../__tests__/index.test.tsx | 77 ++++++++++--------- .../register_model_type_modal/index.tsx | 68 ++++------------ 2 files changed, 59 insertions(+), 86 deletions(-) diff --git a/public/components/register_model_type_modal/__tests__/index.test.tsx b/public/components/register_model_type_modal/__tests__/index.test.tsx index a2935b74..79b3ac1d 100644 --- a/public/components/register_model_type_modal/__tests__/index.test.tsx +++ b/public/components/register_model_type_modal/__tests__/index.test.tsx @@ -5,49 +5,56 @@ import React from 'react'; import userEvent from '@testing-library/user-event'; + import { RegisterModelTypeModal } from '../index'; -import { render, screen, waitFor } from '../../../../test/test_utils'; - -const mockOffsetMethods = () => { - const originalOffsetHeight = Object.getOwnPropertyDescriptor( - HTMLElement.prototype, - 'offsetHeight' - ); - const originalOffsetWidth = Object.getOwnPropertyDescriptor(HTMLElement.prototype, 'offsetWidth'); - Object.defineProperty(HTMLElement.prototype, 'offsetHeight', { - configurable: true, - value: 600, - }); - Object.defineProperty(HTMLElement.prototype, 'offsetWidth', { - configurable: true, - value: 600, - }); - return () => { - Object.defineProperty( - HTMLElement.prototype, - 'offsetHeight', - originalOffsetHeight as PropertyDescriptor - ); - Object.defineProperty( - HTMLElement.prototype, - 'offsetWidth', - originalOffsetWidth as PropertyDescriptor - ); - }; -}; +import { render, screen, history } from '../../../../test/test_utils'; describe('', () => { - it('should render three checkablecard', () => { - render( {}} />); + it('should render three checkable card', () => { + render(); expect(screen.getByLabelText('Opensearch model repository')).toBeInTheDocument(); expect(screen.getByLabelText('Add your own model')).toBeInTheDocument(); expect(screen.getByLabelText('External source')).toBeInTheDocument(); }); - it('should call onCloseModal after click "cancel"', async () => { - const onClickMock = jest.fn(); - render(); + it('should call onCloseModal after cancel button click', async () => { + const onCloseModalMock = jest.fn(); + render(); await userEvent.click(screen.getByTestId('cancelRegister')); - expect(onClickMock).toHaveBeenCalled(); + expect(onCloseModalMock).toHaveBeenCalled(); + }); + + it('should call onCloseModal after modal close icon click', async () => { + const onCloseModalMock = jest.fn(); + render(); + await userEvent.click(screen.getByLabelText('Closes this modal window')); + expect(onCloseModalMock).toHaveBeenCalled(); + }); + + it('should go to repository model import page', async () => { + render(); + + await userEvent.click(screen.getByLabelText('Opensearch model repository')); + await userEvent.click(screen.getByText('Continue')); + expect(history.current.location.pathname).toEqual('/model-registry/register-model/'); + expect(history.current.location.search).toContain('type=import'); + }); + + it('should go to model upload page', async () => { + render(); + + await userEvent.click(screen.getByLabelText('Add your own model')); + await userEvent.click(screen.getByText('Continue')); + expect(history.current.location.pathname).toEqual('/model-registry/register-model/'); + expect(history.current.location.search).toContain('type=upload'); + }); + + it('should go to external model register page', async () => { + render(); + + await userEvent.click(screen.getByLabelText('External source')); + await userEvent.click(screen.getByText('Continue')); + expect(history.current.location.pathname).toEqual('/model-registry/register-model/'); + expect(history.current.location.search).toContain('type=external'); }); }); diff --git a/public/components/register_model_type_modal/index.tsx b/public/components/register_model_type_modal/index.tsx index 523868b0..95a9a4d3 100644 --- a/public/components/register_model_type_modal/index.tsx +++ b/public/components/register_model_type_modal/index.tsx @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ import { EuiSpacer } from '@elastic/eui'; -import React, { useState, useCallback, useEffect } from 'react'; +import React, { useState, useCallback } from 'react'; import { useHistory } from 'react-router-dom'; import { EuiButton, @@ -17,12 +17,10 @@ import { EuiCheckableCard, EuiText, EuiTextColor, - EuiSelectableOption, } from '@elastic/eui'; import { htmlIdGenerator } from '@elastic/eui'; import { generatePath } from 'react-router-dom'; import { routerPaths } from '../../../common/router_paths'; -import { modelRepositoryManager } from '../../utils/model_repository_manager'; enum ModelSource { USER_MODEL = 'UserModel', @@ -32,58 +30,26 @@ enum ModelSource { interface Props { onCloseModal: () => void; } -interface IItem { - label: string; - checked?: 'on' | undefined; - description: string; -} + export function RegisterModelTypeModal({ onCloseModal }: Props) { - const [modelRepoSelection, setModelRepoSelection] = useState>>( - [] - ); const history = useHistory(); const [modelSource, setModelSource] = useState(ModelSource.PRE_TRAINED_MODEL); - const onChange = useCallback((modelSelection: Array>) => { - setModelRepoSelection(modelSelection); - }, []); - const handleContinue = useCallback( - (selectedOption) => { - selectedOption = onChange(modelRepoSelection); - switch (modelSource) { - case ModelSource.PRE_TRAINED_MODEL: - history.push( - `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=import` - ); - break; - case ModelSource.USER_MODEL: - history.push( - `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=upload` - ); - break; - case ModelSource.EXTERNAL_MODEL: - history.push( - `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=external` - ); - break; - } - }, - [history, modelSource, modelRepoSelection, onChange] - ); + const handleContinue = useCallback(() => { + switch (modelSource) { + case ModelSource.PRE_TRAINED_MODEL: + history.push(`${generatePath(routerPaths.registerModel, { id: undefined })}/?type=import`); + break; + case ModelSource.USER_MODEL: + history.push(`${generatePath(routerPaths.registerModel, { id: undefined })}/?type=upload`); + break; + case ModelSource.EXTERNAL_MODEL: + history.push( + `${generatePath(routerPaths.registerModel, { id: undefined })}/?type=external` + ); + break; + } + }, [history, modelSource]); - useEffect(() => { - const subscribe = modelRepositoryManager.getPreTrainedModels$().subscribe((models) => { - setModelRepoSelection( - Object.keys(models).map((name) => ({ - label: name, - description: models[name].description, - checked: undefined, - })) - ); - }); - return () => { - subscribe.unsubscribe(); - }; - }, []); return ( onCloseModal()} maxWidth="1000px">