Skip to content

Commit

Permalink
Calibrate ensemble weights charts (#5681)
Browse files Browse the repository at this point in the history
  • Loading branch information
jryu01 authored Dec 2, 2024
1 parent f6d9d6d commit d7a61f5
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 45 deletions.
11 changes: 9 additions & 2 deletions packages/client/hmi-client/src/components/widgets/VegaChart.vue
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import embed, { Config, Result, VisualizationSpec } from 'vega-embed';
import Button from 'primevue/button';
import Dialog from 'primevue/dialog';
import { countDigits, fixPrecisionError } from '@/utils/number';
import { ref, watch, toRaw, isRef, isReactive, isProxy, computed, h, render } from 'vue';
import { ref, watch, toRaw, isRef, isReactive, isProxy, computed, h, render, onUnmounted } from 'vue';
const NUMBER_FORMAT = '.3~s';
Expand Down Expand Up @@ -119,6 +119,7 @@ const onExpand = async () => {
if (typeof props.expandable === 'function') {
spec = props.expandable(spec);
}
vegaVisualizationExpanded.value?.finalize(); // dispose previous visualization before creating a new one
vegaVisualizationExpanded.value = await createVegaVisualization(vegaContainerLg.value, spec, props.config, {
actions: props.areEmbedActionsVisible,
expandable: false
Expand Down Expand Up @@ -219,8 +220,8 @@ async function createVegaVisualization(
watch(
[vegaContainer, () => props.visualizationSpec],
async ([, newSpec], [, oldSpec]) => {
if (_.isEmpty(newSpec)) return;
const isEqual = _.isEqual(newSpec, oldSpec);
if (isEqual && vegaVisualization.value !== undefined) return;
const spec = deepToRaw(props.visualizationSpec);
Expand All @@ -246,6 +247,7 @@ watch(
} else {
// console.log('render interactive');
if (!vegaContainer.value) return;
vegaVisualization.value?.finalize(); // dispose previous visualization before creating a new one
vegaVisualization.value = await createVegaVisualization(vegaContainer.value, spec, props.config, {
actions: props.areEmbedActionsVisible,
expandable: !!props.expandable
Expand All @@ -256,6 +258,11 @@ watch(
{ immediate: true }
);
onUnmounted(() => {
vegaVisualization.value?.finalize();
vegaVisualizationExpanded.value?.finalize();
});
defineExpose({
view,
expandedView
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,15 @@ export interface CalibrateEnsembleWeights {
[key: string]: number;
}

export interface CalibrateEnsembleCiemssOperationState extends BaseState {
export interface CalibrateEnsembleCiemssOperationOutputSettingsState {
showLossChart: boolean;
chartSettings: ChartSetting[] | null;
showModelWeightsCharts: boolean;
}

export interface CalibrateEnsembleCiemssOperationState
extends BaseState,
CalibrateEnsembleCiemssOperationOutputSettingsState {
ensembleMapping: CalibrateEnsembleMappingRow[];
configurationWeights: CalibrateEnsembleWeights;
timestampColName: string;
Expand Down Expand Up @@ -73,6 +80,8 @@ export const CalibrateEnsembleCiemssOperation: Operation = {
initState: () => {
const init: CalibrateEnsembleCiemssOperationState = {
chartSettings: null,
showLossChart: true,
showModelWeightsCharts: true,
ensembleMapping: [],
configurationWeights: {},
timestampColName: '',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,12 @@
</section>
<template #preview>
<tera-drilldown-section>
<section class="pb-3">
<Accordion multiple :active-index="[0, 1]" class="px-2">
<section class="pb-3 px-2">
<div class="mx-2" ref="chartWidthDiv"></div>
<Accordion multiple :active-index="[0, 1, 2]">
<!-- <AccordionTab header="Summary">
</AccordionTab> -->
<AccordionTab header="Loss">
<div ref="chartWidthDiv"></div>
<AccordionTab v-if="node.state.showLossChart" header="Loss">
<vega-chart
v-if="!_.isEmpty(lossValues)"
expandable
Expand All @@ -206,7 +206,6 @@
</AccordionTab>
<template v-if="!isRunInProgress">
<AccordionTab header="Ensemble variables over time">
<div ref="outputPanel"></div>
<div class="flex flex-row" v-for="setting of selectedEnsembleVariableSettings" :key="setting.id">
<vega-chart
v-for="(spec, index) of ensembleVariableCharts[setting.id]"
Expand All @@ -217,6 +216,17 @@
/>
</div>
</AccordionTab>
<AccordionTab v-if="node.state.showModelWeightsCharts" header="Model weights">
<div class="flex flex-row">
<vega-chart
v-for="(spec, index) of weightsDistributionCharts"
:key="index"
expandable
:are-embed-actions-visible="true"
:visualization-spec="spec"
/>
</div>
</AccordionTab>
</template>
</Accordion>
<tera-progress-spinner v-if="isRunInProgress" :font-size="2" is-centered style="height: 100%">
Expand Down Expand Up @@ -248,6 +258,13 @@
</template>
<template #content>
<div class="output-settings-panel">
<h5>Loss</h5>
<tera-checkbox
label="Show loss chart"
:model-value="Boolean(node.state.showLossChart)"
@update:model-value="emit('update-state', { ...node.state, showLossChart: $event })"
/>
<Divider />
<tera-chart-settings
:title="'Ensemble variables over time'"
:settings="chartSettings"
Expand All @@ -260,6 +277,13 @@
@toggle-ensemble-variable-setting-option="updateEnsembleVariableSettingOption"
/>
<Divider />
<h5>Model Weights</h5>
<tera-checkbox
label="Show distributions in charts"
:model-value="Boolean(node.state.showModelWeightsCharts)"
@update:model-value="emit('update-state', { ...node.state, showModelWeightsCharts: $event })"
/>
<Divider />
</div>
</template>
</tera-slider-panel>
Expand Down Expand Up @@ -297,6 +321,7 @@ import TeraPyciemssCancelButton from '@/components/pyciemss/tera-pyciemss-cancel
import TeraSliderPanel from '@/components/widgets/tera-slider-panel.vue';
import TeraChartSettings from '@/components/widgets/tera-chart-settings.vue';
import TeraChartSettingsPanel from '@/components/widgets/tera-chart-settings-panel.vue';
import TeraCheckbox from '@/components/widgets/tera-checkbox.vue';
import TeraInputText from '@/components/widgets/tera-input-text.vue';
import TeraSignalBars from '@/components/widgets/tera-signal-bars.vue';
import TeraTimestepCalendar from '@/components/widgets/tera-timestep-calendar.vue';
Expand Down Expand Up @@ -554,18 +579,20 @@ const {
updateEnsembleVariableSettingOption
} = useChartSettings(props, emit);
const { generateAnnotation, getChartAnnotationsByChartId, useEnsembleVariableCharts } = useCharts(
props.node.id,
null,
allModelConfigurations,
computed(() => buildChartData(outputData.value, selectedOutputMapping.value)),
chartSize,
null,
selectedOutputMapping
);
const { generateAnnotation, getChartAnnotationsByChartId, useEnsembleVariableCharts, useWeightsDistributionCharts } =
useCharts(
props.node.id,
null,
allModelConfigurations,
computed(() => buildChartData(outputData.value, selectedOutputMapping.value)),
chartSize,
null,
selectedOutputMapping
);
const ensembleVariables = computed(() => getSelectedOutputEnsembleMapping(props.node, false).map((d) => d.newName));
const ensembleVariableCharts = useEnsembleVariableCharts(selectedEnsembleVariableSettings, groundTruthData);
const weightsDistributionCharts = useWeightsDistributionCharts();
// --------------------------------------------------------
watch(
Expand Down
97 changes: 71 additions & 26 deletions packages/client/hmi-client/src/composables/useCharts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
ForecastChartOptions
} from '@/services/charts';
import { flattenInterventionData } from '@/services/intervention-policy';
import { DataArray, extractModelConfigIds } from '@/services/models/simulation-service';
import { DataArray, extractModelConfigIdsInOrder, extractModelConfigIds } from '@/services/models/simulation-service';
import { ChartSetting, ChartSettingEnsembleVariable, ChartSettingType } from '@/types/common';
import { Intervention, Model, ModelConfiguration } from '@/types/Types';
import { displayNumber } from '@/utils/number';
Expand Down Expand Up @@ -41,20 +41,14 @@ type VariableMappings = CalibrateMap[] | EnsembleVariableMappings;
const BASE_GREY = '#AAB3C6';
const PRIMARY_COLOR = CATEGORICAL_SCHEME[0];

// Get the model configuration id to variable name mappings for the given ensemble variable
const getModelConfigMappings = (mapping: EnsembleVariableMappings, ensembleVariableName: string) => {
const modelConfigMappings = mapping.find((d) => d.newName === ensembleVariableName)?.modelConfigurationMappings;
return modelConfigMappings ?? {};
};

// Get the model variable name for the corresponding model configuration and the ensemble variable name from the mapping
const getModelConfigVariable = (
mapping: EnsembleVariableMappings,
ensembleVariableName: string,
modelConfigId: string
) => getModelConfigMappings(mapping, ensembleVariableName)[modelConfigId] ?? '';
) => mapping.find((d) => d.newName === ensembleVariableName)?.modelConfigurationMappings[modelConfigId] ?? '';

const getModelConfigIdPrefix = (modelId: string) => (modelId ? `${modelId}/` : '');
const getModelConfigIdPrefix = (configId: string) => (configId ? `${configId}/` : '');

/**
* Converts a model variable name to a dataset variable name based on the provided mapping.
Expand Down Expand Up @@ -94,6 +88,9 @@ const addModelConfigNameToTranslationMap = (
return newMap;
};

// Consider provided reference object is ready if it is set to null explicitly or if it's value is available
const isRefReady = (ref: Ref | null) => ref === null || Boolean(ref.value);

/**
* Composable to manage the creation and configuration of various types of charts used in operator nodes and drilldown.
*
Expand All @@ -115,6 +112,10 @@ export function useCharts(
interventions: Ref<Intervention[]> | null,
mapping: Ref<VariableMappings> | null
) {
// Check if references of the core dependencies are ready to build the chart to prevent multiple re-renders especially
// on initial page load where data are fetched asynchronously and assigned to the references in different times.
const isChartReadyToBuild = computed(() => [model, modelConfig, chartData].every(isRefReady));

// Setup annotations
const { getChartAnnotationsByChartId, generateAndSaveForecastChartAnnotation } = useChartAnnotations(nodeId);

Expand Down Expand Up @@ -160,7 +161,7 @@ export function useCharts(
variables.push(
modelConfigId
? // model variable
getModelConfigIdPrefix(modelConfigId ?? '') +
getModelConfigIdPrefix(modelConfigId) +
getModelConfigVariable(<EnsembleVariableMappings>mapping?.value ?? [], ensembleVarName, modelConfigId)
: // ensemble variable
(modelVarToDatasetVar(mapping?.value ?? [], ensembleVarName) as string)
Expand Down Expand Up @@ -239,8 +240,8 @@ export function useCharts(
const useInterventionCharts = (chartSettings: ComputedRef<ChartSetting[]>, showSamples = false) => {
const interventionCharts = computed(() => {
const charts: Record<string, VisualizationSpec> = {};
if (!chartData.value) return charts;
const { resultSummary, result } = chartData.value;
if (!isChartReadyToBuild.value) return charts;
const { resultSummary, result } = chartData.value as ChartData;
// intervention chart spec
chartSettings.value.forEach((setting) => {
const variable = setting.selectedVariables[0];
Expand Down Expand Up @@ -279,8 +280,8 @@ export function useCharts(
) => {
const variableCharts = computed(() => {
const charts: Record<string, VisualizationSpec> = {};
if (!chartData.value) return charts;
const { result, resultSummary } = chartData.value;
if (!isChartReadyToBuild.value || !isRefReady(groundTruthData)) return charts;
const { result, resultSummary } = chartData.value as ChartData;

chartSettings.value.forEach((settings) => {
const variable = settings.selectedVariables[0];
Expand Down Expand Up @@ -325,8 +326,8 @@ export function useCharts(
const useComparisonCharts = (chartSettings: ComputedRef<ChartSetting[]>) => {
const comparisonCharts = computed(() => {
const charts: Record<string, VisualizationSpec> = {};
if (!chartData.value) return charts;
const { result, resultSummary } = chartData.value;
if (!isChartReadyToBuild.value) return charts;
const { result, resultSummary } = chartData.value as ChartData;
chartSettings.value.forEach((setting) => {
const selectedVars = setting.selectedVariables;
const { statLayerVariables, sampleLayerVariables, options } = createForecastChartOptions(setting);
Expand Down Expand Up @@ -370,16 +371,13 @@ export function useCharts(
) => {
const ensembleVariableCharts = computed(() => {
const charts: Record<string, VisualizationSpec[]> = {};
if (!chartData.value) return charts;
const { result, resultSummary } = chartData.value;
if (!isChartReadyToBuild.value || !isRefReady(groundTruthData)) return chartData;
const { result, resultSummary } = chartData.value as ChartData;
const modelConfigIds = extractModelConfigIdsInOrder(chartData.value?.pyciemssMap ?? {});
chartSettings.value.forEach((setting) => {
const annotations = getChartAnnotationsByChartId(setting.id);
const datasetVar = modelVarToDatasetVar(mapping?.value || [], setting.selectedVariables[0]);
if (setting.showIndividualModels) {
// Build small multiples charts for each model configuration variable
const modelConfigIds = Object.keys(
getModelConfigMappings(<EnsembleVariableMappings>mapping?.value || [], setting.selectedVariables[0])
);
const smallMultiplesCharts = ['', ...modelConfigIds].map((modelConfigId, index) => {
const { sampleLayerVariables, statLayerVariables, options } = createEnsembleVariableChartOptions(
setting,
Expand Down Expand Up @@ -503,13 +501,13 @@ export function useCharts(
// Create parameter distribution charts based on chart settings
const useParameterDistributionCharts = (chartSettings: ComputedRef<ChartSetting[]>) => {
const parameterDistributionCharts = computed(() => {
if (!chartData.value) return {};
const { result, pyciemssMap } = chartData.value;
const charts = {};
if (!isChartReadyToBuild.value) return charts;
const { result, pyciemssMap } = chartData.value as ChartData;
// Note that we want to show the parameter distribution at the first timepoint only
const data = result.filter((d) => d.timepoint_id === 0);
const labelBefore = 'Before calibration';
const labelAfter = 'After calibration';
const charts = {};
chartSettings.value.forEach((setting) => {
const param = setting.selectedVariables[0];
const fieldName = pyciemssMap[param];
Expand Down Expand Up @@ -541,6 +539,52 @@ export function useCharts(
return parameterDistributionCharts;
};

const useWeightsDistributionCharts = () => {
const WEIGHT_PARAM_NAME = 'weight_param';
const weightsCharts = computed(() => {
const charts: VisualizationSpec[] = [];
if (!isChartReadyToBuild.value) return charts;

// Model configs are used to get the model config metadata. This order of model configs in arrays are not guaranteed to be the same as the order of model configs in the pyciemss results
const modelConfigs = <ModelConfiguration[]>modelConfig?.value ?? [];
// extractModelConfigIdsInOrder ensures that the order of model config IDs are matched with the order of corresponding model index in the pyciemss results
const modelConfigIds = extractModelConfigIdsInOrder(chartData.value?.pyciemssMap ?? {});

const data = chartData.value?.result.filter((d) => d.timepoint_id === 0) ?? [];
const labelBefore = 'Before calibration';
const labelAfter = 'After calibration';

const colors = CATEGORICAL_SCHEME.slice(1); // exclude the first color which is for ensemble variable

modelConfigIds.forEach((configId, index) => {
const modelConfigName = getModelConfigName(modelConfigs, configId);
const chartWidth = chartSize.value.width / modelConfigs.length;

const fieldName = chartData.value?.pyciemssMap[`${getModelConfigIdPrefix(configId)}${WEIGHT_PARAM_NAME}`] ?? '';
const beforeFieldName = `${fieldName}:pre`;

const maxBins = 10;
const barWidth = Math.min((chartWidth - 40) / maxBins, 54);
const spec = createHistogramChart(data, {
title: modelConfigName,
width: chartWidth,
height: chartSize.value.height,
xAxisTitle: `Weights`,
yAxisTitle: 'Count',
maxBins,
variables: [
{ field: beforeFieldName, label: labelBefore, width: barWidth, color: BASE_GREY },
{ field: fieldName, label: labelAfter, width: barWidth / 2, color: colors[index % colors.length] }
],
legendProperties: { direction: 'vertical', columns: 1, labelLimit: chartWidth }
});
charts.push(spec);
});
return charts;
});
return weightsCharts;
};

return {
generateAnnotation,
getChartAnnotationsByChartId,
Expand All @@ -549,6 +593,7 @@ export function useCharts(
useComparisonCharts,
useEnsembleVariableCharts,
useErrorChart,
useParameterDistributionCharts
useParameterDistributionCharts,
useWeightsDistributionCharts
};
}
4 changes: 3 additions & 1 deletion packages/client/hmi-client/src/services/charts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export interface ForecastChartLayer {
export interface HistogramChartOptions extends BaseChartOptions {
maxBins?: number;
variables: { field: string; label?: string; width: number; color: string }[];
legendProperties?: Record<string, any>;
}

export interface ErrorChartOptions extends Omit<BaseChartOptions, 'height' | 'yAxisTitle' | 'legend'> {
Expand Down Expand Up @@ -291,7 +292,8 @@ export function createHistogramChart(dataset: Record<string, any>[], options: Hi
symbolStrokeWidth: 4,
symbolSize: 200,
labelFontSize: 12,
labelOffset: 4
labelOffset: 4,
...options.legendProperties
};

const spec: VisualizationSpec = {
Expand Down
Loading

0 comments on commit d7a61f5

Please sign in to comment.