Skip to content

Commit

Permalink
Add option to drop time steps.
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Pfreundschuh committed Aug 30, 2024
1 parent 502bb07 commit d50341f
Showing 1 changed file with 58 additions and 15 deletions.
73 changes: 58 additions & 15 deletions chimp/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,14 @@ def process_tile(
input_maps: Union[torch.Tensor, List[torch.Tensor]],
age_maps: Union[torch.Tensor, List[torch.Tensor]],
slcs: Tuple[slice],
drop_steps: Optional[int],
metrics: Dict[str, List[ScalarMetric]],
metrics_conditional: Dict[str, ScalarMetric],
metrics_step: Dict[str, ScalarMetric],
metrics_forecast: Dict[str, ScalarMetric],
metrics_persistence: Dict[str, ScalarMetric],
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16
dtype: torch.dtype = torch.bfloat16,
) -> None:
"""
Evaluate predictions for a single tile.
Expand All @@ -108,6 +109,8 @@ def process_tile(
for each input time step.
age_maps: List of tensors containing the age maps for all retrieval input steps.
slcs: A tuple of slices to extract the valid domain.
drop_steps: Optional integer specifying the number of time steps to
drop.
metrics: A dictionary mapping target names to corresponding metrics
to compute.
metrics_conditional: A nested dictionary mapping target names to
Expand Down Expand Up @@ -148,13 +151,13 @@ def process_tile(
n_fc = 0

if model.inference_config is None:
seq_len = None
seq_len = 8
else:
seq_len = model.inference_config.input_loader_args.get(
"sequence_length",
None
)
seq_len = 8

if seq_len is None:
inputs = invert_sequence(inputs)
y_pred = {}
Expand Down Expand Up @@ -186,11 +189,17 @@ def process_tile(
y_preds_k_r = y_preds_k
targets_k_r = targets_k

if drop_steps is not None:
y_preds_k_r = y_preds_k_r[drop_steps:-drop_steps]
targets_k_r = targets_k_r[drop_steps:-drop_steps]
input_maps_r = input_maps[drop_steps:-drop_steps]
age_maps_r = age_maps[drop_steps:-drop_steps]

# Evaluate retrieval.
for step, (y_pred_k, target_k, input_map) in enumerate(zip(
y_preds_k_r,
targets_k_r,
input_maps
input_maps_r
)):
if target_k.mask.all():
continue
Expand All @@ -214,8 +223,8 @@ def process_tile(

for metric in metrics_cond:
metric = metric.to(device=device)
if len(age_maps) > 1:
age_map = age_maps[step].__getitem__((..., ind) + slcs)
if len(age_maps_r) > 1:
age_map = age_maps_r[step].__getitem__((..., ind) + slcs)
target_k_c = target_k.detach().clone()
target_k_c.mask[torch.isnan(age_map)] = True
metric.update(y_pred_k_mean, target_k_c, conditional={"age": age_map})
Expand Down Expand Up @@ -259,7 +268,9 @@ def run_tests(
tile_size: Optional[int] = None,
batch_size: int = 32,
device: str = "cuda",
dtype: str = "float32"
dtype: str = "float32",
drop: Optional[List[str]] = None,
drop_steps: Optional[int] = None
) -> xr.Dataset:
"""
Evaluate retrieval module on test set.
Expand All @@ -272,10 +283,16 @@ def run_tests(
tile_size: A tile size to use for the evaluation.
device: The device on which to perform the evaluation.
dtype: The dtype to use.
drop: Optional list of inputs that will be set to missing.
drop_steps: Optional number of retrieval steps that will be
ignored.
Return:
A the xarray.Dataset containing the calculated error metrics.
"""
if drop is None:
drop = []

if conditional:
sequence_length = 0
forecast = 0
Expand Down Expand Up @@ -313,7 +330,9 @@ def run_tests(
metrics_step = {
target: [
mtrcs.Bias(conditional={"step": test_dataset.sequence_length}),
mtrcs.MAE(conditional={"step": test_dataset.sequence_length}),
mtrcs.MSE(conditional={"step": test_dataset.sequence_length}),
mtrcs.SMAPE(conditional={"step": test_dataset.sequence_length}),
mtrcs.CorrelationCoef(conditional={"step": test_dataset.sequence_length}),
] for target in metrics
}
Expand All @@ -323,15 +342,19 @@ def run_tests(
metrics_forecast = {
target: [
mtrcs.Bias(conditional={"step": test_dataset.forecast}),
mtrcs.MSE(conditional={"step": test_dataset.forecast}),
mtrcs.MAE(conditional={"step": test_dataset.sequence_length}),
mtrcs.MSE(conditional={"step": test_dataset.sequence_length}),
mtrcs.SMAPE(conditional={"step": test_dataset.sequence_length}),
mtrcs.CorrelationCoef(conditional={"step": test_dataset.forecast}),
] for target in metrics
}
if test_dataset.include_input_steps:
metrics_persistence = {
target: [
mtrcs.Bias(conditional={"step": test_dataset.forecast}),
mtrcs.MSE(conditional={"step": test_dataset.forecast}),
mtrcs.MAE(conditional={"step": test_dataset.sequence_length}),
mtrcs.MSE(conditional={"step": test_dataset.sequence_length}),
mtrcs.SMAPE(conditional={"step": test_dataset.sequence_length}),
mtrcs.CorrelationCoef(conditional={"step": test_dataset.forecast}),
] for target in metrics
}
Expand All @@ -352,8 +375,6 @@ def run_tests(
shuffle=False,
)

print("LEN :: ", len(data_loader))

with Progress() as progress:
task = progress.add_task("Evaluating retrieval model: ", total=len(data_loader))

Expand All @@ -375,9 +396,15 @@ def run_tests(
if lead_time is not None:
x["lead_time"] = lead_time

for name in drop:
if isinstance(x[name], list):
x[name] = [torch.nan * x_s for x_s in x[name]]
else:
x[name] = torch.nan * x[name]

slcs = tiler.get_slices(row_ind, col_ind)
process_tile(
model, x, y, input_map, age_map, slcs, metrics, metrics_conditional,
model, x, y, input_map, age_map, slcs, drop_steps, metrics, metrics_conditional,
metrics_step=metrics_step, metrics_forecast=metrics_forecast,
metrics_persistence=metrics_persistence, device=device, dtype=dtype
)
Expand Down Expand Up @@ -436,6 +463,8 @@ def run_tests(
@click.option("--sequence_length", type=int, default=None)
@click.option("--forecast", type=int, default=0)
@click.option("-v", "--verbose", count=True)
@click.option("--drop", type=str, default=None)
@click.option("--drop_steps", type=int, default=None)
def cli(
model: Path,
test_data_path: str,
Expand All @@ -448,7 +477,9 @@ def cli(
verbose: int = 0,
batch_size: int = 32,
sequence_length: Optional[int] = 1,
forecast: int = 0
forecast: int = 0,
drop: Optional[str] = None,
drop_steps: Optional[int] = None
) -> int:
"""
Evaluate model on test data located in TEST_DATA_PATH and write results to
Expand All @@ -470,6 +501,11 @@ def cli(
)
return 1

if drop_steps is None:
sample_rate = 1
else:
sample_rate = sequence_length // (sequence_length - 2 * drop_steps)

test_data = SequenceDataset(
test_data_path,
input_datasets=input_datasets,
Expand All @@ -480,19 +516,24 @@ def cli(
sequence_length=sequence_length,
forecast=forecast,
include_input_steps=True,
sample_rate=1
sample_rate=sample_rate
)

metrics = {
name: [
mtrcs.Bias(),
mtrcs.MSE(),
mtrcs.MAE(),
mtrcs.SMAPE(),
mtrcs.CorrelationCoef()
] for name in model.to_config_dict()["output"].keys()
}

dtype = getattr(torch, dtype)

if drop is not None:
drop = drop.split(",")

retrieval_results, forecast_results = run_tests(
model,
test_data,
Expand All @@ -501,7 +542,9 @@ def cli(
tile_size=tile_size,
device=device,
dtype=dtype,
batch_size=batch_size
batch_size=batch_size,
drop=drop,
drop_steps=drop_steps
)

if retrieval_results is not None:
Expand Down

0 comments on commit d50341f

Please sign in to comment.