From db721398c2fd7d6392eb58550ec5d0d75c29ec4a Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Mon, 23 Mar 2020 11:18:45 -0700 Subject: [PATCH] add exception handling and retry on loading train batch (#194) * add exception handling and retry on loading train batch * adjust number of batches in config record if bad batch skipped --- environment.yml | 1 + fv3net/regression/dataset_handler.py | 67 ++++++++++++++++++++-------- fv3net/regression/sklearn/train.py | 43 ++++++++++++++---- 3 files changed, 83 insertions(+), 28 deletions(-) diff --git a/environment.yml b/environment.yml index 593d9ef017..ae13d0cef8 100644 --- a/environment.yml +++ b/environment.yml @@ -2,6 +2,7 @@ channels: - defaults - conda-forge dependencies: + - backoff - cartopy - click - conda-build diff --git a/fv3net/regression/dataset_handler.py b/fv3net/regression/dataset_handler.py index 385db997e3..33eaf0a00f 100644 --- a/fv3net/regression/dataset_handler.py +++ b/fv3net/regression/dataset_handler.py @@ -1,14 +1,13 @@ +import backoff import logging from dataclasses import dataclass from typing import List - -import gcsfs -from math import ceil import numpy as np import xarray as xr +import vcm +from vcm.cloud.fsspec import get_fs from vcm.cubedsphere.constants import COORD_Z_CENTER, INIT_TIME_DIM -from vcm.select import mask_to_surface_type SAMPLE_DIM = "sample" @@ -19,13 +18,20 @@ logger.addHandler(fh) +class RemoteDataError(Exception): + """ Raised for errors reading data from the cloud that + may be resolved upon retry. + """ + + pass + + @dataclass class BatchGenerator: data_vars: List[str] gcs_data_dir: str files_per_batch: int num_batches: int = None - gcs_project: str = "vcm-ml" random_seed: int = 1234 mask_to_surface_type: str = "none" @@ -36,26 +42,26 @@ def __post_init__(self): nested list of zarr paths, where inner lists are the sets of zarrs used to train each batch """ - self.fs = gcsfs.GCSFileSystem(project=self.gcs_project) - print(f"Reading data from {self.gcs_data_dir}.") + self.fs = get_fs(self.gcs_data_dir) + logger.info(f"Reading data from {self.gcs_data_dir}.") zarr_urls = [ zarr_file for zarr_file in self.fs.ls(self.gcs_data_dir) if "grid_spec" not in zarr_file ] total_num_input_files = len(zarr_urls) - print(f"Number of .zarrs read from GCS: {total_num_input_files}.") + logger.info(f"Number of .zarrs in GCS train data dir: {total_num_input_files}.") np.random.seed(self.random_seed) np.random.shuffle(zarr_urls) - num_batches = self._validated_num_batches(total_num_input_files) - print(f"{num_batches} data batches generated for model training.") + self.num_batches = self._validated_num_batches(total_num_input_files) + logger.info(f"{self.num_batches} data batches generated for model training.") self.train_file_batches = [ zarr_urls[ batch_num * self.files_per_batch : (batch_num + 1) * self.files_per_batch ] - for batch_num in range(num_batches) + for batch_num in range(self.num_batches) ] def generate_batches(self): @@ -69,12 +75,38 @@ def generate_batches(self): """ grouped_urls = self.train_file_batches for file_batch_urls in grouped_urls: - fs_paths = [self.fs.get_mapper(url) for url in file_batch_urls] - ds = xr.concat(map(xr.open_zarr, fs_paths), INIT_TIME_DIM) - ds = mask_to_surface_type(ds, self.mask_to_surface_type) + try: + ds_shuffled = self._create_training_batch_with_retries(file_batch_urls) + except ValueError: + logger.error( + f"Failed to generate batch from files {file_batch_urls}." + "Skipping to next batch." + ) + continue + yield ds_shuffled + + @backoff.on_exception(backoff.expo, RemoteDataError, max_tries=3) + def _create_training_batch_with_retries(self, urls): + timestep_paths = [self.fs.get_mapper(url) for url in urls] + try: + ds = xr.concat(map(xr.open_zarr, timestep_paths), INIT_TIME_DIM) + ds = vcm.mask_to_surface_type(ds, self.mask_to_surface_type) ds_stacked = stack_and_drop_nan_samples(ds).unify_chunks() ds_shuffled = _shuffled(ds_stacked, SAMPLE_DIM, self.random_seed) - yield ds_shuffled + return ds_shuffled + except ValueError as e: + # error when attempting to read from GCS that sometimes resolves on retry + if "array not found at path" in str(e): + logger.error( + f"Error reading data from {timestep_paths}, will retry. {e}" + ) + raise RemoteDataError( + f"Failed to read data from remote location: {str(e)}" + ) + # other errors that will not recover on retry + else: + logger.error(f"Error reading data from {timestep_paths}. {e}") + raise e def _validated_num_batches(self, total_num_input_files): """ check that the number of batches (if provided) and the number of @@ -90,10 +122,7 @@ def _validated_num_batches(self, total_num_input_files): elif self.num_batches * self.files_per_batch > total_num_input_files: if self.num_batches > total_num_input_files: raise ValueError("Fewer input files than number of requested batches.") - num_train_batches = self.num_batches - ceil( - (self.num_batches * self.files_per_batch - total_num_input_files) - / self.num_batches - ) + num_train_batches = total_num_input_files // self.files_per_batch else: num_train_batches = self.num_batches return num_train_batches diff --git a/fv3net/regression/sklearn/train.py b/fv3net/regression/sklearn/train.py index 550a10b38f..cd989b4e30 100644 --- a/fv3net/regression/sklearn/train.py +++ b/fv3net/regression/sklearn/train.py @@ -1,5 +1,6 @@ import argparse import joblib +import logging import os import yaml @@ -15,6 +16,12 @@ MODEL_CONFIG_FILENAME = "training_config.yml" MODEL_FILENAME = "sklearn_model.pkl" +logger = logging.getLogger() +logger.setLevel(logging.INFO) +fh = logging.FileHandler("ml_training.log") +fh.setLevel(logging.INFO) +logger.addHandler(fh) + @dataclass class ModelTrainingConfig: @@ -33,6 +40,17 @@ class ModelTrainingConfig: random_seed: int = 1234 mask_to_surface_type: str = "none" + def validate_number_train_batches(self, batch_generator): + """ Since number of training files specified may be larger than + the actual number available, this adds an attribute num_batches_used + that keeps information about the actual number of training batches + used. + + Args: + batch_generator (BatchGenerator) + """ + self.num_batches_used = batch_generator.num_batches + def load_model_training_config(config_path, gcs_data_dir): """ @@ -115,18 +133,25 @@ def train_model(batched_data, train_config): target_transformer = StandardScaler() transform_regressor = TransformedTargetRegressor(base_regressor, target_transformer) batch_regressor = RegressorEnsemble(transform_regressor) - model_wrapper = SklearnWrapper(batch_regressor) + train_config.validate_number_train_batches(batched_data) + for i, batch in enumerate(batched_data.generate_batches()): - print(f"Fitting batch {i}/{batched_data.num_batches}") - model_wrapper.fit( - input_vars=train_config.input_variables, - output_vars=train_config.output_variables, - sample_dim="sample", - data=batch, - ) - print(f"Batch {i} done fitting.") + logger.info(f"Fitting batch {i}/{batched_data.num_batches}") + try: + model_wrapper.fit( + input_vars=train_config.input_variables, + output_vars=train_config.output_variables, + sample_dim="sample", + data=batch, + ) + logger.info(f"Batch {i} done fitting.") + except ValueError as e: + logger.error(f"Error training on batch {i}: {e}") + train_config.num_batches_used -= 1 + continue + return model_wrapper