diff --git a/CHANGELOG.md b/CHANGELOG.md index 3afddf3ed5..6a186bf7c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added - Docstrings and doctestable examples to `record.py`. +- Inputs can be validated using operations + - `validate` parameter in `Input` takes `Operation.instance_name` ### Fixed - New model tutorial mentions file paths that should be edited. diff --git a/dffml/df/exceptions.py b/dffml/df/exceptions.py index 73764205e4..f3d364e8ef 100644 --- a/dffml/df/exceptions.py +++ b/dffml/df/exceptions.py @@ -16,3 +16,7 @@ class NotOpImp(Exception): class InputValidationError(Exception): pass + + +class ValidatorMissing(Exception): + pass diff --git a/dffml/df/memory.py b/dffml/df/memory.py index d839796319..43cd8f1329 100644 --- a/dffml/df/memory.py +++ b/dffml/df/memory.py @@ -20,7 +20,11 @@ Set, ) -from .exceptions import ContextNotPresent, DefinitionNotInContext +from .exceptions import ( + ContextNotPresent, + DefinitionNotInContext, + ValidatorMissing, +) from .types import Input, Parameter, Definition, Operation, Stage, DataFlow from .base import ( OperationException, @@ -122,6 +126,26 @@ async def inputs(self) -> AsyncIterator[Input]: for item in self.__inputs: yield item + def remove_input(self, item: Input): + for x in self.__inputs[:]: + if x.uid == item.uid: + self.__inputs.remove(x) + break + + def remove_unvalidated_inputs(self) -> "MemoryInputSet": + """ + Removes `unvalidated` inputs from internal list and returns the same. + """ + unvalidated_inputs = [] + for x in self.__inputs[:]: + if not x.validated: + unvalidated_inputs.append(x) + self.__inputs.remove(x) + unvalidated_input_set = MemoryInputSet( + MemoryInputSetConfig(ctx=self.ctx, inputs=unvalidated_inputs) + ) + return unvalidated_input_set + class MemoryParameterSetConfig(NamedTuple): ctx: BaseInputSetContext @@ -249,15 +273,19 @@ async def add(self, input_set: BaseInputSet): handle_string = handle.as_string() # TODO These ctx.add calls should probably happen after inputs are in # self.ctxhd + + # remove unvalidated inputs + unvalidated_input_set = input_set.remove_unvalidated_inputs() + # If the context for this input set does not exist create a # NotificationSet for it to notify the orchestrator if not handle_string in self.input_notification_set: self.input_notification_set[handle_string] = NotificationSet() async with self.ctx_notification_set() as ctx: - await ctx.add(input_set.ctx) + await ctx.add((None, input_set.ctx)) # Add the input set to the incoming inputs async with self.input_notification_set[handle_string]() as ctx: - await ctx.add(input_set) + await ctx.add((unvalidated_input_set, input_set)) # Associate inputs with their context handle grouped by definition async with self.ctxhd_lock: # Create dict for handle_string if not present @@ -921,6 +949,7 @@ async def run_dispatch( octx: BaseOrchestratorContext, operation: Operation, parameter_set: BaseParameterSet, + set_valid: bool = True, ): """ Run an operation in the background and add its outputs to the input @@ -952,14 +981,14 @@ async def run_dispatch( if not key in expand: output = [output] for value in output: - inputs.append( - Input( - value=value, - definition=operation.outputs[key], - parents=parents, - origin=(operation.instance_name, key), - ) + new_input = Input( + value=value, + definition=operation.outputs[key], + parents=parents, + origin=(operation.instance_name, key), ) + new_input.validated = set_valid + inputs.append(new_input) except KeyError as error: raise KeyError( "Value %s missing from output:definition mapping %s(%s)" @@ -1020,6 +1049,38 @@ async def operations_parameter_set_pairs( ): yield operation, parameter_set + async def validator_target_set_pairs( + self, + octx: BaseOperationNetworkContext, + rctx: BaseRedundancyCheckerContext, + ctx: BaseInputSetContext, + dataflow: DataFlow, + unvalidated_input_set: BaseInputSet, + ): + async for unvalidated_input in unvalidated_input_set.inputs(): + validator_instance_name = unvalidated_input.definition.validate + validator = dataflow.validators.get(validator_instance_name, None) + if validator is None: + raise ValidatorMissing( + "Validator with instance_name {validator_instance_name} not found" + ) + # There is only one `input` in `validators` + input_name, input_definition = list(validator.inputs.items())[0] + parameter = Parameter( + key=input_name, + value=unvalidated_input.value, + origin=unvalidated_input, + definition=input_definition, + ) + parameter_set = MemoryParameterSet( + MemoryParameterSetConfig(ctx=ctx, parameters=[parameter]) + ) + async for parameter_set, exists in rctx.exists( + validator, parameter_set + ): + if not exists: + yield validator, parameter_set + @entrypoint("memory") class MemoryOperationImplementationNetwork( @@ -1382,17 +1443,44 @@ async def run_operations_for_ctx( task.print_stack(file=output) self.logger.error("%s", output.getvalue().rstrip()) output.close() + elif task is input_set_enters_network: ( more, new_input_sets, ) = input_set_enters_network.result() - for new_input_set in new_input_sets: + for ( + unvalidated_input_set, + new_input_set, + ) in new_input_sets: + async for operation, parameter_set in self.nctx.validator_target_set_pairs( + self.octx, + self.rctx, + ctx, + self.config.dataflow, + unvalidated_input_set, + ): + await self.rctx.add( + operation, parameter_set + ) # is this required here? + dispatch_operation = await self.nctx.dispatch( + self, operation, parameter_set + ) + dispatch_operation.operation = operation + dispatch_operation.parameter_set = ( + parameter_set + ) + tasks.add(dispatch_operation) + self.logger.debug( + "[%s]: dispatch operation: %s", + ctx_str, + operation.instance_name, + ) # forward inputs to subflow await self.forward_inputs_to_subflow( [x async for x in new_input_set.inputs()] ) - # Identify which operations have complete contextually + # Identify which operations have completed contextually # appropriate input sets which haven't been run yet async for operation, parameter_set in self.nctx.operations_parameter_set_pairs( self.ictx, @@ -1402,6 +1490,9 @@ async def run_operations_for_ctx( self.config.dataflow, new_input_set=new_input_set, ): + # Validation operations shouldn't be run here + if operation.validator: + continue # Add inputs and operation to redundancy checker before # dispatch await self.rctx.add(operation, parameter_set) diff --git a/dffml/df/types.py b/dffml/df/types.py index 1a09aaa436..faa802608e 100644 --- a/dffml/df/types.py +++ b/dffml/df/types.py @@ -122,6 +122,7 @@ class Operation(NamedTuple, Entrypoint): conditions: Optional[List[Definition]] = [] expand: Optional[List[str]] = [] instance_name: Optional[str] = None + validator: bool = False def export(self): exported = { @@ -270,11 +271,13 @@ def __init__( definition: Definition, parents: Optional[List["Input"]] = None, origin: Optional[Union[str, Tuple[Operation, str]]] = "seed", + validated: bool = True, *, uid: Optional[str] = "", ): # TODO Add optional parameter Input.target which specifies the operation # instance name this Input is intended for. + self.validated = validated if parents is None: parents = [] if definition.spec is not None: @@ -288,7 +291,11 @@ def __init__( elif isinstance(value, dict): value = definition.spec(**value) if definition.validate is not None: - value = definition.validate(value) + if callable(definition.validate): + value = definition.validate(value) + # if validate is a string (operation.instance_name) set `not validated` + elif isinstance(definition.validate, str): + self.validated = False self.value = value self.definition = definition self.parents = parents @@ -424,6 +431,8 @@ def __post_init__(self): self.by_origin = {} if self.implementations is None: self.implementations = {} + self.validators = {} # Maps `validator` ops instance_name to op + # Allow callers to pass in functions decorated with op. Iterate over the # given operations and replace any which have been decorated with their # operation. Add the implementation to our dict of implementations. @@ -451,9 +460,10 @@ def __post_init__(self): self.operations[instance_name] = operation value = operation # Make sure every operation has the correct instance name - self.operations[instance_name] = value._replace( - instance_name=instance_name - ) + value = value._replace(instance_name=instance_name) + self.operations[instance_name] = value + if value.validator: + self.validators[instance_name] = value # Grab all definitions from operations operations = list(self.operations.values()) definitions = list( diff --git a/examples/shouldi/tests/test_npm_audit.py b/examples/shouldi/tests/test_npm_audit.py index 6214e68c60..06af9726bf 100644 --- a/examples/shouldi/tests/test_npm_audit.py +++ b/examples/shouldi/tests/test_npm_audit.py @@ -21,7 +21,7 @@ class TestRunNPM_AuditOp(AsyncTestCase): "36b3ce51780ee6ea8dcec266c9d09e3a00198868ba1b041569950b82cf45884da0c47ec354dd8514022169849dfe8b7c", ) async def test_run(self, npm_audit, javascript_algo): - with prepend_to_path(npm_audit / "bin",): + with prepend_to_path(npm_audit / "bin"): results = await run_npm_audit( str( javascript_algo diff --git a/model/scikit/dffml_model_scikit/scikit_models.py b/model/scikit/dffml_model_scikit/scikit_models.py index 57490a61c2..4488e4a1a0 100644 --- a/model/scikit/dffml_model_scikit/scikit_models.py +++ b/model/scikit/dffml_model_scikit/scikit_models.py @@ -226,7 +226,7 @@ def applicable_features(self, features): field( "Directory where state should be saved", default=pathlib.Path( - "~", ".cache", "dffml", f"scikit-{entry_point_name}", + "~", ".cache", "dffml", f"scikit-{entry_point_name}" ), ), ), diff --git a/scripts/docs.py b/scripts/docs.py index 7dc7e93d31..afada08655 100644 --- a/scripts/docs.py +++ b/scripts/docs.py @@ -212,7 +212,7 @@ def gen_docs( def fake_getpwuid(uid): return pwd.struct_passwd( - ("user", "x", uid, uid, "", "/home/user", "/bin/bash",) + ("user", "x", uid, uid, "", "/home/user", "/bin/bash") ) diff --git a/tests/test_types.py b/tests/test_input_validation.py similarity index 62% rename from tests/test_types.py rename to tests/test_input_validation.py index 1cbec35696..093bc3726d 100644 --- a/tests/test_types.py +++ b/tests/test_input_validation.py @@ -19,6 +19,10 @@ def pie_validation(x): ShapeName = Definition( name="shape_name", primitive="str", validate=lambda x: x.upper() ) +SHOUTIN = Definition( + name="shout_in", primitive="str", validate="validate_shout_instance" +) +SHOUTOUT = Definition(name="shout_out", primitive="str") @op( @@ -35,6 +39,20 @@ async def get_circle(name: str, radius: float, pie: float): } +@op( + inputs={"shout_in": SHOUTIN}, + outputs={"shout_in_validated": SHOUTIN}, + validator=True, +) +def validate_shouts(shout_in): + return {"shout_in_validated": shout_in + "_validated"} + + +@op(inputs={"shout_in": SHOUTIN}, outputs={"shout_out": SHOUTOUT}) +def echo_shout(shout_in): + return {"shout_out": shout_in} + + class TestDefintion(AsyncTestCase): async def setUp(self): self.dataflow = DataFlow( @@ -80,3 +98,34 @@ async def test_validation_error(self): ] } pass + + async def test_vaildation_by_op(self): + test_dataflow = DataFlow( + operations={ + "validate_shout_instance": validate_shouts.op, + "echo_shout": echo_shout.op, + "get_single": GetSingle.imp.op, + }, + seed=[ + Input( + value=[echo_shout.op.outputs["shout_out"].name], + definition=GetSingle.op.inputs["spec"], + ) + ], + implementations={ + validate_shouts.op.name: validate_shouts.imp, + echo_shout.op.name: echo_shout.imp, + }, + ) + test_inputs = { + "TestShoutOut": [ + Input(value="validation_status:", definition=SHOUTIN) + ] + } + async with MemoryOrchestrator.withconfig({}) as orchestrator: + async with orchestrator(test_dataflow) as octx: + async for ctx_str, results in octx.run(test_inputs): + self.assertIn("shout_out", results) + self.assertEqual( + results["shout_out"], "validation_status:_validated" + )