diff --git a/pypesto/select/method.py b/pypesto/select/method.py index 36a726be6..d1ff12443 100644 --- a/pypesto/select/method.py +++ b/pypesto/select/method.py @@ -8,6 +8,10 @@ import numpy as np import petab_select from petab_select import ( + CANDIDATE_SPACE, + MODELS, + PREDECESSOR_MODEL, + UNCALIBRATED_MODELS, VIRTUAL_INITIAL_MODEL, CandidateSpace, Criterion, @@ -213,6 +217,11 @@ class MethodCaller: Specify the predecessor (initial) model for the model selection algorithm. If ``None``, then the algorithm will generate an initial predecessor model if required. + user_calibrated_models: + Supply calibration results for models yourself, as a list of models. + If a model with the same hash is encountered in the current model + selection run, and the user-supplied calibrated model has the + `criterion` value set, the model will not be calibrated again. select_first_improvement: If ``True``, model selection will terminate as soon as a better model is found. If `False`, all candidate models will be tested. @@ -245,6 +254,7 @@ def __init__( # TODO deprecated model_to_pypesto_problem_method: Callable[[Any], Problem] = None, model_problem_options: dict = None, + user_calibrated_models: list[Model] = None, ): """Arguments are used in every `__call__`, unless overridden.""" self.petab_select_problem = petab_select_problem @@ -256,6 +266,12 @@ def __init__( self.select_first_improvement = select_first_improvement self.startpoint_latest_mle = startpoint_latest_mle + self.user_calibrated_models = {} + if user_calibrated_models is not None: + self.user_calibrated_models = { + model.get_hash(): model for model in user_calibrated_models + } + self.logger = MethodLogger() # TODO deprecated @@ -335,10 +351,7 @@ def __init__( # May have changed from `None` to `petab_select.VIRTUAL_INITIAL_MODEL` self.predecessor_model = self.candidate_space.get_predecessor_model() - def __call__( - self, - newly_calibrated_models: Optional[dict[str, Model]] = None, - ) -> tuple[list[Model], dict[str, Model]]: + def __call__(self) -> tuple[list[Model], dict[str, Model]]: """Run a single iteration of the model selection method. A single iteration here refers to calibration of all candidate models. @@ -347,14 +360,6 @@ def __call__( of all models that have both: the same 3 estimated parameters; and 1 additional estimated parameter. - The input `newly_calibrated_models` is from the previous iteration. The - output `newly_calibrated_models` is from the current iteration. - - Parameters - ---------- - newly_calibrated_models: - The newly calibrated models from the previous iteration. - Returns ------- A 2-tuple, with the following values: @@ -366,39 +371,56 @@ def __call__( # All calibrated models in this iteration (see second return value). self.logger.new_selection() - candidate_space = petab_select.ui.candidates( + iteration = petab_select.ui.start_iteration( problem=self.petab_select_problem, candidate_space=self.candidate_space, limit=self.limit, - calibrated_models=self.calibrated_models, - newly_calibrated_models=newly_calibrated_models, - excluded_model_hashes=self.calibrated_models.keys(), criterion=self.criterion, + user_calibrated_models=self.user_calibrated_models, ) - predecessor_model = self.candidate_space.predecessor_model - if not candidate_space.models: + if not iteration[UNCALIBRATED_MODELS]: raise StopIteration("No valid models found.") # TODO parallelize calibration (maybe not sensible if # `self.select_first_improvement`) - newly_calibrated_models = {} - for candidate_model in candidate_space.models: - # autoruns calibration - self.new_model_problem(model=candidate_model) - newly_calibrated_models[ - candidate_model.get_hash() - ] = candidate_model + calibrated_models = {} + for model in iteration[UNCALIBRATED_MODELS]: + if ( + model.get_criterion( + criterion=self.criterion, + compute=True, + raise_on_failure=False, + ) + is not None + ): + self.logger.log( + message=( + "Unexpected calibration result already available for " + f"model: `{model.get_hash()}`. Skipping " + "calibration." + ), + level="warning", + ) + else: + self.new_model_problem(model=model) + + calibrated_models[model.get_hash()] = model method_signal = self.handle_calibrated_model( - model=candidate_model, - predecessor_model=predecessor_model, + model=model, + predecessor_model=iteration[PREDECESSOR_MODEL], ) if method_signal.proceed == MethodSignalProceed.STOP: break - self.calibrated_models.update(newly_calibrated_models) + iteration_results = petab_select.ui.end_iteration( + candidate_space=iteration[CANDIDATE_SPACE], + calibrated_models=calibrated_models, + ) + + self.calibrated_models.update(iteration_results[MODELS]) - return predecessor_model, newly_calibrated_models + return iteration[PREDECESSOR_MODEL], iteration_results[MODELS] def handle_calibrated_model( self, diff --git a/pypesto/select/problem.py b/pypesto/select/problem.py index 9692890ef..e0bd08cf3 100644 --- a/pypesto/select/problem.py +++ b/pypesto/select/problem.py @@ -156,16 +156,13 @@ def select( self.handle_select_kwargs(kwargs) # TODO handle bidirectional method_caller = self.create_method_caller(**kwargs) - previous_best_model, newly_calibrated_models = method_caller( - # TODO add predecessor model to state - newly_calibrated_models=self.newly_calibrated_models, - ) + previous_best_model, newly_calibrated_models = method_caller() self.update_with_newly_calibrated_models( newly_calibrated_models=newly_calibrated_models, ) - best_model = petab_select.ui.best( + best_model = petab_select.ui.get_best( problem=self.petab_select_problem, models=self.newly_calibrated_models.values(), criterion=method_caller.criterion, @@ -198,9 +195,7 @@ def select_to_completion( while True: try: - previous_best_model, newly_calibrated_models = method_caller( - newly_calibrated_models=self.newly_calibrated_models, - ) + previous_best_model, newly_calibrated_models = method_caller() self.update_with_newly_calibrated_models( newly_calibrated_models=newly_calibrated_models, ) @@ -247,33 +242,18 @@ def multistart_select( """ self.handle_select_kwargs(kwargs) model_lists = [] - newly_calibrated_models_list = [ - self.newly_calibrated_models for _ in predecessor_models - ] - method_caller = self.create_method_caller(**kwargs) - for start_index, predecessor_model in enumerate(predecessor_models): - method_caller.candidate_space.previous_predecessor_model = ( - predecessor_model - ) - ( - best_model, - newly_calibrated_models_list[start_index], - ) = method_caller( - newly_calibrated_models=newly_calibrated_models_list[ - start_index - ], - ) - self.calibrated_models.update( - newly_calibrated_models_list[start_index] + for predecessor_model in predecessor_models: + method_caller = self.create_method_caller( + **(kwargs | {"predecessor_model": predecessor_model}) ) + (best_model, models) = method_caller() + self.calibrated_models |= models - model_lists.append( - newly_calibrated_models_list[start_index].values() - ) + model_lists.append(list(models.values())) method_caller.candidate_space.reset() - best_model = petab_select.ui.best( + best_model = petab_select.ui.get_best( problem=method_caller.petab_select_problem, models=[model for models in model_lists for model in models], criterion=method_caller.criterion, diff --git a/setup.cfg b/setup.cfg index 4fc1a4e1b..bc28ae610 100644 --- a/setup.cfg +++ b/setup.cfg @@ -163,10 +163,12 @@ example = ipywidgets >= 8.1.5 benchmark_models_petab @ git+https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git@master#subdirectory=src/python select = - # Remove when vis is moved to PEtab Select version + # Remove when vis is moved to PEtab Select networkx >= 2.5.1 # End remove - petab-select >= 0.1.12 + #petab-select >= 0.1.12 + # FIXME before merge + petab-select @ git+https://github.com/PEtab-dev/petab_select.git@develop test = pytest >= 5.4.3 pytest-cov >= 2.10.0 diff --git a/test/select/test_select.py b/test/select/test_select.py index b471365d6..5ba2d5a3e 100644 --- a/test/select/test_select.py +++ b/test/select/test_select.py @@ -279,6 +279,8 @@ def test_problem_multistart_select(pypesto_select_problem, initial_models): "M1_3": -4.705, # 'M1_7': -4.056, # skipped -- reproducibility requires many starts } + # As M1_7 criterion comparison is skipped, at least ensure it is present + assert {m.model_subspace_id for m in best_models} == {"M1_3", "M1_7"} test_best_models_criterion_values = { model.model_subspace_id: model.get_criterion(Criterion.AIC) for model in best_models