Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make optimization configurable #318

Merged
merged 11 commits into from
Aug 23, 2024
33 changes: 32 additions & 1 deletion alphadia/constants/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,39 @@ search_output:
# can be either "parquet" or "tsv"
file_format: "tsv"

optimization:
# The order in which to perform optimization. Should be a list of lists of parameter names
# Example: [['ms1_error', 'ms2_error', 'rt_error', 'mobility_error']] means that all parameters are optimized simultaneously.
# Example: [["ms2_error"], ["rt_error"], ["ms1_error"], ["mobility_error"]] means that the parameters are optimized sequentially in the order given.
# Example: [["rt_error"], ["ms1_error", "ms2_error"]] means that first rt_error is optimized, then ms1_error and ms2_error are optimized simultaneously, and mobility_error is not optimized at all.
order_of_optimization: null
odespard marked this conversation as resolved.
Show resolved Hide resolved

# Parameters for the update rule for each parameter:
# - update_interval: the percentile interval to use (as a decimal)
# - update_factor: the factor by which to multiply the result from the percentile interval to get the new parameter value for the next round of search
ms2_error:
targeted_update_interval: 0.95
targeted_update_factor: 1.0
automatic_update_interval: 0.99
automatic_update_factor: 1.1
ms1_error:
targeted_update_interval: 0.95
targeted_update_factor: 1.0
automatic_update_interval: 0.99
automatic_update_factor: 1.1
mobility_error:
targeted_update_interval: 0.95
targeted_update_factor: 1.0
automatic_update_interval: 0.99
automatic_update_factor: 1.1
rt_error:
targeted_update_interval: 0.95
targeted_update_factor: 1.0
automatic_update_interval: 0.99
automatic_update_factor: 1.1

# configuration for the optimization manager
# initial parameters, will nbe optimized
# initial parameters, will be optimized
optimization_manager:
fwhm_rt: 5
fwhm_mobility: 0.01
Expand Down
53 changes: 34 additions & 19 deletions alphadia/workflow/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ def __init__(
self.workflow.optimization_manager.fit({self.parameter_name: initial_parameter})
self.has_converged = False
self.num_prev_optimizations = 0
self.update_factor = workflow.config["optimization"][self.parameter_name][
"automatic_update_factor"
]
self.update_interval = workflow.config["optimization"][self.parameter_name][
"automatic_update_interval"
]

def step(
self,
Expand Down Expand Up @@ -294,6 +300,12 @@ def __init__(
super().__init__(workflow, reporter)
self.workflow.optimization_manager.fit({self.parameter_name: initial_parameter})
self.target_parameter = target_parameter
self.update_factor = workflow.config["optimization"][self.parameter_name][
"targeted_update_factor"
]
self.update_interval = workflow.config["optimization"][self.parameter_name][
"targeted_update_interval"
]
self.has_converged = False

def _check_convergence(self, proposed_parameter: float):
Expand Down Expand Up @@ -324,10 +336,10 @@ def _propose_new_parameter(self, df: pd.DataFrame):
3) take the maximum of this value and the target parameter.
This is implemented by the ci method for the estimator.
"""
return max(
return self.update_factor * max(
self.workflow.calibration_manager.get_estimator(
self.estimator_group_name, self.estimator_name
).ci(df, 0.95),
).ci(df, self.update_interval),
self.target_parameter,
)

Expand Down Expand Up @@ -388,8 +400,9 @@ def __init__(
def _propose_new_parameter(self, df: pd.DataFrame):
"""See base class. The update rule is
1) calculate the deviation of the predicted mz values from the observed mz values,
2) take the mean of the endpoints of the central 99% of these deviations, and
3) multiply this value by 1.1.
2) take the mean of the endpoints of the central interval
(determined by the self.update_interval attribute, which determines the percentile taken expressed as a decimal) of these deviations, and
3) multiply this value by self.update_factor.
This is implemented by the ci method for the estimator.

