Skip to content

Commit

Permalink
feat(frontend): Support Scalar metrics in V2 compatible mode. Partial #…
Browse files Browse the repository at this point in the history
…5668 (#5811)

* feat(frontend) Support Scalar metrics in V2 compatible mode

* remove unused

* Update frontend/src/components/viewers/MetricsVisualizations.tsx

Co-authored-by: Yuan (Bob) Gong <[email protected]>

* Update frontend/src/components/viewers/MetricsVisualizations.tsx

Co-authored-by: Yuan (Bob) Gong <[email protected]>

* Update frontend/src/components/viewers/MetricsVisualizations.tsx

Co-authored-by: Yuan (Bob) Gong <[email protected]>

* Update frontend/src/components/viewers/MetricsVisualizations.tsx

Co-authored-by: Yuan (Bob) Gong <[email protected]>

* fix adjustmnet

* address comment

Co-authored-by: Yuan (Bob) Gong <[email protected]>
  • Loading branch information
zijianjoy and Bobgy authored Jun 9, 2021
1 parent 635f1ba commit 22bc9e2
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 68 deletions.
74 changes: 60 additions & 14 deletions frontend/src/components/tabs/MetricsTab.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ describe('MetricsTab common case', () => {
describe('MetricsTab with confidenceMetrics', () => {
it('shows ROC curve', async () => {
const execution = buildBasicExecution().setLastKnownState(Execution.State.COMPLETE);
const artifactType = buildBasicArtifactType();
const artifact = buildBasicArtifact();
const artifactType = buildClassificationMetricsArtifactType();
const artifact = buildClassificationMetricsArtifact();
artifact.getCustomPropertiesMap().set('name', new Value().setStringValue('metrics'));
artifact.getCustomPropertiesMap().set(
'confidenceMetrics',
Expand Down Expand Up @@ -128,8 +128,8 @@ describe('MetricsTab with confidenceMetrics', () => {

it('shows error banner when confidenceMetric type is wrong', async () => {
const execution = buildBasicExecution().setLastKnownState(Execution.State.COMPLETE);
const artifactType = buildBasicArtifactType();
const artifact = buildBasicArtifact();
const artifactType = buildClassificationMetricsArtifactType();
const artifact = buildClassificationMetricsArtifact();
artifact.getCustomPropertiesMap().set('name', new Value().setStringValue('metrics'));
artifact.getCustomPropertiesMap().set(
'confidenceMetrics',
Expand Down Expand Up @@ -165,8 +165,8 @@ describe('MetricsTab with confidenceMetrics', () => {

it('shows error banner when confidenceMetric is not array', async () => {
const execution = buildBasicExecution().setLastKnownState(Execution.State.COMPLETE);
const artifactType = buildBasicArtifactType();
const artifact = buildBasicArtifact();
const artifactType = buildClassificationMetricsArtifactType();
const artifact = buildClassificationMetricsArtifact();
artifact.getCustomPropertiesMap().set('name', new Value().setStringValue('metrics'));
artifact.getCustomPropertiesMap().set(
'confidenceMetrics',
Expand All @@ -193,8 +193,8 @@ describe('MetricsTab with confidenceMetrics', () => {
describe('MetricsTab with confusionMatrix', () => {
it('shows confusion matrix', async () => {
const execution = buildBasicExecution().setLastKnownState(Execution.State.COMPLETE);
const artifactType = buildBasicArtifactType();
const artifact = buildBasicArtifact();
const artifactType = buildClassificationMetricsArtifactType();
const artifact = buildClassificationMetricsArtifact();
artifact.getCustomPropertiesMap().set('name', new Value().setStringValue('metrics'));
artifact.getCustomPropertiesMap().set(
'confusionMatrix',
Expand Down Expand Up @@ -226,8 +226,8 @@ describe('MetricsTab with confusionMatrix', () => {

it('shows error banner when confusionMatrix type is wrong', async () => {
const execution = buildBasicExecution().setLastKnownState(Execution.State.COMPLETE);
const artifactType = buildBasicArtifactType();
const artifact = buildBasicArtifact();
const artifactType = buildClassificationMetricsArtifactType();
const artifact = buildClassificationMetricsArtifact();
artifact.getCustomPropertiesMap().set('name', new Value().setStringValue('metrics'));
artifact.getCustomPropertiesMap().set(
'confusionMatrix',
Expand Down Expand Up @@ -259,8 +259,8 @@ describe('MetricsTab with confusionMatrix', () => {

it("shows error banner when confusionMatrix annotationSpecs length doesn't match rows", async () => {
const execution = buildBasicExecution().setLastKnownState(Execution.State.COMPLETE);
const artifactType = buildBasicArtifactType();
const artifact = buildBasicArtifact();
const artifactType = buildClassificationMetricsArtifactType();
const artifact = buildClassificationMetricsArtifact();
artifact.getCustomPropertiesMap().set('name', new Value().setStringValue('metrics'));
artifact.getCustomPropertiesMap().set(
'confusionMatrix',
Expand All @@ -287,19 +287,65 @@ describe('MetricsTab with confusionMatrix', () => {
});
});

describe('MetricsTab with Scalar Metrics', () => {
it('shows Scalar Metrics', async () => {
const execution = buildBasicExecution().setLastKnownState(Execution.State.COMPLETE);
const artifact = buildMetricsArtifact();
artifact.getCustomPropertiesMap().set('name', new Value().setStringValue('metrics'));
artifact.getCustomPropertiesMap().set('double', new Value().setDoubleValue(123.456));
artifact.getCustomPropertiesMap().set('int', new Value().setIntValue(123));
artifact.getCustomPropertiesMap().set(
'struct',
new Value().setStructValue(
Struct.fromJavaScript({
struct: {
field: 'a string value',
},
}),
),
);
jest.spyOn(mlmdUtils, 'getOutputArtifactsInExecution').mockResolvedValueOnce([artifact]);
jest.spyOn(mlmdUtils, 'getArtifactTypes').mockResolvedValueOnce([buildMetricsArtifactType()]);
const { getByText } = render(
<CommonTestWrapper>
<MetricsTab execution={execution}></MetricsTab>
</CommonTestWrapper>,
);
getByText('Metrics is loading.');
// We should upgrade react-scripts for capability to use libraries normally:
// https://github.com/testing-library/dom-testing-library/issues/477
await waitFor(() => getByText('Scalar Metrics: metrics'));
await waitFor(() => getByText('double'));
await waitFor(() => getByText('int'));
await waitFor(() => getByText('struct'));
});
});

function buildBasicExecution() {
const execution = new Execution();
execution.setId(123);
return execution;
}
function buildBasicArtifactType() {
function buildClassificationMetricsArtifactType() {
const artifactType = new ArtifactType();
artifactType.setName('system.ClassificationMetrics');
artifactType.setId(1);
return artifactType;
}
function buildBasicArtifact() {
function buildClassificationMetricsArtifact() {
const artifact = new Artifact();
artifact.setTypeId(1);
return artifact;
}

function buildMetricsArtifactType() {
const artifactType = new ArtifactType();
artifactType.setName('system.Metrics');
artifactType.setId(2);
return artifactType;
}
function buildMetricsArtifact() {
const artifact = new Artifact();
artifact.setTypeId(2);
return artifact;
}
167 changes: 113 additions & 54 deletions frontend/src/components/viewers/MetricsVisualizations.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

import { Artifact, ArtifactType } from '@kubeflow/frontend';
import { Artifact, ArtifactType, getMetadataValue } from '@kubeflow/frontend';
import HelpIcon from '@material-ui/icons/Help';
import React from 'react';
import { Array as ArrayRunType, Number, Failure, Record, String, ValidationError } from 'runtypes';
Expand All @@ -23,6 +23,7 @@ import { color, padding } from 'src/Css';
import { filterArtifactsByType } from 'src/lib/MlmdUtils';
import Banner from '../Banner';
import ConfusionMatrix, { ConfusionMatrixConfig } from './ConfusionMatrix';
import PagedTable from './PagedTable';
import ROCCurve, { ROCCurveConfig } from './ROCCurve';
import { PlotType } from './Viewer';

Expand All @@ -36,44 +37,56 @@ interface MetricsVisualizationsProps {
* and multiple visualizations associated with one artifact.
*/
export function MetricsVisualizations({ artifacts, artifactTypes }: MetricsVisualizationsProps) {
// system.ClassificationMetrics contains confusionMatrix or confidenceMetrics.
// TODO: Visualize confusionMatrix using system.ClassificationMetrics artifacts.
// https://github.com/kubeflow/pipelines/issues/5668
let classificationMetricsArtifacts = filterArtifactsByType(
'system.ClassificationMetrics',
artifactTypes,
artifacts,
);

// There can be multiple system.ClassificationMetrics artifacts per execution.
// Get confidenceMetrics and confusionMatrix from artifact.
// There can be multiple system.ClassificationMetrics or system.Metrics artifacts per execution.
// Get scalar metrics, confidenceMetrics and confusionMatrix from artifact.
// If there is no available metrics, show banner to notify users.
// Otherwise, Visualize all available metrics per artifact.
const metricsAvailableArtifacts = getMetricsAvailableArtifacts(classificationMetricsArtifacts);
const verifiedClassificationMetricsArtifacts = getVerifiedClassificationMetricsArtifacts(
artifacts,
artifactTypes,
);
const verifiedMetricsArtifacts = getVerifiedMetricsArtifacts(artifacts, artifactTypes);

if (metricsAvailableArtifacts.length === 0) {
if (
verifiedClassificationMetricsArtifacts.length === 0 &&
verifiedMetricsArtifacts.length === 0
) {
return <Banner message='There is no metrics artifact available in this step.' mode='info' />;
}

return (
<>
{metricsAvailableArtifacts.map(artifact => {
{verifiedClassificationMetricsArtifacts.map(artifact => {
return (
<React.Fragment key={artifact.getId()}>
<ConfidenceMetricsSection artifact={artifact} />
<ConfusionMatrixSection artifact={artifact} />
</React.Fragment>
);
})}
{verifiedMetricsArtifacts.map(artifact => (
<ScalarMetricsSection artifact={artifact} key={artifact.getId()} />
))}
</>
);
}

function getMetricsAvailableArtifacts(artifacts: Artifact[]): Artifact[] {
if (!artifacts) {
function getVerifiedClassificationMetricsArtifacts(
artifacts: Artifact[],
artifactTypes: ArtifactType[],
): Artifact[] {
if (!artifacts || !artifactTypes) {
return [];
}
return artifacts
// Reference: https://github.com/kubeflow/pipelines/blob/master/sdk/python/kfp/dsl/io_types.py#L124
// system.ClassificationMetrics contains confusionMatrix or confidenceMetrics.
const classificationMetricsArtifacts = filterArtifactsByType(
'system.ClassificationMetrics',
artifactTypes,
artifacts,
);

return classificationMetricsArtifacts
.map(artifact => ({
name: artifact
.getCustomPropertiesMap()
Expand All @@ -98,6 +111,25 @@ function getMetricsAvailableArtifacts(artifacts: Artifact[]): Artifact[] {
.map(x => x.artifact);
}

function getVerifiedMetricsArtifacts(
artifacts: Artifact[],
artifactTypes: ArtifactType[],
): Artifact[] {
if (!artifacts || !artifactTypes) {
return [];
}
// Reference: https://github.com/kubeflow/pipelines/blob/master/sdk/python/kfp/dsl/io_types.py#L104
// system.Metrics contains scalar metrics.
const metricsArtifacts = filterArtifactsByType('system.Metrics', artifactTypes, artifacts);

return metricsArtifacts.filter(x =>
x
.getCustomPropertiesMap()
.get('name')
?.getStringValue(),
);
}

const ROC_CURVE_DEFINITION =
'The receiver operating characteristic (ROC) curve shows the trade-off between true positive rate and false positive rate. ' +
'A lower threshold results in a higher true positive rate (and a higher false positive rate), ' +
Expand All @@ -122,7 +154,7 @@ function ConfidenceMetricsSection({ artifact }: ConfidenceMetricsSectionProps) {
?.getStructValue()
?.toJavaScript();
if (confidenceMetrics === undefined) {
return <></>;
return null;
}

const { error } = validateConfidenceMetrics((confidenceMetrics as any).list);
Expand All @@ -132,23 +164,19 @@ function ConfidenceMetricsSection({ artifact }: ConfidenceMetricsSectionProps) {
return <Banner message={errorMsg} mode='error' additionalInfo={error} />;
}
return (
<>
{
<div className={padding(40, 'lrt')}>
<div className={padding(40, 'b')}>
<h3>
{'ROC Curve: ' + name}{' '}
<IconWithTooltip
Icon={HelpIcon}
iconColor={color.weak}
tooltip={ROC_CURVE_DEFINITION}
></IconWithTooltip>
</h3>
</div>
<ROCCurve configs={buildRocCurveConfig((confidenceMetrics as any).list)} />
</div>
}
</>
<div className={padding(40, 'lrt')}>
<div className={padding(40, 'b')}>
<h3>
{'ROC Curve: ' + name}{' '}
<IconWithTooltip
Icon={HelpIcon}
iconColor={color.weak}
tooltip={ROC_CURVE_DEFINITION}
></IconWithTooltip>
</h3>
</div>
<ROCCurve configs={buildRocCurveConfig((confidenceMetrics as any).list)} />
</div>
);
}

Expand Down Expand Up @@ -213,7 +241,7 @@ function ConfusionMatrixSection({ artifact }: ConfusionMatrixProps) {
?.getStructValue()
?.toJavaScript();
if (confusionMatrix === undefined) {
return <></>;
return null;
}

const { error } = validateConfusionMatrix(confusionMatrix.struct as any);
Expand All @@ -223,23 +251,19 @@ function ConfusionMatrixSection({ artifact }: ConfusionMatrixProps) {
return <Banner message={errorMsg} mode='error' additionalInfo={error} />;
}
return (
<>
{
<div className={padding(40, 'lrt')}>
<div className={padding(40, 'b')}>
<h3>
{'Confusion Matrix: ' + name}{' '}
<IconWithTooltip
Icon={HelpIcon}
iconColor={color.weak}
tooltip={CONFUSION_MATRIX_DEFINITION}
></IconWithTooltip>
</h3>
</div>
<ConfusionMatrix configs={buildConfusionMatrixConfig(confusionMatrix.struct as any)} />
</div>
}
</>
<div className={padding(40)}>
<div className={padding(40, 'b')}>
<h3>
{'Confusion Matrix: ' + name}{' '}
<IconWithTooltip
Icon={HelpIcon}
iconColor={color.weak}
tooltip={CONFUSION_MATRIX_DEFINITION}
></IconWithTooltip>
</h3>
</div>
<ConfusionMatrix configs={buildConfusionMatrixConfig(confusionMatrix.struct as any)} />
</div>
);
}

Expand Down Expand Up @@ -289,3 +313,38 @@ function buildConfusionMatrixConfig(
},
];
}

interface ScalarMetricsSectionProps {
artifact: Artifact;
}
function ScalarMetricsSection({ artifact }: ScalarMetricsSectionProps) {
const customProperties = artifact.getCustomPropertiesMap();
const name = customProperties.get('name')?.getStringValue();
const data = customProperties
.getEntryList()
.map(([key]) => ({
key,
value: JSON.stringify(getMetadataValue(customProperties.get(key))),
}))
.filter(metric => metric.key !== 'name');

if (data.length === 0) {
return null;
}
return (
<div className={padding(40, 'lrt')}>
<div className={padding(40, 'b')}>
<h3>{'Scalar Metrics: ' + name}</h3>
</div>
<PagedTable
configs={[
{
data: data.map(d => [d.key, d.value]),
labels: ['name', 'value'],
type: PlotType.TABLE,
},
]}
/>
</div>
);
}

0 comments on commit 22bc9e2

Please sign in to comment.