From 95baec7afb143e878e8c2c7709a876c273bbe957 Mon Sep 17 00:00:00 2001 From: Alec Wysoker Date: Thu, 24 Aug 2023 14:57:47 -0400 Subject: [PATCH] Move hash computation so that it is recomputed on retry, and now-invalid checkpoint is not loaded. If number of tries is exhausted, and ELBO tests are still failing, allow to complete anyway (using checkpoint) so that outputs are produced, but exit(1). --- cellbender/remove_background/cli.py | 11 ----------- cellbender/remove_background/run.py | 25 ++++++++++++++++++++++++- cellbender/remove_background/train.py | 2 +- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/cellbender/remove_background/cli.py b/cellbender/remove_background/cli.py index ad49c01..29f7928 100644 --- a/cellbender/remove_background/cli.py +++ b/cellbender/remove_background/cli.py @@ -207,17 +207,6 @@ def setup_and_logging(args): + ' '.join(['cellbender', 'remove-background'] + sys.argv[2:])) logger.info("CellBender " + get_version()) - # Set up checkpointing by creating a unique workflow hash. - hashcode = create_workflow_hashcode( - module_path=os.path.dirname(cellbender.__file__), - args_to_remove=(['output_file', 'fpr', 'input_checkpoint_tarball', 'debug', - 'posterior_batch_size', 'checkpoint_min', 'truth_file', - 'posterior_regularization', 'cdf_threshold_q', 'prq_alpha', - 'estimator', 'use_multiprocessing_estimation', 'cpu_threads'] - + (['epochs'] if args.constant_learning_rate else [])), - args=args)[:10] - args.checkpoint_filename = hashcode # store this in args - logger.info(f'(Workflow hash {hashcode})') return args, file_handler diff --git a/cellbender/remove_background/run.py b/cellbender/remove_background/run.py index 7b777e4..42cf5b4 100644 --- a/cellbender/remove_background/run.py +++ b/cellbender/remove_background/run.py @@ -1,5 +1,6 @@ """Single run of remove-background, given input arguments.""" +import cellbender from cellbender.remove_background.model import RemoveBackgroundPyroModel from cellbender.remove_background.data.dataset import get_dataset_obj, \ SingleCellRNACountsDataset @@ -21,6 +22,7 @@ from cellbender.remove_background.sparse_utils import csr_set_rows_to_zero from cellbender.remove_background.data.io import write_matrix_to_cellranger_h5 from cellbender.remove_background.report import run_notebook_make_html, plot_summary +from cellbender.remove_background.checkpoint import create_workflow_hashcode import pyro from pyro.infer import SVI, JitTraceEnum_ELBO, JitTrace_ELBO, \ @@ -59,6 +61,22 @@ def run_remove_background(args: argparse.Namespace) -> Posterior: """ + # Set up checkpointing by creating a unique workflow hash. + hashcode = create_workflow_hashcode( + module_path=os.path.dirname(cellbender.__file__), + args_to_remove=(['output_file', 'fpr', 'input_checkpoint_tarball', 'debug', + 'posterior_batch_size', 'checkpoint_min', 'truth_file', + 'posterior_regularization', 'cdf_threshold_q', 'prq_alpha', + 'estimator', 'use_multiprocessing_estimation', 'cpu_threads', + # The following settings do not affect the results, and can change when retrying, + # so remove them. + 'epoch_elbo_fail_fraction', 'final_elbo_fail_fraction', + 'num_failed_attempts', 'checkpoint_filename'] + + (['epochs'] if args.constant_learning_rate else [])), + args=args)[:10] + args.checkpoint_filename = hashcode # store this in args + logger.info(f'(Workflow hash {hashcode})') + # Handle initial random state. pyro.util.set_rng_seed(consts.RANDOM_SEED) if torch.cuda.is_available(): @@ -771,7 +789,12 @@ def run_inference(dataset_obj: SingleCellRNACountsDataset, sys.exit(0) else: logger.info(f'No more attempts are specified by --num-training-tries. ' - f'Therefore the workflow will abort here.') + f'Therefore the workflow will run once more without ELBO restrictions.') + args.epoch_elbo_fail_fraction = None + args.final_elbo_fail_fraction = None + run_remove_background(args) # start from scratch + # non-zero exit status in order to draw user's attention to the fact that ELBO tests + # were never satisfied. sys.exit(1) logger.info("Inference procedure complete.") diff --git a/cellbender/remove_background/train.py b/cellbender/remove_background/train.py index 8d77808..c333ce4 100644 --- a/cellbender/remove_background/train.py +++ b/cellbender/remove_background/train.py @@ -224,7 +224,7 @@ def run_training(model: RemoveBackgroundPyroModel, overall_diff = np.abs(test_elbo[-2] - test_elbo[0]) fractional_spike = current_diff / overall_diff if fractional_spike > epoch_elbo_fail_fraction: - raise ElboException( + raise ElboException( f'Training failed because test loss moved {current_diff:.2f} ' f'in the wrong direction, and that is {fractional_spike:.2f} ' f'of the total test ELBO change, > '