Skip to content

Commit

Permalink
feat: add heatmap to runs table [ET-230] (#9429)
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilyBonar authored May 29, 2024
1 parent 0599d0e commit da2f943
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 80 deletions.
4 changes: 4 additions & 0 deletions webui/react/src/pages/FlatRuns/FlatRuns.settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ export const FlatRunsSettings = t.intersection([
columnWidths: t.record(t.string, t.number),
compare: t.boolean,
filterset: t.string, // save FilterFormSet as string
heatmapOn: t.boolean,
heatmapSkipped: t.array(t.string),
pageLimit: t.number,
pinnedColumnsCount: t.number,
selection: SelectionType,
Expand All @@ -26,6 +28,8 @@ export const defaultFlatRunsSettings: Required<FlatRunsSettings> = {
columnWidths: defaultColumnWidths,
compare: false,
filterset: JSON.stringify(INIT_FORMSET),
heatmapOn: false,
heatmapSkipped: [],
pageLimit: 20,
pinnedColumnsCount: 3,
selection: DEFAULT_SELECTION,
Expand Down
239 changes: 183 additions & 56 deletions webui/react/src/pages/FlatRuns/FlatRuns.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { CompactSelection, GridSelection } from '@glideapps/glide-data-grid';
import { isLeft } from 'fp-ts/lib/Either';
import Button from 'hew/Button';
import Column from 'hew/Column';
import {
ColumnDef,
Expand Down Expand Up @@ -63,7 +64,7 @@ import {
SelectionType as SelectionState,
} from 'pages/F_ExpList/F_ExperimentList.settings';
import { paths } from 'routes/utils';
import { getProjectColumns, searchRuns } from 'services/api';
import { getProjectColumns, getProjectNumericMetricsRange, searchRuns } from 'services/api';
import { V1ColumnType, V1LocationType } from 'services/api-ts-sdk';
import userStore from 'stores/users';
import userSettings from 'stores/userSettings';
Expand All @@ -78,6 +79,7 @@ import {
getColumnDefs,
RunColumn,
runColumns,
searcherMetricsValColumn,
} from './columns';
import css from './FlatRuns.module.scss';
import {
Expand Down Expand Up @@ -169,6 +171,15 @@ const FlatRuns: React.FC<Props> = ({ project }) => {
isDarkMode,
} = useUI();

const projectHeatmap = useAsync(async () => {
try {
return await getProjectNumericMetricsRange({ id: project.id });
} catch (e) {
handleError(e, { publicSubject: 'Unable to fetch project heatmap' });
return NotLoaded;
}
}, [project.id]);

const projectColumns = useAsync(async () => {
try {
const columns = await getProjectColumns({ id: project.id });
Expand Down Expand Up @@ -291,16 +302,39 @@ const FlatRuns: React.FC<Props> = ({ project }) => {
break;
}
switch (currentColumn.type) {
case V1ColumnType.NUMBER:
columnDefs[currentColumn.column] = defaultNumberColumn(
currentColumn.column,
currentColumn.displayName || currentColumn.column,
settings.columnWidths[currentColumn.column] ??
defaultColumnWidths[currentColumn.column as RunColumn] ??
MIN_COLUMN_WIDTH,
dataPath,
);
case V1ColumnType.NUMBER: {
const heatmap = projectHeatmap
.getOrElse([])
.find((h) => h.metricsName === currentColumn.column);
if (
heatmap &&
settings.heatmapOn &&
!settings.heatmapSkipped.includes(currentColumn.column)
) {
columnDefs[currentColumn.column] = defaultNumberColumn(
currentColumn.column,
currentColumn.displayName || currentColumn.column,
settings.columnWidths[currentColumn.column] ??
defaultColumnWidths[currentColumn.column as RunColumn] ??
MIN_COLUMN_WIDTH,
dataPath,
{
max: heatmap.max,
min: heatmap.min,
},
);
} else {
columnDefs[currentColumn.column] = defaultNumberColumn(
currentColumn.column,
currentColumn.displayName || currentColumn.column,
settings.columnWidths[currentColumn.column] ??
defaultColumnWidths[currentColumn.column as RunColumn] ??
MIN_COLUMN_WIDTH,
dataPath,
);
}
break;
}
case V1ColumnType.DATE:
columnDefs[currentColumn.column] = defaultDateColumn(
currentColumn.column,
Expand All @@ -323,6 +357,21 @@ const FlatRuns: React.FC<Props> = ({ project }) => {
dataPath,
);
}
if (currentColumn.column === 'searcherMetricsVal') {
const heatmap = projectHeatmap
.getOrElse([])
.find((h) => h.metricsName === currentColumn.column);

columnDefs[currentColumn.column] = searcherMetricsValColumn(
settings.columnWidths[currentColumn.column],
heatmap && settings.heatmapOn && !settings.heatmapSkipped.includes(currentColumn.column)
? {
max: heatmap.max,
min: heatmap.min,
}
: undefined,
);
}
return columnDefs[currentColumn.column];
})
.flatMap((col) => (col ? [col] : []));
Expand All @@ -332,9 +381,12 @@ const FlatRuns: React.FC<Props> = ({ project }) => {
columnsIfLoaded,
isDarkMode,
projectColumns,
projectHeatmap,
selection.rows,
settings.columnWidths,
settings.compare,
settings.heatmapOn,
settings.heatmapSkipped,
settings.pinnedColumnsCount,
users,
]);
Expand All @@ -346,6 +398,31 @@ const FlatRuns: React.FC<Props> = ({ project }) => {
[updateGlobalSettings],
);

const handleHeatmapToggle = useCallback(
(heatmapOn: boolean) => updateSettings({ heatmapOn: !heatmapOn }),
[updateSettings],
);

const handleHeatmapSelection = useCallback(
(selection: string[]) => updateSettings({ heatmapSkipped: selection }),
[updateSettings],
);

const heatmapBtnVisible = useMemo(() => {
const visibleColumns = settings.columns.slice(
0,
settings.compare ? settings.pinnedColumnsCount : undefined,
);
return Loadable.getOrElse([], projectColumns).some(
(column) =>
visibleColumns.includes(column.column) &&
(column.column === 'searcherMetricsVal' ||
(column.type === V1ColumnType.NUMBER &&
(column.location === V1LocationType.VALIDATIONS ||
column.location === V1LocationType.TRAINING))),
);
}, [settings.columns, projectColumns, settings.pinnedColumnsCount, settings.compare]);

const onPageChange = useCallback(
(cPage: number, cPageSize: number) => {
updateSettings({ pageLimit: cPageSize });
Expand Down Expand Up @@ -732,33 +809,60 @@ const FlatRuns: React.FC<Props> = ({ project }) => {
},
},
);
}

if (filterCount > 0) {
items.push({
icon: <Icon decorative name="filter" />,
key: 'filter-clear',
label: `Clear ${pluralizer(filterCount, 'Filter')} (${filterCount})`,
onClick: () => {
setTimeout(clearFilterForColumn, 5);
},
});
}
if (filterCount > 0) {
items.push({
icon: <Icon decorative name="filter" />,
key: 'filter-clear',
label: `Clear ${pluralizer(filterCount, 'Filter')} (${filterCount})`,
onClick: () => {
setTimeout(clearFilterForColumn, 5);
},
});
}
if (
settings.heatmapOn &&
(column.column === 'searcherMetricsVal' ||
(column.type === V1ColumnType.NUMBER &&
(column.location === V1LocationType.VALIDATIONS ||
column.location === V1LocationType.TRAINING)))
) {
items.push(
{ type: 'divider' as const },
{
icon: <Icon decorative name="heatmap" />,
key: 'heatmap',
label: !settings.heatmapSkipped.includes(column.column)
? 'Cancel heatmap'
: 'Apply heatmap',
onClick: () =>
handleHeatmapSelection?.(
settings.heatmapSkipped.includes(column.column)
? settings.heatmapSkipped.filter((p) => p !== column.column)
: [...settings.heatmapSkipped, column.column],
),
},
);
}
return items;
},
[
projectColumns,
settings.pinnedColumnsCount,
settings.selection,
settings.pageLimit,
settings.heatmapOn,
settings.heatmapSkipped,
isMobile,
handleSelectionChange,
columnsIfLoaded,
handleColumnsOrderChange,
handleSelectionChange,
handleSortChange,
isMobile,
loadableFormset,
handleIsOpenFilterChange,
projectColumns,
settings.pinnedColumnsCount,
sorts,
settings.pageLimit,
settings.selection,
handleSortChange,
handleHeatmapSelection,
],
);

Expand All @@ -772,35 +876,58 @@ const FlatRuns: React.FC<Props> = ({ project }) => {
return (
<div className={css.content} ref={contentRef}>
<Row>
<TableFilter
bannedFilterColumns={BANNED_FILTER_COLUMNS}
formStore={formStore}
isMobile={isMobile}
isOpenFilter={isOpenFilter}
loadableColumns={projectColumns}
onIsOpenFilterChange={handleIsOpenFilterChange}
/>
<MultiSortMenu
columns={projectColumns}
isMobile={isMobile}
sorts={sorts}
onChange={handleSortChange}
/>
<ColumnPickerMenu
defaultVisibleColumns={defaultRunColumns}
initialVisibleColumns={columnsIfLoaded}
isMobile={isMobile}
pinnedColumnsCount={settings.pinnedColumnsCount}
projectColumns={projectColumns}
projectId={project.id}
tabs={[
V1LocationType.EXPERIMENT,
[V1LocationType.VALIDATIONS, V1LocationType.TRAINING, V1LocationType.CUSTOMMETRIC],
V1LocationType.HYPERPARAMETERS,
]}
onVisibleColumnChange={handleColumnsOrderChange}
/>
<OptionsMenu rowHeight={globalSettings.rowHeight} onRowHeightChange={onRowHeightChange} />
<Column>
<Row>
<TableFilter
bannedFilterColumns={BANNED_FILTER_COLUMNS}
formStore={formStore}
isMobile={isMobile}
isOpenFilter={isOpenFilter}
loadableColumns={projectColumns}
onIsOpenFilterChange={handleIsOpenFilterChange}
/>
<MultiSortMenu
columns={projectColumns}
isMobile={isMobile}
sorts={sorts}
onChange={handleSortChange}
/>
<ColumnPickerMenu
defaultVisibleColumns={defaultRunColumns}
initialVisibleColumns={columnsIfLoaded}
isMobile={isMobile}
pinnedColumnsCount={settings.pinnedColumnsCount}
projectColumns={projectColumns}
projectId={project.id}
tabs={[
V1LocationType.EXPERIMENT,
[V1LocationType.VALIDATIONS, V1LocationType.TRAINING, V1LocationType.CUSTOMMETRIC],
V1LocationType.HYPERPARAMETERS,
]}
onHeatmapSelectionRemove={(id) => {
const newSelection = settings.heatmapSkipped.filter((s) => s !== id);
handleHeatmapSelection(newSelection);
}}
onVisibleColumnChange={handleColumnsOrderChange}
/>
<OptionsMenu
rowHeight={globalSettings.rowHeight}
onRowHeightChange={onRowHeightChange}
/>
</Row>
</Column>
<Column align="right">
<Row>
{heatmapBtnVisible && (
<Button
icon={<Icon name="heatmap" title="heatmap" />}
tooltip="Toggle Metric Heatmap"
type={settings.heatmapOn ? 'primary' : 'default'}
onClick={() => handleHeatmapToggle(settings.heatmapOn ?? false)}
/>
)}
</Row>
</Column>
</Row>
{!isLoading && total.isLoaded && total.data === 0 ? (
numFilters === 0 ? (
Expand Down
Loading

0 comments on commit da2f943

Please sign in to comment.