From 37fd191ff8009793bd7bdaa10b020d9213d23468 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Mon, 29 Apr 2024 11:42:09 -0700 Subject: [PATCH] [minor] Factories instead of meta nodes, and use them for transformer nodes (#293) * Replace evaluated code to-/from-list nodes with actual classes Leveraging custom constructors * Develop meta abstraction and make a new home for transforming meta nodes * Accept node args too * Introduce a decorator for building the class-IO * Change paradigm to whether or not the node uses __reduced__ and a constructor Instead of "Meta" nodes * Allow direct use of Constructed children * Move and update constructed stuff * Add new singleton behaviour so factory-produced classes can pass is-tests * Apply constructed and class registration updates to the transformers * Remove (now unused) meta module * PEP8 newline * Remove unnecessary __getstate__ The object isn't holding instance level state and older versions of python bork here. * Add constructed __*state__ compatibility for older versions * :bug: add missing `return` * Format black * Introduce a new factory pattern At the cost of requiring factory functions to forgo kwargs, we get object matching for factories and classes, and picklability for instances. Still need to check some edge cases around the use of stuff with a non-trivial qualname. * Test factories as methods * Test local scoping * Add docstrings and hide a function * Simplify super * Give demo class in tests have a more generic __init__ * Test and document multiple inheritance * Stop storing __init__ args By leveraging `__new__` and `__getnewargs_ex__` * Add clearing methods In case you want to change some constructor behaviour and clear the access cache. Not sure what should be cached at the end of the day, but for now just give users a shortcut for clearing it. * Use snippets.factory.classfactory and typing.ClassVar for transformers * Revert singleton * Format black * Remove constructed It's superceded by the snippets.factory stuff * Gently rename things for when the factory comes from a decorator * Allow factory made classes to also come from decorators * Format black --------- Co-authored-by: pyiron-runner --- pyiron_workflow/create.py | 12 +- pyiron_workflow/io_preview.py | 23 +- pyiron_workflow/loops.py | 6 +- pyiron_workflow/meta.py | 46 --- pyiron_workflow/snippets/factory.py | 412 ++++++++++++++++++++++ pyiron_workflow/transform.py | 144 ++++++++ tests/integration/test_transform.py | 45 +++ tests/unit/snippets/test_factory.py | 482 ++++++++++++++++++++++++++ tests/unit/snippets/test_singleton.py | 1 - 9 files changed, 1109 insertions(+), 62 deletions(-) delete mode 100644 pyiron_workflow/meta.py create mode 100644 pyiron_workflow/snippets/factory.py create mode 100644 pyiron_workflow/transform.py create mode 100644 tests/integration/test_transform.py create mode 100644 tests/unit/snippets/test_factory.py diff --git a/pyiron_workflow/create.py b/pyiron_workflow/create.py index 72a497fe..60604b07 100644 --- a/pyiron_workflow/create.py +++ b/pyiron_workflow/create.py @@ -107,19 +107,15 @@ def Workflow(self): @property def meta(self): if self._meta is None: - from pyiron_workflow.meta import ( - input_to_list, - list_to_output, - ) - from pyiron_workflow.loops import while_loop - from pyiron_workflow.loops import for_loop + from pyiron_workflow.transform import inputs_to_list, list_to_outputs + from pyiron_workflow.loops import for_loop, while_loop from pyiron_workflow.snippets.dotdict import DotDict self._meta = DotDict( { for_loop.__name__: for_loop, - input_to_list.__name__: input_to_list, - list_to_output.__name__: list_to_output, + inputs_to_list.__name__: inputs_to_list, + list_to_outputs.__name__: list_to_outputs, while_loop.__name__: while_loop, } ) diff --git a/pyiron_workflow/io_preview.py b/pyiron_workflow/io_preview.py index ef4a495a..bc8a85e8 100644 --- a/pyiron_workflow/io_preview.py +++ b/pyiron_workflow/io_preview.py @@ -15,7 +15,7 @@ import inspect import warnings from abc import ABC, abstractmethod -from functools import lru_cache +from functools import lru_cache, wraps from textwrap import dedent from types import FunctionType from typing import Any, get_args, get_type_hints, Literal, Optional, TYPE_CHECKING @@ -79,6 +79,22 @@ def preview_io(cls) -> DotDict[str:dict]: ) +def builds_class_io(subclass_factory: callable[..., type[HasIOPreview]]): + """ + A decorator for factories producing subclasses of `HasIOPreview` to invoke + :meth:`preview_io` after the class is created, thus ensuring the IO has been + constructed at the class level. + """ + + @wraps(subclass_factory) + def wrapped(*args, **kwargs): + node_class = subclass_factory(*args, **kwargs) + node_class.preview_io() + return node_class + + return wrapped + + class ScrapesIO(HasIOPreview, ABC): """ A mixin class for scraping IO channel information from a specific class method's @@ -391,6 +407,7 @@ def as_decorated_node_decorator( ): output_labels = None if len(output_labels) == 0 else output_labels + @builds_class_io def as_decorated_node(io_defining_function: callable): if not callable(io_defining_function): raise AttributeError( @@ -398,7 +415,7 @@ def as_decorated_node(io_defining_function: callable): f"but got {io_defining_function} instead of a callable." ) - decorated_node_class = type( + return type( io_defining_function.__name__, (parent_class,), # Define parentage { @@ -409,8 +426,6 @@ def as_decorated_node(io_defining_function: callable): **parent_class_attr_overrides, }, ) - decorated_node_class.preview_io() # Construct everything - return decorated_node_class return as_decorated_node diff --git a/pyiron_workflow/loops.py b/pyiron_workflow/loops.py index 8f8adb80..e7f9028d 100644 --- a/pyiron_workflow/loops.py +++ b/pyiron_workflow/loops.py @@ -8,7 +8,7 @@ from pyiron_workflow import Workflow from pyiron_workflow.function import Function from pyiron_workflow.macro import Macro -from pyiron_workflow.meta import input_to_list, list_to_output +from pyiron_workflow.transform import inputs_to_list, list_to_outputs from pyiron_workflow.node import Node @@ -92,7 +92,7 @@ def for_loop( f"{l}={l.upper()}[n]" if l in iterate_on else f"{l}={l}" for l in input_preview.keys() ).rstrip(" ") - input_label = 'f"inp{n}"' + input_label = 'f"item_{n}"' returns = ", ".join( f'self.children["{label.upper()}"]' for label in output_preview.keys() ) @@ -106,7 +106,7 @@ def {node_name}(self, {macro_args}): from {loop_body_class.__module__} import {loop_body_class.__name__} for label in [{output_labels}]: - input_to_list({length})(label=label, parent=self) + inputs_to_list({length}, label=label, parent=self) for n in range({length}): body_node = {loop_body_class.__name__}( diff --git a/pyiron_workflow/meta.py b/pyiron_workflow/meta.py deleted file mode 100644 index f77d6a26..00000000 --- a/pyiron_workflow/meta.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Meta nodes are callables that create a node class instead of a node instance. -""" - -from __future__ import annotations - -from pyiron_workflow.function import Function, as_function_node - - -def list_to_output(length: int, **node_class_kwargs) -> type[Function]: - """ - A meta-node that returns a node class with :param:`length` input channels and - maps these to a single output channel with type `list`. - """ - - def _list_to_many(length: int): - template = f""" -def __list_to_many(input_list: list): - {"; ".join([f"out{i} = input_list[{i}]" for i in range(length)])} - return {", ".join([f"out{i}" for i in range(length)])} - """ - exec(template) - return locals()["__list_to_many"] - - return as_function_node(*(f"output{n}" for n in range(length)))( - _list_to_many(length=length), **node_class_kwargs - ) - - -def input_to_list(length: int, **node_class_kwargs) -> type[Function]: - """ - A meta-node that returns a node class with :param:`length` output channels and - maps an input list to these. - """ - - def _many_to_list(length: int): - template = f""" -def __many_to_list({", ".join([f"inp{i}=None" for i in range(length)])}): - return [{", ".join([f"inp{i}" for i in range(length)])}] - """ - exec(template) - return locals()["__many_to_list"] - - return as_function_node("output_list")( - _many_to_list(length=length), **node_class_kwargs - ) diff --git a/pyiron_workflow/snippets/factory.py b/pyiron_workflow/snippets/factory.py new file mode 100644 index 00000000..69c8797d --- /dev/null +++ b/pyiron_workflow/snippets/factory.py @@ -0,0 +1,412 @@ +""" +Tools for making dynamically generated classes unique, and their instances pickleable. + +Provides two main user-facing tools: :func:`classfactory`, which should be used +_exclusively_ as a decorator (this restriction pertains to namespace requirements for +re-importing), and `ClassFactory`, which can be used to instantiate a new factory from +some existing factory function. + +In both cases, the decorated function/input argument should be a pickleable function +taking only positional arguments, and returning a tuple suitable for use in dynamic +class creation via :func:`builtins.type` -- i.e. taking a class name, a tuple of base +classes, a dictionary of class attributes, and a dictionary of values to be expanded +into kwargs for `__subclass_init__`. + +The resulting factory produces classes that are (a) pickleable, and (b) the same object +as any previously built class with the same name. (Note: avoiding class degeneracy with +respect to class name is the responsibility of the person writing the factory function.) + +These classes are then themselves pickleable, and produce instances which are in turn +pickleable (so long as any data they've been fed as inputs or attributes is pickleable, +i.e. here the only pickle-barrier we resolve is that of having come from a dynamically +generated class). + +Since users need to build their own class factories returning classes with sensible +names, we also provide a helper function :func:`sanitize_callable_name`, which makes +sure a string is compliant with use as a class name. This is run internally on user- +provided names, and failure for the user name and sanitized name to match will give a +clear error message. + +Constructed classes can, in turn be used as bases in further class factories. +""" + +from __future__ import annotations + +from abc import ABC, ABCMeta +from functools import wraps +from importlib import import_module +from inspect import signature, Parameter +from re import sub +from typing import ClassVar + + +class _SingleInstance(ABCMeta): + """Simple singleton pattern.""" + + _instance = None + + def __call__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(_SingleInstance, cls).__call__(*args, **kwargs) + return cls._instance + + +class _FactoryTown(metaclass=_SingleInstance): + """ + Makes sure two factories created around the same factory function are the same + factory object. + """ + + factories = {} + + @classmethod + def clear(cls): + """ + Remove factories. + + Can be useful if you're + """ + cls.factories = {} + + @staticmethod + def _factory_address(factory_function: callable) -> str: + return f"{factory_function.__module__}.{factory_function.__qualname__}" + + def get_factory(self, factory_function: callable[..., type]) -> _ClassFactory: + + self._verify_function_only_takes_positional_args(factory_function) + + address = self._factory_address(factory_function) + + try: + return self.factories[address] + except KeyError: + factory = self._build_factory(factory_function) + self.factories[address] = factory + return factory + + @staticmethod + def _build_factory(factory_function): + """ + Subclass :class:`_ClassFactory` and make an instance. + """ + new_factory_class = type( + sanitize_callable_name( + f"{factory_function.__module__}{factory_function.__qualname__}" + f"{factory_function.__name__.title()}" + f"{_ClassFactory.__name__}" + ).replace("_", ""), + (_ClassFactory,), + {}, + factory_function=factory_function, + ) + return wraps(factory_function)(new_factory_class()) + + @staticmethod + def _verify_function_only_takes_positional_args(factory_function: callable): + parameters = signature(factory_function).parameters.values() + if any( + p.kind not in [Parameter.POSITIONAL_ONLY, Parameter.VAR_POSITIONAL] + for p in parameters + ): + raise InvalidFactorySignature( + f"{_ClassFactory.__name__} can only be subclassed using factory " + f"functions that take exclusively positional arguments, but " + f"{factory_function.__name__} has the parameters {parameters}" + ) + + +_FACTORY_TOWN = _FactoryTown() + + +class InvalidFactorySignature(ValueError): + """When the factory function's arguments are not purely positional""" + + pass + + +class InvalidClassNameError(ValueError): + """When a string isn't a good class name""" + + pass + + +class _ClassFactory(metaclass=_SingleInstance): + """ + For making dynamically created classes the same class. + """ + + _decorated_as_classfactory: ClassVar[bool] = False + + def __init_subclass__(cls, /, factory_function, **kwargs): + super().__init_subclass__(**kwargs) + cls.factory_function = staticmethod(factory_function) + cls.class_registry = {} + + def __call__(self, *args) -> type[_FactoryMade]: + name, bases, class_dict, sc_init_kwargs = self.factory_function(*args) + self._verify_name_is_legal(name) + try: + return self.class_registry[name] + except KeyError: + factory_made = self._build_class( + name, + bases, + class_dict, + sc_init_kwargs, + args, + ) + self.class_registry[name] = factory_made + return factory_made + + @classmethod + def clear(cls): + """ + Remove constructed classes. + + Can be useful if you've updated the constructor and want to remove old + instances. + """ + cls.class_registry = {} + + def _build_class( + self, name, bases, class_dict, sc_init_kwargs, class_factory_args + ) -> type[_FactoryMade]: + + class_dict["__module__"] = self.factory_function.__module__ + sc_init_kwargs["class_factory"] = self + sc_init_kwargs["class_factory_args"] = class_factory_args + + if not any(_FactoryMade in base.mro() for base in bases): + bases = (_FactoryMade, *bases) + + return type(name, bases, class_dict, **sc_init_kwargs) + + @staticmethod + def _verify_name_is_legal(name): + sanitized_name = sanitize_callable_name(name) + if name != sanitized_name: + raise InvalidClassNameError( + f"The class name {name} failed to match with its sanitized version" + f"({sanitized_name}), please supply a valid class name." + ) + + def __reduce__(self): + if self._decorated_as_classfactory: + # When we create a factory by decorating the factory function, this object + # conflicts with its own factory_function attribute in the namespace, so we + # rely on directly re-importing the factory + return ( + _import_object, + (self.factory_function.__module__, self.factory_function.__qualname__), + ) + else: + return (_FACTORY_TOWN.get_factory, (self.factory_function,)) + + +def _import_object(module_name, qualname): + module = import_module(module_name) + obj = module + for name in qualname.split("."): + obj = getattr(obj, name) + return obj + + +class _FactoryMade(ABC): + """ + A mix-in to make class-factory-produced classes pickleable. + + If the factory is used as a decorator for another function, it will conflict with + this function (i.e. the owned function will be the true function, and will mismatch + with imports from that location, which will return the post-decorator factory made + class). This can be resolved by setting the + :attr:`_class_returns_from_decorated_function` attribute to be the decorated + function in the decorator definition. + """ + + _class_returns_from_decorated_function: ClassVar[callable | None] = None + + def __init_subclass__(cls, /, class_factory, class_factory_args, **kwargs): + super().__init_subclass__(**kwargs) + cls._class_factory = class_factory + cls._class_factory_args = class_factory_args + cls._factory_town = _FACTORY_TOWN + + def __reduce__(self): + if self._class_returns_from_decorated_function is not None: + # When we create a class by decorating some other function, this class + # conflicts with its own factory_function attribute in the namespace, so we + # rely on directly re-importing the factory + return ( + _instantiate_from_decorated, + ( + self._class_returns_from_decorated_function.__module__, + self._class_returns_from_decorated_function.__qualname__, + self.__getnewargs_ex__(), + ), + self.__getstate__(), + ) + else: + return ( + _instantiate_from_factory, + ( + self._class_factory, + self._class_factory_args, + self.__getnewargs_ex__(), + ), + self.__getstate__(), + ) + + def __getnewargs_ex__(self): + # Child classes can override this as needed + return (), {} + + def __getstate__(self): + # Python <3.11 compatibility + try: + return super().__getstate__() + except AttributeError: + return dict(self.__dict__) + + def __setstate__(self, state): + # Python <3.11 compatibility + try: + super().__setstate__(state) + except AttributeError: + self.__dict__.update(**state) + + +def _instantiate_from_factory(factory, factory_args, newargs_ex): + """ + Recover the dynamic class, then invoke its `__new__` to avoid instantiation (and + the possibility of positional args in `__init__`). + """ + cls = factory(*factory_args) + return cls.__new__(cls, *newargs_ex[0], **newargs_ex[1]) + + +def _instantiate_from_decorated(module, qualname, newargs_ex): + """ + In case the class comes from a decorated function, we need to import it directly. + """ + cls = _import_object(module, qualname) + return cls.__new__(cls, *newargs_ex[0], **newargs_ex[1]) + + +def classfactory( + factory_function: callable[..., tuple[str, tuple[type, ...], dict, dict]] +) -> _ClassFactory: + """ + A decorator for building dynamic class factories whose classes are unique and whose + terminal instances can be pickled. + + Under the hood, classes created by factories get dependence on + :class:`_FactoryMade` mixed in. This class leverages :meth:`__reduce__` and + :meth:`__init_subclass__` and uses up the class namespace :attr:`_class_factory` + and :attr:`_class_factory_args` to hold data (using up corresponding public + variable names in the :meth:`__init_subclass__` kwargs), so any interference with + these fields may cause unexpected side effects. For un-pickling, the dynamic class + gets recreated then its :meth:`__new__` is called using `__newargs_ex__`; a default + implementation returning no arguments is provided on :class:`_FactoryMade` but can + be overridden. + + Args: + factory_function (callable[..., tuple[str, tuple[type, ...], dict, dict]]): + A function returning arguments that would be passed to `builtins.type` to + dynamically generate a class. The function must accept exclusively + positional arguments + + Returns: + (type[_ClassFactory]): A new callable that returns unique classes whose + instances can be pickled. + + Notes: + If the :param:`factory_function` itself, or any data stored on instances of + its resulting class(es) cannot be pickled, then the instances will not be able + to be pickled. Here we only remove the trouble associated with pickling + dynamically created classes. + + If the `__init_subclass__` kwargs are exploited, remember that these are + subject to all the same "gotchas" as their regular non-factory use; namely, all + child classes must specify _all_ parent class kwargs in order to avoid them + getting overwritten by the parent class defaults! + + Dynamically generated classes can, in turn, be used as base classes for further + `@classfactory` decorated factory functions. + + Warnings: + Use _exclusively_ as a decorator. For an inline constructor for an existing + callable, use :class:`ClassFactory` instead. + + Examples: + >>> import pickle + >>> + >>> from pyiron_workflow.snippets.factory import classfactory + >>> + >>> class HasN(ABC): + ... '''Some class I want to make dynamically subclass.''' + ... def __init_subclass__(cls, /, n=0, s="foo", **kwargs): + ... super(HasN, cls).__init_subclass__(**kwargs) + ... cls.n = n + ... cls.s = s + ... + ... def __init__(self, x, y=0): + ... self.x = x + ... self.y = y + >>> + >>> @classfactory + ... def has_n_factory(n, s="wrapped_function", /): + ... return ( + ... f"{HasN.__name__}{n}{s}", # New class name + ... (HasN,), # Base class(es) + ... {}, # Class attributes dictionary + ... {"n": n, "s": s} + ... # dict of `builtins.type` kwargs (passed to `__init_subclass__`) + ... ) + >>> + >>> Has2 = has_n_factory(2, "my_dynamic_class") + >>> HasToo = has_n_factory(2, "my_dynamic_class") + >>> HasToo is Has2 + True + + >>> foo = Has2(42, y=-1) + >>> print(foo.n, foo.s, foo.x, foo.y) + 2 my_dynamic_class 42 -1 + + >>> reloaded = pickle.loads(pickle.dumps(foo)) # doctest: +SKIP + >>> print(reloaded.n, reloaded.s, reloaded.x, reloaded.y) # doctest: +SKIP + 2 my_dynamic_class 42 -1 # doctest: +SKIP + + """ + factory = _FACTORY_TOWN.get_factory(factory_function) + factory._decorated_as_classfactory = True + return factory + + +class ClassFactory: + """ + A constructor for new class factories. + + Use on existing class factory callables, _not_ as a decorator. + + Cf. the :func:`classfactory` decorator for more info. + """ + + def __new__(cls, factory_function): + return _FACTORY_TOWN.get_factory(factory_function) + + +def sanitize_callable_name(name: str): + """ + A helper class for sanitizing a string so it's appropriate as a class/function name. + """ + # Replace non-alphanumeric characters except underscores + sanitized_name = sub(r"\W+", "_", name) + # Ensure the name starts with a letter or underscore + if ( + len(sanitized_name) > 0 + and not sanitized_name[0].isalpha() + and sanitized_name[0] != "_" + ): + sanitized_name = "_" + sanitized_name + return sanitized_name diff --git a/pyiron_workflow/transform.py b/pyiron_workflow/transform.py new file mode 100644 index 00000000..b03e33ec --- /dev/null +++ b/pyiron_workflow/transform.py @@ -0,0 +1,144 @@ +""" +Transformer nodes convert many inputs into a single output, or vice-versa. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, ClassVar + +from pyiron_workflow.channels import NOT_DATA +from pyiron_workflow.io_preview import StaticNode, builds_class_io +from pyiron_workflow.snippets.factory import classfactory + + +class Transformer(StaticNode, ABC): + """ + Transformers are a special :class:`Constructed` case of :class:`StaticNode` nodes + that turn many inputs into a single output or vice-versa. + """ + + def to_dict(self): + pass # Vestigial abstract method + + +class FromManyInputs(Transformer, ABC): + _output_name: ClassVar[str] # Mandatory attribute for non-abstract subclasses + _output_type_hint: ClassVar[Any] = None + + @staticmethod + @abstractmethod + def transform_from_input(inputs_as_dict: dict): + pass + + # _build_inputs_preview required from parent class + # This must be commensurate with the internal expectations of transform_from_input + + @property + def on_run(self) -> callable[..., Any | tuple]: + return self.transform_from_input + + @property + def run_args(self) -> dict: + return {"inputs_as_dict": self.inputs.to_value_dict()} + + @classmethod + def _build_outputs_preview(cls) -> dict[str, Any]: + return {cls._output_name: cls._output_type_hint} + + def process_run_result(self, run_output: Any | tuple) -> Any | tuple: + self.outputs[self._output_name].value = run_output + return run_output + + +class ToManyOutputs(Transformer, ABC): + _input_name: ClassVar[str] # Mandatory attribute for non-abstract subclasses + _input_type_hint: ClassVar[Any] = None + _input_default: ClassVar[Any | NOT_DATA] = NOT_DATA + + @staticmethod + @abstractmethod + def transform_to_output(input_data) -> dict[str, Any]: + pass + + # _build_outputs_preview still required from parent class + # Must be commensurate with the dictionary returned by transform_to_output + + @property + def on_run(self) -> callable[..., Any | tuple]: + return self.transform_to_output + + @property + def run_args(self) -> dict: + return { + "input_data": self.inputs[self._input_name].value, + } + + @classmethod + def _build_inputs_preview(cls) -> dict[str, tuple[Any, Any]]: + return {cls._input_name: (cls._input_type_hint, cls._input_default)} + + def process_run_result(self, run_output: dict[str, Any]) -> dict[str, Any]: + for k, v in run_output.items(): + self.outputs[k].value = v + return run_output + + +class ListTransformer(Transformer, ABC): + _length: ClassVar[int] # Mandatory attribute for non-abstract subclasses + + +class InputsToList(ListTransformer, FromManyInputs, ABC): + _output_name: ClassVar[str] = "list" + _output_type_hint: ClassVar[Any] = list + + @staticmethod + def transform_from_input(inputs_as_dict: dict): + return list(inputs_as_dict.values()) + + @classmethod + def _build_inputs_preview(cls) -> dict[str, tuple[Any, Any]]: + return {f"item_{i}": (None, NOT_DATA) for i in range(cls._length)} + + +class ListToOutputs(ListTransformer, ToManyOutputs, ABC): + _input_name: ClassVar[str] = "list" + _input_type_hint: ClassVar[Any] = list + + @staticmethod + def transform_to_output(input_data: list): + return {f"item_{i}": v for i, v in enumerate(input_data)} + + @classmethod + def _build_outputs_preview(cls) -> dict[str, Any]: + return {f"item_{i}": None for i in range(cls._length)} + + +@builds_class_io +@classfactory +def inputs_to_list_factory(n: int, /) -> type[InputsToList]: + return ( + f"{InputsToList.__name__}{n}", + (InputsToList,), + {"_length": n}, + {}, + ) + + +def inputs_to_list(n: int, *node_args, **node_kwargs): + return inputs_to_list_factory(n)(*node_args, **node_kwargs) + + +@builds_class_io +@classfactory +def list_to_outputs_factory(n: int, /) -> type[ListToOutputs]: + return ( + f"{ListToOutputs.__name__}{n}", + (ListToOutputs,), + {"_length": n}, + {}, + ) + + +def list_to_outputs(n: int, /, *node_args, **node_kwargs) -> ListToOutputs: + return list_to_outputs_factory(n)(*node_args, **node_kwargs) diff --git a/tests/integration/test_transform.py b/tests/integration/test_transform.py new file mode 100644 index 00000000..deef6453 --- /dev/null +++ b/tests/integration/test_transform.py @@ -0,0 +1,45 @@ +import pickle +import unittest + +from pyiron_workflow.transform import ( + inputs_to_list, + inputs_to_list_factory, + list_to_outputs, + list_to_outputs_factory +) + + +class TestTransform(unittest.TestCase): + def test_list(self): + n = 3 + inp = inputs_to_list(n, *list(range(n)), label="inp") + out = list_to_outputs(n, inp, label="out") + out() + self.assertListEqual( + list(range(3)), + out.outputs.to_list(), + msg="Expected behaviour here is an autoencoder" + ) + + inp_class = inputs_to_list_factory(n) + out_class = list_to_outputs_factory(n) + + self.assertIs( + inp_class, + inp.__class__, + msg="Regardless of origin, we expect to be constructing the exact same " + "class" + ) + self.assertIs(out_class, out.__class__) + + reloaded = pickle.loads(pickle.dumps(out)) + self.assertEqual( + out.label, + reloaded.label, + msg="Transformers should be pickleable" + ) + self.assertDictEqual( + out.outputs.to_value_dict(), + reloaded.outputs.to_value_dict(), + msg="Transformers should be pickleable" + ) diff --git a/tests/unit/snippets/test_factory.py b/tests/unit/snippets/test_factory.py new file mode 100644 index 00000000..5e1bf5bb --- /dev/null +++ b/tests/unit/snippets/test_factory.py @@ -0,0 +1,482 @@ +from __future__ import annotations + +from abc import ABC +import pickle +from typing import ClassVar +import unittest + +from pyiron_workflow.snippets.factory import ( + _ClassFactory, + _FactoryMade, + ClassFactory, + classfactory, + InvalidClassNameError, + InvalidFactorySignature, + sanitize_callable_name +) + + +class HasN(ABC): + def __init_subclass__(cls, /, n=0, s="foo", **kwargs): + super().__init_subclass__(**kwargs) + cls.n = n + cls.s = s + + def __init__(self, x, *args, y=0, **kwargs): + super().__init__(*args, **kwargs) + self.x = x + self.y = y + + +@classfactory +def has_n_factory(n, s="wrapped_function", /): + return ( + f"{HasN.__name__}{n}{s}", + (HasN,), + {}, + {"n": n, "s": s} + ) + + +def undecorated_function(n, s="undecorated_function", /): + return ( + f"{HasN.__name__}{n}{s}", + (HasN,), + {}, + {"n": n, "s": s} + ) + + +def takes_kwargs(n, /, s="undecorated_function"): + return ( + f"{HasN.__name__}{n}{s}", + (HasN,), + {}, + {"n": n, "s": s} + ) + + +class FactoryOwner: + @staticmethod + @classfactory + def has_n_factory(n, s="decorated_method", /): + return ( + f"{HasN.__name__}{n}{s}", + (HasN,), + {}, + {"n": n, "s": s} + ) + + +Has2 = has_n_factory(2, "factory_made") # For testing repeated inheritance + + +class HasM(ABC): + def __init_subclass__(cls, /, m=0, **kwargs): + super(HasM, cls).__init_subclass__(**kwargs) + cls.m = m + + def __init__(self, z, *args, **kwargs): + super().__init__(*args, **kwargs) + self.z = z + + +@classfactory +def has_n2_m_factory(m, /): + return ( + f"HasN2M{m}", + (Has2, HasM), + {}, + {"m": m, "n": Has2.n, "s": Has2.s} + ) + + +@classfactory +def has_m_n2_factory(m, /): + return ( + f"HasM{m}N2", + (HasM, Has2,), + {}, + {"m": m} + ) + + +class AddsNandX(ABC): + fnc: ClassVar[callable] + n: ClassVar[int] + + def __init__(self, x): + self.x = x + + def add_to_function(self, *args, **kwargs): + return self.fnc(*args, **kwargs) + self.n + self.x + + +@classfactory +def adder_factory(fnc, n, /): + return ( + f"{AddsNandX.__name__}{fnc.__name__}", + (AddsNandX,), + { + "fnc": staticmethod(fnc), + "n": n, + "_class_returns_from_decorated_function": fnc + }, + {}, + ) + + +def add_to_this_decorator(n): + def wrapped(fnc): + factory_made = adder_factory(fnc, n) + factory_made._class_returns_from_decorated_function = fnc + return factory_made + return wrapped + + +@add_to_this_decorator(5) +def adds_5_plus_x(y: int): + return y + + +class TestClassfactory(unittest.TestCase): + + def test_factory_initialization(self): + self.assertTrue( + issubclass(has_n_factory.__class__, _ClassFactory), + msg="Creation by decorator should yield a subclass" + ) + self.assertTrue( + issubclass(ClassFactory(undecorated_function).__class__, _ClassFactory), + msg="Creation by public instantiator should yield a subclass" + ) + + factory = has_n_factory(2, "foo") + self.assertTrue( + issubclass(factory, HasN), + msg=f"Resulting class should inherit from the base" + ) + self.assertEqual(2, factory.n, msg="Factory args should get interpreted") + self.assertEqual("foo", factory.s, msg="Factory kwargs should get interpreted") + + def test_factory_uniqueness(self): + f1 = classfactory(undecorated_function) + f2 = classfactory(undecorated_function) + + self.assertIs( + f1, + f2, + msg="Repeatedly packaging the same function should give the exact same " + "factory" + ) + self.assertIsNot( + f1, + has_n_factory, + msg="Factory degeneracy is based on the actual wrapped function, we don't " + "do any parsing for identical behaviour inside those functions." + ) + + def test_factory_pickle(self): + with self.subTest("By decoration"): + reloaded = pickle.loads(pickle.dumps(has_n_factory)) + self.assertIs(has_n_factory, reloaded) + + with self.subTest("From instantiation"): + my_factory = ClassFactory(undecorated_function) + reloaded = pickle.loads(pickle.dumps(my_factory)) + self.assertIs(my_factory, reloaded) + + with self.subTest("From qualname by decoration"): + my_factory = FactoryOwner().has_n_factory + reloaded = pickle.loads(pickle.dumps(my_factory)) + self.assertIs(my_factory, reloaded) + + def test_class_creation(self): + n2 = has_n_factory(2, "something") + self.assertEqual( + 2, + n2.n, + msg="Factory args should be getting parsed" + ) + self.assertEqual( + "something", + n2.s, + msg="Factory kwargs should be getting parsed" + ) + self.assertTrue( + issubclass(n2, HasN), + msg="" + ) + self.assertTrue( + issubclass(n2, HasN), + msg="Resulting classes should inherit from the requested base(s)" + ) + + with self.assertRaises( + InvalidClassNameError, + msg="Invalid class names should raise an error" + ): + has_n_factory( + 2, + "our factory function uses this as part of the class name, but spaces" + "are not allowed!" + ) + + def test_class_uniqueness(self): + n2 = has_n_factory(2) + + self.assertIs( + n2, + has_n_factory(2), + msg="Repeatedly creating the same class should give the exact same class" + ) + self.assertIsNot( + n2, + has_n_factory(2, "something_else"), + msg="Sanity check" + ) + + def test_bad_factory_function(self): + with self.assertRaises( + InvalidFactorySignature, + msg="For compliance with __reduce__, we can only use factory functions " + "that strictly take positional arguments" + ): + ClassFactory(takes_kwargs) + + def test_instance_creation(self): + foo = has_n_factory(2, "used")(42, y=43) + self.assertEqual( + 2, foo.n, msg="Class attributes should be inherited" + ) + self.assertEqual( + "used", foo.s, msg="Class attributes should be inherited" + ) + self.assertEqual( + 42, foo.x, msg="Initialized args should be captured" + ) + self.assertEqual( + 43, foo.y, msg="Initialized kwargs should be captured" + ) + self.assertIsInstance( + foo, + HasN, + msg="Instances should inherit from the requested base(s)" + ) + self.assertIsInstance( + foo, + _FactoryMade, + msg="Instances should get :class:`_FactoryMade` mixed in." + ) + + def test_instance_pickle(self): + foo = has_n_factory(2, "used")(42, y=43) + reloaded = pickle.loads(pickle.dumps(foo)) + self.assertEqual( + foo.n, reloaded.n, msg="Class attributes should be reloaded" + ) + self.assertEqual( + foo.s, reloaded.s, msg="Class attributes should be reloaded" + ) + self.assertEqual( + foo.x, reloaded.x, msg="Initialized args should be reloaded" + ) + self.assertEqual( + foo.y, reloaded.y, msg="Initialized kwargs should be reloaded" + ) + self.assertIsInstance( + reloaded, + HasN, + msg="Instances should inherit from the requested base(s)" + ) + self.assertIsInstance( + reloaded, + _FactoryMade, + msg="Instances should get :class:`_FactoryMade` mixed in." + ) + + def test_decorated_method(self): + msg = "It should be possible to have class factories as methods on a class" + foo = FactoryOwner().has_n_factory(2)(42, y=43) + reloaded = pickle.loads(pickle.dumps(foo)) + self.assertEqual(foo.n, reloaded.n, msg=msg) + self.assertEqual(foo.s, reloaded.s, msg=msg) + self.assertEqual(foo.x, reloaded.x, msg=msg) + self.assertEqual(foo.y, reloaded.y, msg=msg) + + def test_factory_inside_a_function(self): + @classfactory + def internal_factory(n, s="unimportable_scope", /): + return ( + f"{HasN.__name__}{n}{s}", + (HasN,), + {}, + {"n": n, "s": s} + ) + + foo = internal_factory(2)(1, y=0) + self.assertEqual(2, foo.n, msg="Nothing should stop the factory from working") + self.assertEqual( + "unimportable_scope", + foo.s, + msg="Nothing should stop the factory from working" + ) + self.assertEqual(1, foo.x, msg="Nothing should stop the factory from working") + self.assertEqual(0, foo.y, msg="Nothing should stop the factory from working") + with self.assertRaises( + AttributeError, + msg="`internal_factory` is defined only locally inside the scope of " + "another function, so we don't expect it to be pickleable whether it's " + "a class factory or not!" + ): + pickle.loads(pickle.dumps(foo)) + + def test_repeated_inheritance(self): + n2m3 = has_n2_m_factory(3)(5, 6) + m3n2 = has_m_n2_factory(3)(5, 6) + + self.assertListEqual( + [3, 2, "factory_made"], + [n2m3.m, n2m3.n, n2m3.s], + msg="Sanity check on class property inheritance" + ) + self.assertListEqual( + [3, 0, "foo"], # n and s defaults from HasN! + [m3n2.m, m3n2.n, m3n2.s], + msg="When exploiting __init_subclass__, each subclass must take care to " + "specify _all_ parent class __init_subclass__ kwargs, or they will " + "revert to the default behaviour. This is totally normal python " + "behaviour, and here we just verify that we're vulnerable to the same " + "'gotcha' as the rest of the language." + ) + self.assertListEqual( + [5, 6], + [n2m3.x, n2m3.z], + msg="Sanity check on instance inheritance" + ) + self.assertListEqual( + [m3n2.z, m3n2.x], + [n2m3.x, n2m3.z], + msg="Inheritance order should impact arg order, also completely as usual " + "for python classes" + ) + reloaded_nm = pickle.loads(pickle.dumps(n2m3)) + self.assertListEqual( + [n2m3.m, n2m3.n, n2m3.s, n2m3.z, n2m3.x, n2m3.y], + [ + reloaded_nm.m, + reloaded_nm.n, + reloaded_nm.s, + reloaded_nm.z, + reloaded_nm.x, + reloaded_nm.y + ], + msg="Pickling behaviour should not care that one of the parents was itself " + "a factory made class." + ) + + reloaded_mn = pickle.loads(pickle.dumps(m3n2)) + self.assertListEqual( + [m3n2.m, m3n2.n, m3n2.s, m3n2.z, m3n2.x, m3n2.y], + [ + reloaded_mn.m, + reloaded_mn.n, + reloaded_mn.s, + reloaded_mn.z, + reloaded_mn.x, + reloaded_nm.y + ], + msg="Pickling behaviour should not care about the order of bases." + ) + + def test_clearing_town(self): + + self.assertGreater(len(Has2._factory_town.factories), 0, msg="Sanity check") + + Has2._factory_town.clear() + self.assertEqual( + len(Has2._factory_town.factories), + 0, + msg="Town should get cleared" + ) + + ClassFactory(undecorated_function) + self.assertEqual( + len(Has2._factory_town.factories), + 1, + msg="Has2 exists in memory and the factory town has forgotten about it, " + "but it still knows about the factory town and can see the newly " + "created one." + ) + + def test_clearing_class_register(self): + self.assertGreater( + len(has_n_factory.class_registry), + 0, + msg="Sanity. We expect to have created at least one class up in the header." + ) + has_n_factory.clear() + self.assertEqual( + len(has_n_factory.class_registry), + 0, + msg="Clear should remove all instances" + ) + n_new = 3 + for i in range(n_new): + has_n_factory(i) + self.assertEqual( + len(has_n_factory.class_registry), + n_new, + msg="Should see the new constructed classes" + ) + + def test_other_decorators(self): + """ + In case the factory-produced class itself comes from a decorator, we need to + check that name conflicts between the class and decorated function are handled. + """ + a5 = adds_5_plus_x(2) + self.assertIsInstance(a5, AddsNandX) + self.assertIsInstance(a5, _FactoryMade) + self.assertEqual(5, a5.n) + self.assertEqual(2, a5.x) + self.assertEqual( + 1 + 5 + 2, # y + n=5 + x=2 + a5.add_to_function(1), + msg="Should execute the function as part of call" + ) + + reloaded = pickle.loads(pickle.dumps(a5)) + self.assertEqual(a5.n, reloaded.n) + self.assertIs(a5.fnc, reloaded.fnc) + self.assertEqual(a5.x, reloaded.x) + + +class TestSanitization(unittest.TestCase): + + def test_simple_string(self): + self.assertEqual(sanitize_callable_name("SimpleString"), "SimpleString") + + def test_string_with_spaces(self): + self.assertEqual( + sanitize_callable_name("String with spaces"), "String_with_spaces" + ) + + def test_string_with_special_characters(self): + self.assertEqual(sanitize_callable_name("a!@#$%b^&*()c"), "a_b_c") + + def test_string_with_numbers_at_start(self): + self.assertEqual(sanitize_callable_name("123Class"), "_123Class") + + def test_empty_string(self): + self.assertEqual(sanitize_callable_name(""), "") + + def test_string_with_only_special_characters(self): + self.assertEqual(sanitize_callable_name("!@#$%"), "_") + + def test_string_with_only_numbers(self): + self.assertEqual(sanitize_callable_name("123456"), "_123456") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/snippets/test_singleton.py b/tests/unit/snippets/test_singleton.py index d18c848a..7953445e 100644 --- a/tests/unit/snippets/test_singleton.py +++ b/tests/unit/snippets/test_singleton.py @@ -17,4 +17,3 @@ def __init__(self): if __name__ == '__main__': unittest.main() -