Returns
Expand All @@ -398,9 +411,9 @@ def _propose_new_parameter(self, df: pd.DataFrame):
The proposed new value for the search parameter.

"""
return 1.1 * self.workflow.calibration_manager.get_estimator(
return self.update_factor * self.workflow.calibration_manager.get_estimator(
self.estimator_group_name, self.estimator_name
).ci(df, 0.99)
).ci(df, self.update_interval)

def _get_feature_value(
self, precursors_df: pd.DataFrame, fragments_df: pd.DataFrame
Expand All @@ -425,8 +438,9 @@ def __init__(
def _propose_new_parameter(self, df: pd.DataFrame):
"""See base class. The update rule is
1) calculate the deviation of the predicted mz values from the observed mz values,
2) take the mean of the endpoints of the central 99% of these deviations, and
3) multiply this value by 1.1.
2) take the mean of the endpoints of the central interval
(determined by the self.update_interval attribute, which determines the percentile taken expressed as a decimal) of these deviations, and
3) multiply this value by self.update_factor.
This is implemented by the ci method for the estimator.

Returns
Expand All @@ -435,9 +449,9 @@ def _propose_new_parameter(self, df: pd.DataFrame):
The proposed new value for the search parameter.

"""
return 1.1 * self.workflow.calibration_manager.get_estimator(
return self.update_factor * self.workflow.calibration_manager.get_estimator(
self.estimator_group_name, self.estimator_name
).ci(df, 0.99)
).ci(df, self.update_interval)

def _get_feature_value(
self, precursors_df: pd.DataFrame, fragments_df: pd.DataFrame
Expand All @@ -462,20 +476,20 @@ def __init__(
def _propose_new_parameter(self, df: pd.DataFrame):
"""See base class. The update rule is
1) calculate the deviation of the predicted mz values from the observed mz values,
2) take the mean of the endpoints of the central 99% of these deviations, and
3) multiply this value by 1.1.
2) take the mean of the endpoints of the central interval
(determined by the self.update_interval attribute, which determines the percentile taken expressed as a decimal) of these deviations, and
3) multiply this value by self.update_factor.
This is implemented by the ci method for the estimator.


Returns
-------
float
The proposed new value for the search parameter.

"""
return 1.1 * self.workflow.calibration_manager.get_estimator(
return self.update_factor * self.workflow.calibration_manager.get_estimator(
self.estimator_group_name, self.estimator_name
).ci(df, 0.99)
).ci(df, self.update_interval)

def _get_feature_value(
self, precursors_df: pd.DataFrame, fragments_df: pd.DataFrame
Expand All @@ -500,8 +514,9 @@ def __init__(
def _propose_new_parameter(self, df: pd.DataFrame):
"""See base class. The update rule is
1) calculate the deviation of the predicted mz values from the observed mz values,
2) take the mean of the endpoints of the central 99% of these deviations, and
3) multiply this value by 1.1.
2) take the mean of the endpoints of the central interval
(determined by the self.update_interval attribute, which determines the percentile taken expressed as a decimal) of these deviations, and
3) multiply this value by self.update_factor.
This is implemented by the ci method for the estimator.

Returns
Expand All @@ -511,9 +526,9 @@ def _propose_new_parameter(self, df: pd.DataFrame):

