Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix numerous issues with tool input format "21.01" #19030

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lib/galaxy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1836,7 +1836,9 @@ def expand_incoming(
# Expand these out to individual parameters for given jobs (tool executions).
expanded_incomings: List[ToolStateJobInstanceT]
collection_info: Optional[MatchingCollections]
expanded_incomings, collection_info = expand_meta_parameters(request_context, self, incoming)
expanded_incomings, collection_info = expand_meta_parameters(
request_context, self, incoming, input_format=input_format
)

self._ensure_expansion_is_valid(expanded_incomings, rerun_remap_job_id)

Expand Down
39 changes: 27 additions & 12 deletions lib/galaxy/tools/parameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def visit_input_values(
context=None,
no_replacement_value=REPLACE_ON_TRUTHY,
replace_optional_connections=False,
allow_case_inference=False,
unset_value=None,
):
"""
Given a tools parameter definition (`inputs`) and a specific set of
Expand Down Expand Up @@ -158,7 +160,7 @@ def visit_input_values(
"""

def callback_helper(input, input_values, name_prefix, label_prefix, parent_prefix, context=None, error=None):
value = input_values.get(input.name)
value = input_values.get(input.name, unset_value)
args = {
"input": input,
"parent": input_values,
Expand All @@ -182,13 +184,23 @@ def callback_helper(input, input_values, name_prefix, label_prefix, parent_prefi
input_values[input.name] = input.value

def get_current_case(input, input_values):
test_parameter = input.test_param
test_parameter_name = test_parameter.name
try:
return input.get_current_case(input_values[input.test_param.name])
if test_parameter_name not in input_values and allow_case_inference:
return input.get_current_case(test_parameter.get_initial_value(None, input_values))
else:
return input.get_current_case(input_values[test_parameter_name])
except (KeyError, ValueError):
return -1

context = ExpressionContext(input_values, context)
payload = {"context": context, "no_replacement_value": no_replacement_value}
payload = {
"context": context,
"no_replacement_value": no_replacement_value,
"allow_case_inference": allow_case_inference,
"unset_value": unset_value,
}
for input in inputs.values():
if isinstance(input, Repeat) or isinstance(input, UploadDataset):
values = input_values[input.name] = input_values.get(input.name, [])
Expand Down Expand Up @@ -411,16 +423,15 @@ def populate_state(
group_state = state[input.name]
if input.type == "repeat":
repeat_input = cast(Repeat, input)
if (
len(incoming[repeat_input.name]) > repeat_input.max
or len(incoming[repeat_input.name]) < repeat_input.min
repeat_name = repeat_input.name
repeat_incoming = incoming.get(repeat_name) or []
if repeat_incoming and (
len(repeat_incoming) > repeat_input.max or len(repeat_incoming) < repeat_input.min
):
errors[repeat_input.name] = (
"The number of repeat elements is outside the range specified by the tool."
)
errors[repeat_name] = "The number of repeat elements is outside the range specified by the tool."
else:
del group_state[:]
for rep in incoming[repeat_input.name]:
for rep in repeat_incoming:
new_state: ToolStateJobInstancePopulatedT = {}
group_state.append(new_state)
repeat_errors: ParameterValidationErrorsT = {}
Expand Down Expand Up @@ -454,10 +465,13 @@ def populate_state(
current_case = conditional_input.get_current_case(value)
group_state = state[conditional_input.name] = {}
cast_errors: ParameterValidationErrorsT = {}
incoming_for_conditional = cast(
ToolStateJobInstanceT, incoming.get(conditional_input.name) or {}
)
populate_state(
request_context,
conditional_input.cases[current_case].inputs,
cast(ToolStateJobInstanceT, incoming.get(conditional_input.name)),
incoming_for_conditional,
group_state,
cast_errors,
context=context,
Expand All @@ -475,10 +489,11 @@ def populate_state(
elif input.type == "section":
section_input = cast(Section, input)
section_errors: ParameterValidationErrorsT = {}
incoming_for_state = cast(ToolStateJobInstanceT, incoming.get(section_input.name) or {})
populate_state(
request_context,
section_input.inputs,
cast(ToolStateJobInstanceT, incoming.get(section_input.name)),
incoming_for_state,
group_state,
section_errors,
context=context,
Expand Down
150 changes: 124 additions & 26 deletions lib/galaxy/tools/parameters/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,19 @@
matching,
subcollections,
)
from galaxy.util import permutations
from galaxy.util.permutations import (
build_combos,
input_classification,
is_in_state,
state_copy,
state_get_value,
state_remove_value,
state_set_value,
)
from . import visit_input_values
from .wrapped import process_key
from .._types import (
InputFormatT,
ToolRequestT,
ToolStateJobInstanceT,
)
Expand Down Expand Up @@ -161,7 +170,15 @@ def is_batch(value):
ExpandedT = Tuple[List[ToolStateJobInstanceT], Optional[matching.MatchingCollections]]


def expand_meta_parameters(trans, tool, incoming: ToolRequestT) -> ExpandedT:
def expand_flat_parameters_to_nested(incoming_copy: ToolRequestT) -> Dict[str, Any]:
nested_dict: Dict[str, Any] = {}
for incoming_key, incoming_value in incoming_copy.items():
if not incoming_key.startswith("__"):
process_key(incoming_key, incoming_value=incoming_value, d=nested_dict)
return nested_dict


def expand_meta_parameters(trans, tool, incoming: ToolRequestT, input_format: InputFormatT) -> ExpandedT:
"""
Take in a dictionary of raw incoming parameters and expand to a list
of expanded incoming parameters (one set of parameters per tool
Expand All @@ -176,33 +193,24 @@ def expand_meta_parameters(trans, tool, incoming: ToolRequestT) -> ExpandedT:
# order matters, so the following reorders incoming
# according to tool.inputs (which is ordered).
incoming_copy = incoming.copy()
nested_dict: Dict[str, Any] = {}
for incoming_key, incoming_value in incoming_copy.items():
if not incoming_key.startswith("__"):
process_key(incoming_key, incoming_value=incoming_value, d=nested_dict)

reordered_incoming = {}

def visitor(input, value, prefix, prefixed_name, prefixed_label, error, **kwargs):
if prefixed_name in incoming_copy:
reordered_incoming[prefixed_name] = incoming_copy[prefixed_name]
del incoming_copy[prefixed_name]
if input_format == "legacy":
nested_dict = expand_flat_parameters_to_nested(incoming_copy)
else:
nested_dict = incoming_copy

visit_input_values(inputs=tool.inputs, input_values=nested_dict, callback=visitor)
reordered_incoming.update(incoming_copy)
collections_to_match = matching.CollectionsToMatch()

def classifier(input_key):
value = incoming[input_key]
def classifier_from_value(value, input_key):
if isinstance(value, dict) and "values" in value:
# Explicit meta wrapper for inputs...
is_batch = value.get("batch", False)
is_linked = value.get("linked", True)
if is_batch and is_linked:
classification = permutations.input_classification.MATCHED
classification = input_classification.MATCHED
elif is_batch:
classification = permutations.input_classification.MULTIPLIED
classification = input_classification.MULTIPLIED
else:
classification = permutations.input_classification.SINGLE
classification = input_classification.SINGLE
if __collection_multirun_parameter(value):
collection_value = value["values"][0]
values = __expand_collection_parameter(
Expand All @@ -211,24 +219,114 @@ def classifier(input_key):
else:
values = value["values"]
else:
classification = permutations.input_classification.SINGLE
classification = input_classification.SINGLE
values = value
return classification, values

collections_to_match = matching.CollectionsToMatch()
nested = input_format != "legacy"
if not nested:
reordered_incoming = reorder_parameters(tool, incoming_copy, nested_dict, nested)
incoming_template = reordered_incoming

def classifier_flat(input_key):
return classifier_from_value(incoming[input_key], input_key)

# Stick an unexpanded version of multirun keys so they can be replaced,
# by expand_mult_inputs.
incoming_template = reordered_incoming
single_inputs, matched_multi_inputs, multiplied_multi_inputs = split_inputs_flat(
incoming_template, classifier_flat
)
else:
reordered_incoming = reorder_parameters(tool, incoming_copy, nested_dict, nested)
incoming_template = reordered_incoming
single_inputs, matched_multi_inputs, multiplied_multi_inputs = split_inputs_nested(
tool.inputs, incoming_template, classifier_from_value
)

expanded_incomings = permutations.expand_multi_inputs(incoming_template, classifier)
expanded_incomings = build_combos(single_inputs, matched_multi_inputs, multiplied_multi_inputs, nested=nested)
if collections_to_match.has_collections():
collection_info = trans.app.dataset_collection_manager.match_collections(collections_to_match)
else:
collection_info = None
return expanded_incomings, collection_info


def reorder_parameters(tool, incoming, nested_dict, nested):
# If we're going to multiply input dataset combinations
# order matters, so the following reorders incoming
# according to tool.inputs (which is ordered).
incoming_copy = state_copy(incoming, nested)

reordered_incoming = {}

def visitor(input, value, prefix, prefixed_name, prefixed_label, error, **kwargs):
if is_in_state(incoming_copy, prefixed_name, nested):
value_to_copy_over = state_get_value(incoming_copy, prefixed_name, nested)
state_set_value(reordered_incoming, prefixed_name, value_to_copy_over, nested)
state_remove_value(incoming_copy, prefixed_name, nested)

visit_input_values(inputs=tool.inputs, input_values=nested_dict, callback=visitor)

def merge_into(from_object, into_object):
if isinstance(from_object, dict):
for key, value in from_object.items():
if key not in into_object:
into_object[key] = value
else:
into_target = into_object[key]
merge_into(value, into_target)
elif isinstance(from_object, list):
for index in from_object:
if len(into_object) <= index:
into_object.append(from_object[index])
else:
merge_into(from_object[index], into_object[index])

merge_into(incoming_copy, reordered_incoming)
return reordered_incoming


def split_inputs_flat(inputs: Dict[str, Any], classifier):
single_inputs: Dict[str, Any] = {}
matched_multi_inputs: Dict[str, Any] = {}
multiplied_multi_inputs: Dict[str, Any] = {}

for input_key in inputs:
input_type, expanded_val = classifier(input_key)
if input_type == input_classification.SINGLE:
single_inputs[input_key] = expanded_val
elif input_type == input_classification.MATCHED:
matched_multi_inputs[input_key] = expanded_val
elif input_type == input_classification.MULTIPLIED:
multiplied_multi_inputs[input_key] = expanded_val

return (single_inputs, matched_multi_inputs, multiplied_multi_inputs)


def split_inputs_nested(inputs, nested_dict, classifier):
single_inputs: Dict[str, Any] = {}
matched_multi_inputs: Dict[str, Any] = {}
multiplied_multi_inputs: Dict[str, Any] = {}
unset_value = object()

def visitor(input, value, prefix, prefixed_name, prefixed_label, error, **kwargs):
if value is unset_value:
# don't want to inject extra nulls into state
return

input_type, expanded_val = classifier(value, prefixed_name)
if input_type == input_classification.SINGLE:
single_inputs[prefixed_name] = expanded_val
elif input_type == input_classification.MATCHED:
matched_multi_inputs[prefixed_name] = expanded_val
elif input_type == input_classification.MULTIPLIED:
multiplied_multi_inputs[prefixed_name] = expanded_val

visit_input_values(
inputs=inputs, input_values=nested_dict, callback=visitor, allow_case_inference=True, unset_value=unset_value
)
single_inputs_nested = expand_flat_parameters_to_nested(single_inputs)
return (single_inputs_nested, matched_multi_inputs, multiplied_multi_inputs)


def __expand_collection_parameter(trans, input_key, incoming_val, collections_to_match, linked=False):
# If subcollectin multirun of data_collection param - value will
# be "hdca_id|subcollection_type" else it will just be hdca_id
Expand Down
9 changes: 6 additions & 3 deletions lib/galaxy/tools/parameters/wrapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
InputValueWrapper,
SelectToolParameterWrapper,
)
from galaxy.util.permutations import (
looks_like_flattened_repeat_key,
split_flattened_repeat_key,
)

PARAMS_UNWRAPPED = object()

Expand Down Expand Up @@ -172,10 +176,9 @@ def process_key(incoming_key: str, incoming_value: Any, d: Dict[str, Any]):
# In case we get an empty repeat after we already filled in a repeat element
return
d[incoming_key] = incoming_value
elif key_parts[0].rsplit("_", 1)[-1].isdigit():
elif looks_like_flattened_repeat_key(key_parts[0]):
# Repeat
input_name, _index = key_parts[0].rsplit("_", 1)
index = int(_index)
input_name, index = split_flattened_repeat_key(key_parts[0])
d.setdefault(input_name, [])
newlist: List[Dict[Any, Any]] = [{} for _ in range(index + 1)]
d[input_name].extend(newlist[len(d[input_name]) :])
Expand Down
2 changes: 2 additions & 0 deletions lib/galaxy/tools/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,8 @@ def __init__(self, input_datasets: Optional[Dict[str, Any]] = None) -> None:
self.identifier_key_dict = {}

def identifier(self, dataset_value: str, input_values: Dict[str, str]) -> Optional[str]:
if isinstance(dataset_value, list):
raise TypeError(f"Expected {dataset_value} to be hashable")
element_identifier = None
if identifier_key := self.identifier_key_dict.get(dataset_value, None):
element_identifier = input_values.get(identifier_key, None)
Expand Down
Loading
Loading