Skip to content

Commit

Permalink
Merge pull request #19027 from jmchilton/parameter_model_enhancements
Browse files Browse the repository at this point in the history
A variety of improvements around tool parameter modeling (second try)
  • Loading branch information
bgruening authored Oct 29, 2024
2 parents 29cd291 + e053eee commit d11d893
Show file tree
Hide file tree
Showing 53 changed files with 2,931 additions and 692 deletions.
16 changes: 16 additions & 0 deletions lib/galaxy/tool_util/parameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
encode,
encode_test,
fill_static_defaults,
landing_decode,
landing_encode,
)
from .factory import (
from_input_source,
Expand Down Expand Up @@ -39,6 +41,7 @@
HiddenParameterModel,
IntegerParameterModel,
LabelValue,
RawStateDict,
RepeatParameterModel,
RulesParameterModel,
SelectParameterModel,
Expand All @@ -49,15 +52,20 @@
ToolParameterT,
validate_against_model,
validate_internal_job,
validate_internal_landing_request,
validate_internal_request,
validate_internal_request_dereferenced,
validate_landing_request,
validate_request,
validate_test_case,
validate_workflow_step,
validate_workflow_step_linked,
ValidationFunctionT,
)
from .state import (
JobInternalToolState,
LandingRequestInternalToolState,
LandingRequestToolState,
RequestInternalDereferencedToolState,
RequestInternalToolState,
RequestToolState,
Expand Down Expand Up @@ -113,10 +121,14 @@
"ConditionalParameterModel",
"ConditionalWhen",
"RepeatParameterModel",
"RawStateDict",
"ValidationFunctionT",
"validate_against_model",
"validate_internal_job",
"validate_internal_landing_request",
"validate_internal_request",
"validate_internal_request_dereferenced",
"validate_landing_request",
"validate_request",
"validate_test_case",
"validate_workflow_step",
Expand All @@ -130,6 +142,8 @@
"RequestToolState",
"RequestInternalToolState",
"RequestInternalDereferencedToolState",
"LandingRequestToolState",
"LandingRequestInternalToolState",
"flat_state_path",
"keys_starting_with",
"visit_input_values",
Expand All @@ -139,6 +153,8 @@
"encode",
"encode_test",
"fill_static_defaults",
"landing_decode",
"landing_encode",
"dereference",
"WorkflowStepToolState",
"WorkflowStepLinkedToolState",
Expand Down
11 changes: 11 additions & 0 deletions lib/galaxy/tool_util/parameters/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

from typing import (
Any,
cast,
List,
Optional,
Expand All @@ -15,6 +16,7 @@

# https://stackoverflow.com/questions/56832881/check-if-a-field-is-typing-optional
from typing_extensions import (
Annotated,
get_args,
get_origin,
)
Expand Down Expand Up @@ -46,3 +48,12 @@ def cast_as_type(arg) -> Type:

def is_optional(field) -> bool:
return get_origin(field) is Union and type(None) in get_args(field)


def expand_annotation(field: Type, new_annotations: List[Any]) -> Type:
is_annotation = get_origin(field) is Annotated
if is_annotation:
args = get_args(field) # noqa: F841
return Annotated[tuple([args[0], *args[1:], *new_annotations])] # type: ignore[return-value]
else:
return Annotated[tuple([field, *new_annotations])] # type: ignore[return-value]
173 changes: 110 additions & 63 deletions lib/galaxy/tool_util/parameters/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,15 @@
)
from .state import (
JobInternalToolState,
LandingRequestInternalToolState,
LandingRequestToolState,
RequestInternalDereferencedToolState,
RequestInternalToolState,
RequestToolState,
TestCaseToolState,
)
from .visitor import (
Callback,
validate_explicit_conditional_test_value,
visit_input_values,
VISITOR_NO_REPLACEMENT,
Expand All @@ -54,40 +57,22 @@
log = logging.getLogger(__name__)


DecodeFunctionT = Callable[[str], int]
EncodeFunctionT = Callable[[int], str]
DereferenceCallable = Callable[[DataRequestUri], DataRequestInternalHda]
# interfaces for adapting test data dictionaries to tool request dictionaries
# e.g. {class: File, path: foo.bed} => {src: hda, id: ab1235cdfea3}
AdaptDatasets = Callable[[JsonTestDatasetDefDict], DataRequestHda]
AdaptCollections = Callable[[JsonTestCollectionDefDict], DataCollectionRequest]


def decode(
external_state: RequestToolState, input_models: ToolParameterBundle, decode_id: Callable[[str], int]
) -> RequestInternalToolState:
"""Prepare an external representation of tool state (request) for storing in the database (request_internal)."""
"""Prepare an internal representation of tool state (request_internal) for storing in the database."""

external_state.validate(input_models)

def decode_src_dict(src_dict: dict):
if "id" in src_dict:
decoded_dict = src_dict.copy()
decoded_dict["id"] = decode_id(src_dict["id"])
return decoded_dict
else:
return src_dict

def decode_callback(parameter: ToolParameterT, value: Any):
if parameter.parameter_type == "gx_data":
if value is None:
return VISITOR_NO_REPLACEMENT
data_parameter = cast(DataParameterModel, parameter)
if data_parameter.multiple:
assert isinstance(value, list), str(value)
return list(map(decode_src_dict, value))
else:
assert isinstance(value, dict), str(value)
return decode_src_dict(value)
elif parameter.parameter_type == "gx_data_collection":
if value is None:
return VISITOR_NO_REPLACEMENT
assert isinstance(value, dict), str(value)
return decode_src_dict(value)
else:
return VISITOR_NO_REPLACEMENT

decode_callback = _decode_callback_for(decode_id)
internal_state_dict = visit_input_values(
input_models,
external_state,
Expand All @@ -100,44 +85,53 @@ def decode_callback(parameter: ToolParameterT, value: Any):


def encode(
external_state: RequestInternalToolState, input_models: ToolParameterBundle, encode_id: Callable[[int], str]
internal_state: RequestInternalToolState, input_models: ToolParameterBundle, encode_id: EncodeFunctionT
) -> RequestToolState:
"""Prepare an external representation of tool state (request) for storing in the database (request_internal)."""

def encode_src_dict(src_dict: dict):
if "id" in src_dict:
encoded_dict = src_dict.copy()
encoded_dict["id"] = encode_id(src_dict["id"])
return encoded_dict
else:
return src_dict

def encode_callback(parameter: ToolParameterT, value: Any):
if parameter.parameter_type == "gx_data":
data_parameter = cast(DataParameterModel, parameter)
if data_parameter.multiple:
assert isinstance(value, list), str(value)
return list(map(encode_src_dict, value))
else:
assert isinstance(value, dict), str(value)
return encode_src_dict(value)
elif parameter.parameter_type == "gx_data_collection":
assert isinstance(value, dict), str(value)
return encode_src_dict(value)
else:
return VISITOR_NO_REPLACEMENT
"""Prepare an external representation of tool state (request) from persisted state in the database (request_internal)."""

encode_callback = _encode_callback_for(encode_id)
request_state_dict = visit_input_values(
input_models,
external_state,
internal_state,
encode_callback,
)
request_state = RequestToolState(request_state_dict)
request_state.validate(input_models)
return request_state


DereferenceCallable = Callable[[DataRequestUri], DataRequestInternalHda]
def landing_decode(
external_state: LandingRequestToolState, input_models: ToolParameterBundle, decode_id: Callable[[str], int]
) -> LandingRequestInternalToolState:
"""Prepare an external representation of tool state (request) for storing in the database (request_internal)."""

external_state.validate(input_models)
decode_callback = _decode_callback_for(decode_id)
internal_state_dict = visit_input_values(
input_models,
external_state,
decode_callback,
)

internal_request_state = LandingRequestInternalToolState(internal_state_dict)
internal_request_state.validate(input_models)
return internal_request_state


def landing_encode(
internal_state: LandingRequestInternalToolState, input_models: ToolParameterBundle, encode_id: EncodeFunctionT
) -> LandingRequestToolState:
"""Prepare an external representation of tool state (request) for storing in the database (request_internal)."""

encode_callback = _encode_callback_for(encode_id)
request_state_dict = visit_input_values(
input_models,
internal_state,
encode_callback,
)
request_state = LandingRequestToolState(request_state_dict)
request_state.validate(input_models)
return request_state


def dereference(
Expand Down Expand Up @@ -177,12 +171,6 @@ def dereference_callback(parameter: ToolParameterT, value: Any):
return request_state


# interfaces for adapting test data dictionaries to tool request dictionaries
# e.g. {class: File, path: foo.bed} => {src: hda, id: ab1235cdfea3}
AdaptDatasets = Callable[[JsonTestDatasetDefDict], DataRequestHda]
AdaptCollections = Callable[[JsonTestCollectionDefDict], DataCollectionRequest]


def encode_test(
test_case_state: TestCaseToolState,
input_models: ToolParameterBundle,
Expand Down Expand Up @@ -324,7 +312,6 @@ def _fill_default_for(tool_state: Dict[str, Any], parameter: ToolParameterT) ->
)
test_value = validate_explicit_conditional_test_value(test_parameter_name, explicit_test_value)
when = _select_which_when(conditional, test_value, conditional_state)
test_parameter = conditional.test_parameter
_fill_default_for(conditional_state, test_parameter)
_fill_defaults(conditional_state, when)
elif parameter_type in ["gx_repeat"]:
Expand Down Expand Up @@ -358,3 +345,63 @@ def _select_which_when(
raise Exception(
f"Invalid conditional test value ({test_value}) for parameter ({conditional.test_parameter.name})"
)


def _encode_callback_for(encode_id: EncodeFunctionT) -> Callback:

def encode_src_dict(src_dict: dict):
if "id" in src_dict:
encoded_dict = src_dict.copy()
encoded_dict["id"] = encode_id(src_dict["id"])
return encoded_dict
else:
return src_dict

def encode_callback(parameter: ToolParameterT, value: Any):
if parameter.parameter_type == "gx_data":
data_parameter = cast(DataParameterModel, parameter)
if data_parameter.multiple:
assert isinstance(value, list), str(value)
return list(map(encode_src_dict, value))
else:
assert isinstance(value, dict), str(value)
return encode_src_dict(value)
elif parameter.parameter_type == "gx_data_collection":
assert isinstance(value, dict), str(value)
return encode_src_dict(value)
else:
return VISITOR_NO_REPLACEMENT

return encode_callback


def _decode_callback_for(decode_id: DecodeFunctionT) -> Callback:

def decode_src_dict(src_dict: dict):
if "id" in src_dict:
decoded_dict = src_dict.copy()
decoded_dict["id"] = decode_id(src_dict["id"])
return decoded_dict
else:
return src_dict

def decode_callback(parameter: ToolParameterT, value: Any):
if parameter.parameter_type == "gx_data":
if value is None:
return VISITOR_NO_REPLACEMENT
data_parameter = cast(DataParameterModel, parameter)
if data_parameter.multiple:
assert isinstance(value, list), str(value)
return list(map(decode_src_dict, value))
else:
assert isinstance(value, dict), str(value)
return decode_src_dict(value)
elif parameter.parameter_type == "gx_data_collection":
if value is None:
return VISITOR_NO_REPLACEMENT
assert isinstance(value, dict), str(value)
return decode_src_dict(value)
else:
return VISITOR_NO_REPLACEMENT

return decode_callback
Loading

0 comments on commit d11d893

Please sign in to comment.