Skip to content

Commit

Permalink
Functionality complete
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilyBonar committed Jun 6, 2024
1 parent cac842d commit 6fd92ca
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 117 deletions.
100 changes: 69 additions & 31 deletions webui/react/src/components/CompareHyperparameters.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import ExperimentVisualizationFilters, {
VisualizationFilters,
} from 'pages/ExperimentDetails/ExperimentVisualization/ExperimentVisualizationFilters';
import { TrialMetricData } from 'pages/TrialDetails/useTrialMetrics';
import { ExperimentWithTrial, TrialItem } from 'types';
import { ExperimentWithTrial, FlatRun, TrialItem, XOR } from 'types';

import CompareHeatMaps from './CompareHeatMaps';
import {
Expand All @@ -23,17 +23,22 @@ import CompareParallelCoordinates from './CompareParallelCoordinates';
import CompareScatterPlots from './CompareScatterPlots';
import css from './HpParallelCoordinates.module.scss';

interface Props {
interface BaseProps {
projectId: number;
selectedExperiments: ExperimentWithTrial[];
trials: TrialItem[];
metricData: TrialMetricData;
}

type Props = XOR<
{ selectedExperiments: ExperimentWithTrial[]; trials: TrialItem[] },
{ selectedRuns: FlatRun[] }
> &
BaseProps;

export const NO_DATA_MESSAGE = 'No data available.';

const CompareHyperparameters: React.FC<Props> = ({
selectedExperiments,
selectedRuns,
trials,
projectId,
metricData,
Expand All @@ -42,9 +47,14 @@ const CompareHyperparameters: React.FC<Props> = ({

const fullHParams: string[] = useMemo(() => {
const hpParams = new Set<string>();
trials.forEach((trial) => Object.keys(trial.hyperparameters).forEach((hp) => hpParams.add(hp)));
trials?.forEach((trial) =>
Object.keys(trial.hyperparameters).forEach((hp) => hpParams.add(hp)),
);
selectedRuns?.forEach((run) =>
Object.keys(run.hyperparameters ?? {}).forEach((hp) => hpParams.add(hp)),
);
return Array.from(hpParams);
}, [trials]);
}, [selectedRuns, trials]);

const settingsConfig = useMemo(
() => settingsConfigForCompareHyperparameters(fullHParams, projectId),
Expand Down Expand Up @@ -120,11 +130,11 @@ const CompareHyperparameters: React.FC<Props> = ({
return <Spinner center spinning />;
}

if (trials.length === 0) {
if ((trials ?? selectedRuns).length === 0) {
return <Message title={NO_DATA_MESSAGE} />;
}

if (selectedExperiments.length !== 0 && metrics.length === 0) {
if ((selectedExperiments ?? selectedRuns).length !== 0 && metrics.length === 0) {
return (
<div className={css.waiting}>
<Alert
Expand All @@ -140,35 +150,63 @@ const CompareHyperparameters: React.FC<Props> = ({
<Section bodyBorder bodyScroll filters={visualizationFilters}>
<div className={css.container}>
<div className={css.chart}>
{selectedExperiments.length > 0 && (
{(selectedExperiments ?? selectedRuns).length > 0 && (
<>
<Title>Parallel Coordinates</Title>
<CompareParallelCoordinates
fullHParams={fullHParams}
metricData={metricData}
projectId={projectId}
selectedExperiments={selectedExperiments}
settings={settings}
trials={trials}
/>
{selectedRuns ? (
<CompareParallelCoordinates
fullHParams={fullHParams}
metricData={metricData}
projectId={projectId}
selectedRuns={selectedRuns}
settings={settings}
/>
) : (
<CompareParallelCoordinates
fullHParams={fullHParams}
metricData={metricData}
projectId={projectId}
selectedExperiments={selectedExperiments}
settings={settings}
trials={trials}
/>
)}
<Divider />
<Title>Scatter Plots</Title>
<CompareScatterPlots
fullHParams={fullHParams}
metricData={metricData}
selectedExperiments={selectedExperiments}
settings={settings}
trials={trials}
/>
{selectedRuns ? (
<CompareScatterPlots
fullHParams={fullHParams}
metricData={metricData}
selectedRuns={selectedRuns}
settings={settings}
/>
) : (
<CompareScatterPlots
fullHParams={fullHParams}
metricData={metricData}
selectedExperiments={selectedExperiments}
settings={settings}
trials={trials}
/>
)}
<Divider />
<Title>Heat Maps</Title>
<CompareHeatMaps
fullHParams={fullHParams}
metricData={metricData}
selectedExperiments={selectedExperiments}
settings={settings}
trials={trials}
/>
{selectedRuns ? (
<CompareHeatMaps
fullHParams={fullHParams}
metricData={metricData}
selectedRuns={selectedRuns}
settings={settings}
/>
) : (
<CompareHeatMaps
fullHParams={fullHParams}
metricData={metricData}
selectedExperiments={selectedExperiments}
settings={settings}
trials={trials}
/>
)}
</>
)}
</div>
Expand Down
134 changes: 88 additions & 46 deletions webui/react/src/components/CompareParallelCoordinates.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ import { useGlasbey } from 'hooks/useGlasbey';
import { TrialMetricData } from 'pages/TrialDetails/useTrialMetrics';
import {
ExperimentWithTrial,
FlatRun,
HpTrialData,
Hyperparameter,
HyperparameterType,
Primitive,
Range,
Scale,
TrialItem,
XAxisDomain,
XOR,
} from 'types';
import { defaultNumericRange, getNumericRange, updateRange } from 'utils/chart';
import { getNumericRange } from 'utils/chart';
import { flattenObject, isPrimitive } from 'utils/data';
import { metricToKey, metricToStr } from 'utils/metric';
import { numericSorter } from 'utils/sort';
Expand All @@ -28,28 +29,37 @@ import css from './HpParallelCoordinates.module.scss';

export const COMPARE_PARALLEL_COORDINATES = 'compare-parallel-coordinates';

interface Props {
interface BaseProps {
projectId: number;
selectedExperiments: ExperimentWithTrial[];
trials: TrialItem[];
metricData: TrialMetricData;
settings: CompareHyperparametersSettings;
fullHParams: string[];
}

type Props = XOR<
{ selectedExperiments: ExperimentWithTrial[]; trials: TrialItem[] },
{ selectedRuns: FlatRun[] }
> &
BaseProps;

const CompareParallelCoordinates: React.FC<Props> = ({
selectedExperiments,
trials,
settings,
metricData,
fullHParams,
selectedRuns,
}: Props) => {
const [chartData, setChartData] = useState<HpTrialData | undefined>();
const [hermesCreatedFilters, setHermesCreatedFilters] = useState<Hermes.Filters>({});

const { metrics, data, isLoaded, setScale } = metricData;

const colorMap = useGlasbey(selectedExperiments.map((e) => e.experiment.id));
const colorMap = useGlasbey(
selectedExperiments
? selectedExperiments.map((e) => e.experiment.id)
: selectedRuns.map((r) => r.id),
);
const selectedScale = settings.scale;

useEffect(() => {
Expand All @@ -61,7 +71,7 @@ const CompareParallelCoordinates: React.FC<Props> = ({

const experimentHyperparameters = useMemo(() => {
const hpMap: Record<string, Hyperparameter> = {};
selectedExperiments.forEach((exp) => {
selectedExperiments?.forEach((exp) => {
const hps = Object.keys(exp.experiment.hyperparameters);
hps.forEach((hp) => (hpMap[hp] = exp.experiment.hyperparameters[hp]));
});
Expand Down Expand Up @@ -140,67 +150,99 @@ const CompareParallelCoordinates: React.FC<Props> = ({

useEffect(() => {
if (!selectedMetric) return;
const trialMetricsMap: Record<number, number> = {};
const trialHpMap: Record<string, Record<number, Primitive>> = {};

const trialHpdata: Record<string, Primitive[]> = {};
let trialMetricRange: Range<number> = defaultNumericRange(true);

trials?.forEach((trial) => {
const expId = trial.experimentId;
const key = metricToKey(selectedMetric);

// Choose the final metric value for each trial
const metricValue = data?.[trial.id]?.[key]?.data?.[XAxisDomain.Batches]?.at(-1)?.[1];

if (!metricValue) return;
trialMetricsMap[expId] = metricValue;

trialMetricRange = updateRange<number>(trialMetricRange, metricValue);
const flatHParams = {
...trial?.hyperparameters,
...flattenObject(trial?.hyperparameters || {}),
};

Object.keys(flatHParams).forEach((hpKey) => {
const hpValue = flatHParams[hpKey];
trialHpMap[hpKey] = trialHpMap[hpKey] ?? {};
trialHpMap[hpKey][expId] = isPrimitive(hpValue)
? (hpValue as Primitive)
: JSON.stringify(hpValue);
const metricsMap: Record<number, number> = {};
const hpMap: Record<string, Record<number, Primitive>> = {};

const hpData: Record<string, Primitive[]> = {};

if (trials) {
trials.forEach((trial) => {
const expId = trial.experimentId;
const key = metricToKey(selectedMetric);

// Choose the final metric value for each trial
const metricValue = data?.[trial.id]?.[key]?.data?.[XAxisDomain.Batches]?.at(-1)?.[1];

if (!metricValue) return;
metricsMap[expId] = metricValue;

const flatHParams = {
...trial?.hyperparameters,
...flattenObject(trial?.hyperparameters || {}),
};

Object.keys(flatHParams).forEach((hpKey) => {
const hpValue = flatHParams[hpKey];
hpMap[hpKey] = hpMap[hpKey] ?? {};
hpMap[hpKey][expId] = isPrimitive(hpValue)
? (hpValue as Primitive)
: JSON.stringify(hpValue);
});
});
});
} else if (selectedRuns) {
selectedRuns.forEach((run) => {
const key = metricToKey(selectedMetric);

// Choose the final metric value for each trial
const metricValue = data?.[run.id]?.[key]?.data?.[XAxisDomain.Batches]?.at(-1)?.[1];

const trialIds = Object.keys(trialMetricsMap)
if (!metricValue) return;
metricsMap[run.id] = metricValue;

const flatHParams = {
...run.hyperparameters,
...flattenObject(run.hyperparameters || {}),
};

Object.keys(flatHParams).forEach((hpKey) => {
const hpValue = flatHParams[hpKey];
hpMap[hpKey] = hpMap[hpKey] ?? {};
hpMap[hpKey][run.id] = isPrimitive(hpValue)
? (hpValue as Primitive)
: JSON.stringify(hpValue);
});
});
}

const trialIds = Object.keys(metricsMap)
.map((id) => parseInt(id))
.sort(numericSorter);

Object.keys(trialHpMap).forEach((hpKey) => {
trialHpdata[hpKey] = trialIds.map((trialId) => trialHpMap[hpKey][trialId]);
Object.keys(hpMap).forEach((hpKey) => {
hpData[hpKey] = trialIds.map((trialId) => hpMap[hpKey][trialId]);
});

const metricKey = metricToStr(selectedMetric);
const metricValues = trialIds.map((id) => trialMetricsMap[id]);
trialHpdata[metricKey] = metricValues;
const metricValues = trialIds.map((id) => metricsMap[id]);
hpData[metricKey] = metricValues;

const metricRange = getNumericRange(metricValues);
setChartData({
data: trialHpdata,
data: hpData,
metricRange,
metricValues,
trialIds,
});
}, [selectedExperiments, selectedMetric, fullHParams, metricData, selectedScale, trials, data]);
}, [
selectedExperiments,
selectedMetric,
fullHParams,
metricData,
selectedScale,
trials,
data,
selectedRuns,
]);

if (!isLoaded) {
return <Spinner center spinning />;
}

if (trials.length === 0) {
if ((trials ?? selectedRuns).length === 0) {
return <Message title="No data available." />;
}

if (!chartData || (selectedExperiments.length !== 0 && metrics.length === 0)) {
if (!chartData || ((selectedExperiments ?? selectedRuns).length !== 0 && metrics.length === 0)) {
return (
<div className={css.waiting}>
<Alert
Expand All @@ -214,7 +256,7 @@ const CompareParallelCoordinates: React.FC<Props> = ({

return (
<div className={css.container}>
{selectedExperiments.length > 0 && (
{(selectedExperiments ?? selectedRuns).length > 0 && (
<div className={css.chart} data-testid={COMPARE_PARALLEL_COORDINATES}>
<ParallelCoordinates
config={config}
Expand Down
Loading

0 comments on commit 6fd92ca

Please sign in to comment.