Skip to content

Commit

Permalink
[ML] Fix model test flyout reload (#144318) (#144388)
Browse files Browse the repository at this point in the history
(cherry picked from commit c6a0058)

Co-authored-by: James Gowdy <[email protected]>
  • Loading branch information
kibanamachine and jgowdyelastic authored Nov 2, 2022
1 parent aa4540d commit ed16c6c
Showing 1 changed file with 27 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*/

import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import React, { FC } from 'react';
import React, { FC, useMemo } from 'react';

import { NerInference } from './models/ner';
import { QuestionAnsweringInference } from './models/question_answering';
Expand All @@ -28,53 +28,41 @@ import { useMlApiContext } from '../../../contexts/kibana';
import { InferenceInputForm } from './models/inference_input_form';

interface Props {
model: estypes.MlTrainedModelConfig | null;
model: estypes.MlTrainedModelConfig;
}

export const SelectedModel: FC<Props> = ({ model }) => {
const { trainedModels } = useMlApiContext();

if (model === null) {
return null;
}

if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) {
if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.NER) {
const inferrer = new NerInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
}

if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION) {
const inferrer = new TextClassificationInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
}
const inferrer = useMemo(() => {
if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) {
const taskType = Object.keys(model.inference_config)[0];

if (
Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION
) {
const inferrer = new ZeroShotClassificationInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
}

if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING) {
const inferrer = new TextEmbeddingInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
}
switch (taskType) {
case SUPPORTED_PYTORCH_TASKS.NER:
return new NerInference(trainedModels, model);
case SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION:
return new TextClassificationInference(trainedModels, model);
case SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION:
return new ZeroShotClassificationInference(trainedModels, model);
case SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING:
return new TextEmbeddingInference(trainedModels, model);
case SUPPORTED_PYTORCH_TASKS.FILL_MASK:
return new FillMaskInference(trainedModels, model);
case SUPPORTED_PYTORCH_TASKS.QUESTION_ANSWERING:
return new QuestionAnsweringInference(trainedModels, model);

if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.FILL_MASK) {
const inferrer = new FillMaskInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
default:
break;
}
} else if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) {
return new LangIdentInference(trainedModels, model);
}
}, [model, trainedModels]);

if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.QUESTION_ANSWERING) {
const inferrer = new QuestionAnsweringInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
}
}
if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) {
const inferrer = new LangIdentInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
if (inferrer === undefined) {
return null;
}

return null;
return <InferenceInputForm inferrer={inferrer} />;
};

0 comments on commit ed16c6c

Please sign in to comment.