"""

return 1.1 * self.workflow.calibration_manager.get_estimator(
return self.update_factor * self.workflow.calibration_manager.get_estimator(
odespard marked this conversation as resolved.
Show resolved Hide resolved
self.estimator_group_name, self.estimator_name
).ci(df, 0.99)
).ci(df, self.update_interval)

def _get_feature_value(
self, precursors_df: pd.DataFrame, fragments_df: pd.DataFrame
Expand Down
65 changes: 42 additions & 23 deletions alphadia/workflow/peptidecentric.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,32 +323,51 @@ def get_ordered_optimizers(self):
else:
mobility_optimizer = None

optimizers = [
ms2_optimizer,
rt_optimizer,
ms1_optimizer,
mobility_optimizer,
]
targeted_optimizers = [
[
optimizer
if self.config["optimization"]["order_of_optimization"] is None:
odespard marked this conversation as resolved.
Show resolved Hide resolved
optimizers = [
ms2_optimizer,
rt_optimizer,
ms1_optimizer,
mobility_optimizer,
]
targeted_optimizers = [
[
optimizer
for optimizer in optimizers
if isinstance(optimizer, optimization.TargetedOptimizer)
]
]
automatic_optimizers = [
[optimizer]
for optimizer in optimizers
if isinstance(optimizer, optimization.TargetedOptimizer)
if isinstance(optimizer, optimization.AutomaticOptimizer)
]
]
automatic_optimizers = [
[optimizer]
for optimizer in optimizers
if isinstance(optimizer, optimization.AutomaticOptimizer)
]

ordered_optimizers = (
targeted_optimizers + automatic_optimizers
if any(
targeted_optimizers
) # This line is required so no empty list is added to the ordered_optimizers list
else automatic_optimizers
)
ordered_optimizers = (
targeted_optimizers + automatic_optimizers
if any(
targeted_optimizers
) # This line is required so no empty list is added to the ordered_optimizers list
else automatic_optimizers
)
else:
opt_mapping = {
"ms2_error": ms2_optimizer,
"rt_error": rt_optimizer,
"ms1_error": ms1_optimizer,
"mobility_error": mobility_optimizer,
}
ordered_optimizers = []
for optimizers_in_ordering in self.config["optimization"][
"order_of_optimization"
]:
ordered_optimizers += [
[
opt_mapping[opt]
for opt in optimizers_in_ordering
if opt_mapping[opt] is not None
]
]

return ordered_optimizers

Expand Down
53 changes: 53 additions & 0 deletions tests/unit_tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,12 @@ class MockOptlock:

workflow.optlock = MockOptlock()

class MockDIAData:
has_mobility = True
has_ms1 = True

workflow._dia_data = MockDIAData()

return workflow


Expand Down Expand Up @@ -915,6 +921,53 @@ def test_optlock_batch_idx():
assert optlock.stop_idx == 2000


def test_configurability():
workflow = create_workflow_instance()
workflow.config["optimization"].update(
{
"order_of_optimization": [
["rt_error"],
["ms1_error", "ms2_error"],
["mobility_error"],
],
"rt_error": {
"automatic_update_interval": 0.99,
"automatic_update_factor": 1.3,
},
"ms2_error": {
"automatic_update_interval": 0.80,
"targeted_update_interval": 0.995,
"targeted_update_factor": 1.2,
},
}
)
workflow.config["search"].update(
{
"target_rt_tolerance": -1,
}
)

ordered_optimizers = workflow.get_ordered_optimizers()

assert len(ordered_optimizers) == 3

assert ordered_optimizers[0][0].parameter_name == "rt_error"
assert isinstance(ordered_optimizers[0][0], optimization.AutomaticRTOptimizer)
assert ordered_optimizers[0][0].update_interval == 0.99
assert ordered_optimizers[0][0].update_factor == 1.3

assert ordered_optimizers[1][0].parameter_name == "ms1_error"
assert ordered_optimizers[1][0].update_interval == 0.95
assert isinstance(ordered_optimizers[1][0], optimization.TargetedMS1Optimizer)

assert ordered_optimizers[1][1].parameter_name == "ms2_error"
assert isinstance(ordered_optimizers[1][1], optimization.TargetedMS2Optimizer)
assert ordered_optimizers[1][1].update_interval == 0.995
assert ordered_optimizers[1][1].update_factor == 1.2

assert ordered_optimizers[2][0].parameter_name == "mobility_error"


def test_optlock_reindex():
library = create_test_library_for_indexing()
optlock = optimization.OptimizationLock(library, TEST_OPTLOCK_CONFIG)
Expand Down
Loading