Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
df: memory: Input validation using operations
Browse files Browse the repository at this point in the history
Signed-off-by: John Andersen <[email protected]>
  • Loading branch information
aghinsa authored Mar 14, 2020
1 parent 6a32f39 commit f87b180
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 19 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
- Docstrings and doctestable examples to `record.py`.
- Inputs can be validated using operations
- `validate` parameter in `Input` takes `Operation.instance_name`
### Fixed
- New model tutorial mentions file paths that should be edited.

Expand Down
4 changes: 4 additions & 0 deletions dffml/df/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ class NotOpImp(Exception):

class InputValidationError(Exception):
pass


class ValidatorMissing(Exception):
pass
115 changes: 103 additions & 12 deletions dffml/df/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
Set,
)

from .exceptions import ContextNotPresent, DefinitionNotInContext
from .exceptions import (
ContextNotPresent,
DefinitionNotInContext,
ValidatorMissing,
)
from .types import Input, Parameter, Definition, Operation, Stage, DataFlow
from .base import (
OperationException,
Expand Down Expand Up @@ -122,6 +126,26 @@ async def inputs(self) -> AsyncIterator[Input]:
for item in self.__inputs:
yield item

def remove_input(self, item: Input):
for x in self.__inputs[:]:
if x.uid == item.uid:
self.__inputs.remove(x)
break

def remove_unvalidated_inputs(self) -> "MemoryInputSet":
"""
Removes `unvalidated` inputs from internal list and returns the same.
"""
unvalidated_inputs = []
for x in self.__inputs[:]:
if not x.validated:
unvalidated_inputs.append(x)
self.__inputs.remove(x)
unvalidated_input_set = MemoryInputSet(
MemoryInputSetConfig(ctx=self.ctx, inputs=unvalidated_inputs)
)
return unvalidated_input_set


class MemoryParameterSetConfig(NamedTuple):
ctx: BaseInputSetContext
Expand Down Expand Up @@ -249,15 +273,19 @@ async def add(self, input_set: BaseInputSet):
handle_string = handle.as_string()
# TODO These ctx.add calls should probably happen after inputs are in
# self.ctxhd

# remove unvalidated inputs
unvalidated_input_set = input_set.remove_unvalidated_inputs()

# If the context for this input set does not exist create a
# NotificationSet for it to notify the orchestrator
if not handle_string in self.input_notification_set:
self.input_notification_set[handle_string] = NotificationSet()
async with self.ctx_notification_set() as ctx:
await ctx.add(input_set.ctx)
await ctx.add((None, input_set.ctx))
# Add the input set to the incoming inputs
async with self.input_notification_set[handle_string]() as ctx:
await ctx.add(input_set)
await ctx.add((unvalidated_input_set, input_set))
# Associate inputs with their context handle grouped by definition
async with self.ctxhd_lock:
# Create dict for handle_string if not present
Expand Down Expand Up @@ -921,6 +949,7 @@ async def run_dispatch(
octx: BaseOrchestratorContext,
operation: Operation,
parameter_set: BaseParameterSet,
set_valid: bool = True,
):
"""
Run an operation in the background and add its outputs to the input
Expand Down Expand Up @@ -952,14 +981,14 @@ async def run_dispatch(
if not key in expand:
output = [output]
for value in output:
inputs.append(
Input(
value=value,
definition=operation.outputs[key],
parents=parents,
origin=(operation.instance_name, key),
)
new_input = Input(
value=value,
definition=operation.outputs[key],
parents=parents,
origin=(operation.instance_name, key),
)
new_input.validated = set_valid
inputs.append(new_input)
except KeyError as error:
raise KeyError(
"Value %s missing from output:definition mapping %s(%s)"
Expand Down Expand Up @@ -1020,6 +1049,38 @@ async def operations_parameter_set_pairs(
):
yield operation, parameter_set

async def validator_target_set_pairs(
self,
octx: BaseOperationNetworkContext,
rctx: BaseRedundancyCheckerContext,
ctx: BaseInputSetContext,
dataflow: DataFlow,
unvalidated_input_set: BaseInputSet,
):
async for unvalidated_input in unvalidated_input_set.inputs():
validator_instance_name = unvalidated_input.definition.validate
validator = dataflow.validators.get(validator_instance_name, None)
if validator is None:
raise ValidatorMissing(
"Validator with instance_name {validator_instance_name} not found"
)
# There is only one `input` in `validators`
input_name, input_definition = list(validator.inputs.items())[0]
parameter = Parameter(
key=input_name,
value=unvalidated_input.value,
origin=unvalidated_input,
definition=input_definition,
)
parameter_set = MemoryParameterSet(
MemoryParameterSetConfig(ctx=ctx, parameters=[parameter])
)
async for parameter_set, exists in rctx.exists(
validator, parameter_set
):
if not exists:
yield validator, parameter_set


@entrypoint("memory")
class MemoryOperationImplementationNetwork(
Expand Down Expand Up @@ -1382,17 +1443,44 @@ async def run_operations_for_ctx(
task.print_stack(file=output)
self.logger.error("%s", output.getvalue().rstrip())
output.close()

elif task is input_set_enters_network:
(
more,
new_input_sets,
) = input_set_enters_network.result()
for new_input_set in new_input_sets:
for (
unvalidated_input_set,
new_input_set,
) in new_input_sets:
async for operation, parameter_set in self.nctx.validator_target_set_pairs(
self.octx,
self.rctx,
ctx,
self.config.dataflow,
unvalidated_input_set,
):
await self.rctx.add(
operation, parameter_set
) # is this required here?
dispatch_operation = await self.nctx.dispatch(
self, operation, parameter_set
)
dispatch_operation.operation = operation
dispatch_operation.parameter_set = (
parameter_set
)
tasks.add(dispatch_operation)
self.logger.debug(
"[%s]: dispatch operation: %s",
ctx_str,
operation.instance_name,
)
# forward inputs to subflow
await self.forward_inputs_to_subflow(
[x async for x in new_input_set.inputs()]
)
# Identify which operations have complete contextually
# Identify which operations have completed contextually
# appropriate input sets which haven't been run yet
async for operation, parameter_set in self.nctx.operations_parameter_set_pairs(
self.ictx,
Expand All @@ -1402,6 +1490,9 @@ async def run_operations_for_ctx(
self.config.dataflow,
new_input_set=new_input_set,
):
# Validation operations shouldn't be run here
if operation.validator:
continue
# Add inputs and operation to redundancy checker before
# dispatch
await self.rctx.add(operation, parameter_set)
Expand Down
18 changes: 14 additions & 4 deletions dffml/df/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class Operation(NamedTuple, Entrypoint):
conditions: Optional[List[Definition]] = []
expand: Optional[List[str]] = []
instance_name: Optional[str] = None
validator: bool = False

def export(self):
exported = {
Expand Down Expand Up @@ -270,11 +271,13 @@ def __init__(
definition: Definition,
parents: Optional[List["Input"]] = None,
origin: Optional[Union[str, Tuple[Operation, str]]] = "seed",
validated: bool = True,
*,
uid: Optional[str] = "",
):
# TODO Add optional parameter Input.target which specifies the operation
# instance name this Input is intended for.
self.validated = validated
if parents is None:
parents = []
if definition.spec is not None:
Expand All @@ -288,7 +291,11 @@ def __init__(
elif isinstance(value, dict):
value = definition.spec(**value)
if definition.validate is not None:
value = definition.validate(value)
if callable(definition.validate):
value = definition.validate(value)
# if validate is a string (operation.instance_name) set `not validated`
elif isinstance(definition.validate, str):
self.validated = False
self.value = value
self.definition = definition
self.parents = parents
Expand Down Expand Up @@ -424,6 +431,8 @@ def __post_init__(self):
self.by_origin = {}
if self.implementations is None:
self.implementations = {}
self.validators = {} # Maps `validator` ops instance_name to op

# Allow callers to pass in functions decorated with op. Iterate over the
# given operations and replace any which have been decorated with their
# operation. Add the implementation to our dict of implementations.
Expand Down Expand Up @@ -451,9 +460,10 @@ def __post_init__(self):
self.operations[instance_name] = operation
value = operation
# Make sure every operation has the correct instance name
self.operations[instance_name] = value._replace(
instance_name=instance_name
)
value = value._replace(instance_name=instance_name)
self.operations[instance_name] = value
if value.validator:
self.validators[instance_name] = value
# Grab all definitions from operations
operations = list(self.operations.values())
definitions = list(
Expand Down
2 changes: 1 addition & 1 deletion examples/shouldi/tests/test_npm_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class TestRunNPM_AuditOp(AsyncTestCase):
"36b3ce51780ee6ea8dcec266c9d09e3a00198868ba1b041569950b82cf45884da0c47ec354dd8514022169849dfe8b7c",
)
async def test_run(self, npm_audit, javascript_algo):
with prepend_to_path(npm_audit / "bin",):
with prepend_to_path(npm_audit / "bin"):
results = await run_npm_audit(
str(
javascript_algo
Expand Down
2 changes: 1 addition & 1 deletion model/scikit/dffml_model_scikit/scikit_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def applicable_features(self, features):
field(
"Directory where state should be saved",
default=pathlib.Path(
"~", ".cache", "dffml", f"scikit-{entry_point_name}",
"~", ".cache", "dffml", f"scikit-{entry_point_name}"
),
),
),
Expand Down
2 changes: 1 addition & 1 deletion scripts/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def gen_docs(

def fake_getpwuid(uid):
return pwd.struct_passwd(
("user", "x", uid, uid, "", "/home/user", "/bin/bash",)
("user", "x", uid, uid, "", "/home/user", "/bin/bash")
)


Expand Down
49 changes: 49 additions & 0 deletions tests/test_types.py → tests/test_input_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def pie_validation(x):
ShapeName = Definition(
name="shape_name", primitive="str", validate=lambda x: x.upper()
)
SHOUTIN = Definition(
name="shout_in", primitive="str", validate="validate_shout_instance"
)
SHOUTOUT = Definition(name="shout_out", primitive="str")


@op(
Expand All @@ -35,6 +39,20 @@ async def get_circle(name: str, radius: float, pie: float):
}


@op(
inputs={"shout_in": SHOUTIN},
outputs={"shout_in_validated": SHOUTIN},
validator=True,
)
def validate_shouts(shout_in):
return {"shout_in_validated": shout_in + "_validated"}


@op(inputs={"shout_in": SHOUTIN}, outputs={"shout_out": SHOUTOUT})
def echo_shout(shout_in):
return {"shout_out": shout_in}


class TestDefintion(AsyncTestCase):
async def setUp(self):
self.dataflow = DataFlow(
Expand Down Expand Up @@ -80,3 +98,34 @@ async def test_validation_error(self):
]
}
pass

async def test_vaildation_by_op(self):
test_dataflow = DataFlow(
operations={
"validate_shout_instance": validate_shouts.op,
"echo_shout": echo_shout.op,
"get_single": GetSingle.imp.op,
},
seed=[
Input(
value=[echo_shout.op.outputs["shout_out"].name],
definition=GetSingle.op.inputs["spec"],
)
],
implementations={
validate_shouts.op.name: validate_shouts.imp,
echo_shout.op.name: echo_shout.imp,
},
)
test_inputs = {
"TestShoutOut": [
Input(value="validation_status:", definition=SHOUTIN)
]
}
async with MemoryOrchestrator.withconfig({}) as orchestrator:
async with orchestrator(test_dataflow) as octx:
async for ctx_str, results in octx.run(test_inputs):
self.assertIn("shout_out", results)
self.assertEqual(
results["shout_out"], "validation_status:_validated"
)

0 comments on commit f87b180

Please sign in to comment.