Skip to content

Commit

Permalink
fix: comparison view parallel coordinates chart shouldn't break when …
Browse files Browse the repository at this point in the history
…selecting rows [ET-261] (#9584)
  • Loading branch information
emily-roses authored Jul 3, 2024
1 parent d159f14 commit a6a79b8
Show file tree
Hide file tree
Showing 12 changed files with 368 additions and 94 deletions.
212 changes: 207 additions & 5 deletions webui/react/src/components/CompareHyperparameters.test.mock.tsx
Original file line number Diff line number Diff line change
@@ -1,13 +1,66 @@
import React from 'react';

import { useGlasbey } from 'hooks/useGlasbey';
import { RunMetricData } from 'hooks/useMetrics';
import { Scale } from 'types';
import { generateTestRunData } from 'utils/tests/generateTestData';

import CompareHyperparameters from './CompareHyperparameters';
export const METRIC_DATA: RunMetricData = {
data: {
3400: {
1: {
'{"group":"training","name":"loss"}': {
data: {
Batches: [[2, 0.5823304653167725]],
Epoch: [],
Time: [[1656260140.728, 0.5823304653167725]],
},
name: 'training.loss',
},
'{"group":"validation","name":"accuracy"}': {
data: {
Batches: [[2, 0.8522093949044586]],
Epoch: [],
Time: [[1656260146.436, 0.8522093949044586]],
},
name: 'validation.accuracy',
},
'{"group":"validation","name":"validation_loss"}': {
data: {
Batches: [[2, 0.49773818169050155]],
Epoch: [],
Time: [[1656260146.436, 0.49773818169050155]],
},
name: 'validation.validation_loss',
},
},
2: {
'{"group":"training","name":"loss"}': {
data: {
Batches: [[2, 0.5823304653167725]],
Epoch: [],
Time: [[1656260140.728, 0.5823304653167725]],
},
name: 'training.loss',
},
'{"group":"validation","name":"accuracy"}': {
data: {
Batches: [[2, 0.8522093949044586]],
Epoch: [],
Time: [[1656260146.436, 0.8522093949044586]],
},
name: 'validation.accuracy',
},
'{"group":"validation","name":"validation_loss"}': {
data: {
Batches: [[2, 0.49773818169050155]],
Epoch: [],
Time: [[1656260146.436, 0.49773818169050155]],
},
name: 'validation.validation_loss',
},
},
3: {
'{"group":"training","name":"loss"}': {
data: {
Batches: [[2, 0.5823304653167725]],
Expand Down Expand Up @@ -346,7 +399,145 @@ export const TRIALS = [
n_filters1: 54,
n_filters2: 70,
},
id: 3400,
id: 1,
latestValidationMetric: {
endTime: '2022-06-26T16:15:46.436495Z',
metrics: {
accuracy: 0.8522093949044586,
validation_loss: 0.49773818169050155,
},
totalBatches: 2,
},
searcherMetricsVal: 1,
startTime: '2022-06-26T16:08:36.678225Z',
state: 'COMPLETED',
summaryMetrics: {
avgMetrics: {
loss: {
count: 1,
last: 0.5823304653167725,
max: 0.582330465316772,
mean: 0.582330465316772,
min: 0.582330465316772,
sum: 0.582330465316772,
type: 'number',
},
},
validationMetrics: {
accuracy: {
count: 1,
last: 0.8522093949044586,
max: 0.852209394904459,
mean: 0.852209394904459,
min: 0.852209394904459,
sum: 0.852209394904459,
type: 'number',
},
validation_loss: {
count: 1,
last: 0.49773818169050155,
max: 0.497738181690502,
mean: 0.497738181690502,
min: 0.497738181690502,
sum: 0.497738181690502,
type: 'number',
},
},
},
totalBatchesProcessed: 100,
totalCheckpointSize: 83008221,
},
{
autoRestarts: 0,
bestAvailableCheckpoint: null,
bestValidationMetric: {
endTime: '2023-04-20T16:20:22.902226Z',
metrics: {
loss: 1,
},
totalBatches: 1,
},
checkpointCount: 1,
endTime: '2022-06-26T16:16:04.171606Z',
experimentId: 1156,
hyperparameters: {
dropout1: 0.532803505916605,
dropout2: 0.39400711778394015,
global_batch_size: 64,
learning_rate: 0.06716139157036664,
n_filters1: 54,
n_filters2: 70,
},
id: 2,
latestValidationMetric: {
endTime: '2022-06-26T16:15:46.436495Z',
metrics: {
accuracy: 0.8522093949044586,
validation_loss: 0.49773818169050155,
},
totalBatches: 2,
},
searcherMetricsVal: 1,
startTime: '2022-06-26T16:08:36.678225Z',
state: 'COMPLETED',
summaryMetrics: {
avgMetrics: {
loss: {
count: 1,
last: 0.5823304653167725,
max: 0.582330465316772,
mean: 0.582330465316772,
min: 0.582330465316772,
sum: 0.582330465316772,
type: 'number',
},
},
validationMetrics: {
accuracy: {
count: 1,
last: 0.8522093949044586,
max: 0.852209394904459,
mean: 0.852209394904459,
min: 0.852209394904459,
sum: 0.852209394904459,
type: 'number',
},
validation_loss: {
count: 1,
last: 0.49773818169050155,
max: 0.497738181690502,
mean: 0.497738181690502,
min: 0.497738181690502,
sum: 0.497738181690502,
type: 'number',
},
},
},
totalBatchesProcessed: 100,
totalCheckpointSize: 83008221,
},
{
autoRestarts: 0,
bestAvailableCheckpoint: null,
bestValidationMetric: {
endTime: '2023-04-20T16:20:22.902226Z',
metrics: {
loss: 1,
},
totalBatches: 1,
},
checkpointCount: 1,
endTime: '2022-06-26T16:16:04.171606Z',
experimentId: 1156,
hyperparameters: {
dropout1: 0.532803505916605,
dropout2: 0.39400711778394015,
global_batch_size: 64,
learning_rate: 0.06716139157036664,
n_filters1: 54,
n_filters2: 70,
},
id: 3,
latestValidationMetric: {
endTime: '2022-06-26T16:15:46.436495Z',
metrics: {
Expand Down Expand Up @@ -396,17 +587,25 @@ export const TRIALS = [
},
];

export const SELECTED_RUNS = [generateTestRunData(), generateTestRunData(), generateTestRunData()];
export const SELECTED_RUNS = [
generateTestRunData(1),
generateTestRunData(2),
generateTestRunData(3),
];

interface Props {
empty?: boolean;
comparableMetrics?: boolean;
}
export const CompareTrialHyperparametersWithMocks: React.FC<Props> = ({
empty,
comparableMetrics = true,
}: Props): JSX.Element => {
const colorMap = useGlasbey(SELECTED_EXPERIMENTS.map((exp) => exp.experiment.id));
return (
<CompareHyperparameters
metricData={METRIC_DATA}
colorMap={colorMap}
metricData={comparableMetrics ? METRIC_DATA : { ...METRIC_DATA, data: {} }}
projectId={1}
// @ts-expect-error Mock data does not need type checking
selectedExperiments={empty ? [] : SELECTED_EXPERIMENTS}
Expand All @@ -418,10 +617,13 @@ export const CompareTrialHyperparametersWithMocks: React.FC<Props> = ({

export const CompareRunHyperparametersWithMocks: React.FC<Props> = ({
empty,
comparableMetrics = true,
}: Props): JSX.Element => {
const colorMap = useGlasbey(SELECTED_RUNS.map((run) => run.id));
return (
<CompareHyperparameters
metricData={METRIC_DATA}
colorMap={colorMap}
metricData={comparableMetrics ? METRIC_DATA : { ...METRIC_DATA, data: {} }}
projectId={1}
selectedRuns={empty ? [] : SELECTED_RUNS}
/>
Expand Down
22 changes: 19 additions & 3 deletions webui/react/src/components/CompareHyperparameters.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,26 @@ vi.mock('hooks/useSettings', async (importOriginal) => {
};
});

const setup = (type: 'trials' | 'runs', empty?: boolean) => {
const setup = (
type: 'trials' | 'runs',
empty: boolean = false,
comparableMetrics: boolean = true,
) => {
render(
<BrowserRouter>
<UIProvider theme={DefaultTheme.Light}>
<ThemeProvider>
<SettingsProvider>
{type === 'trials' ? (
<CompareTrialHyperparametersWithMocks empty={empty} />
<CompareTrialHyperparametersWithMocks
comparableMetrics={comparableMetrics}
empty={empty}
/>
) : (
<CompareRunHyperparametersWithMocks empty={empty} />
<CompareRunHyperparametersWithMocks
comparableMetrics={comparableMetrics}
empty={empty}
/>
)}
</SettingsProvider>
</ThemeProvider>
Expand All @@ -60,6 +70,12 @@ describe('CompareHyperparameters component', () => {
setup(type);
expect(screen.getByTestId(COMPARE_PARALLEL_COORDINATES)).toBeInTheDocument();
});
it('renders Parallel Coordinates error when metrics are incompatable', () => {
setup(type, false, false);
expect(
screen.getByText('Records are not comparable using current parameters.'),
).toBeInTheDocument();
});
it('renders Scatter Plots', () => {
setup(type);
expect(screen.getByTestId(COMPARE_SCATTER_PLOTS)).toBeInTheDocument();
Expand Down
6 changes: 6 additions & 0 deletions webui/react/src/components/CompareHyperparameters.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { Title } from 'hew/Typography';
import React, { useCallback, useEffect, useMemo } from 'react';

import Section from 'components/Section';
import { MapOfIdsToColors } from 'hooks/useGlasbey';
import { RunMetricData } from 'hooks/useMetrics';
import { useSettings } from 'hooks/useSettings';
import { ExperimentVisualizationType } from 'pages/ExperimentDetails/ExperimentVisualization';
Expand All @@ -24,6 +25,7 @@ import CompareScatterPlots from './CompareScatterPlots';
import css from './HpParallelCoordinates.module.scss';

interface BaseProps {
colorMap: MapOfIdsToColors;
projectId: number;
metricData: RunMetricData;
}
Expand All @@ -39,6 +41,7 @@ export const NO_DATA_MESSAGE = 'No data available.';
const CompareHyperparameters: React.FC<Props> = ({
selectedExperiments,
selectedRuns,
colorMap,
trials,
projectId,
metricData,
Expand Down Expand Up @@ -93,6 +96,7 @@ const CompareHyperparameters: React.FC<Props> = ({
}, [resetSettings]);

useEffect(() => {
if (metrics.length === 0) return;
const activeMetricFound = metrics.find(
(metric) =>
metric.name === settings?.metric?.name && metric.group === settings?.metric?.group,
Expand Down Expand Up @@ -155,6 +159,7 @@ const CompareHyperparameters: React.FC<Props> = ({
<Title>Parallel Coordinates</Title>
{selectedRuns ? (
<CompareParallelCoordinates
colorMap={colorMap}
fullHParams={fullHParams}
metricData={metricData}
projectId={projectId}
Expand All @@ -163,6 +168,7 @@ const CompareHyperparameters: React.FC<Props> = ({
/>
) : (
<CompareParallelCoordinates
colorMap={colorMap}
fullHParams={fullHParams}
metricData={metricData}
projectId={projectId}
Expand Down
Loading

0 comments on commit a6a79b8

Please sign in to comment.