diff --git a/mephisto/abstractions/blueprint.py b/mephisto/abstractions/blueprint.py index 5098df1a7..a35d9f36a 100644 --- a/mephisto/abstractions/blueprint.py +++ b/mephisto/abstractions/blueprint.py @@ -534,8 +534,55 @@ def get_task_end(self) -> Optional[float]: class BlueprintMixin(ABC): """ Base class for compositional mixins for blueprints + + We expect mixins that subclass other mixins to handle subinitialization + work, such that only the highest class needs to be called. """ + @property + @abstractmethod + def ArgsMixin(self) -> Any: # Should be a dataclass, to extend BlueprintArgs + pass + + @property + @abstractmethod + def SharedStateMixin( + self, + ) -> Any: # Also should be a dataclass, to extend SharedTaskState + pass + + @staticmethod + def extract_unique_mixins(blueprint_class: ClassVar[Type["Blueprint"]]): + """Return the unique mixin classes that are used in the given blueprint class""" + mixin_subclasses = [ + clazz + for clazz in blueprint_class.mro() + if issubclass(clazz, BlueprintMixin) + ] + # Remove magic created with `mixin_args_and_state` + while blueprint_class.__name__ == "MixedInBlueprint": + blueprint_class = mixin_subclasses.pop(0) + removed_locals = [ + clazz + for clazz in mixin_subclasses + if "MixedInBlueprint" not in clazz.__name__ + ] + filtered_subclasses = set( + clazz + for clazz in removed_locals + if clazz != BlueprintMixin and clazz != blueprint_class + ) + # we also want to make sure that we don't double-count extensions of mixins, so remove classes that other classes are subclasses of + def is_subclassed(clazz): + return True in [ + issubclass(x, clazz) and x != clazz for x in filtered_subclasses + ] + + unique_subclasses = [ + clazz for clazz in filtered_subclasses if not is_subclassed(clazz) + ] + return unique_subclasses + @abstractmethod def init_mixin_config( self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState" @@ -545,7 +592,7 @@ def init_mixin_config( @classmethod @abstractmethod - def assert_task_args( + def assert_mixin_args( cls, args: "DictConfig", shared_state: "SharedTaskState" ) -> None: """Method to validate the incoming args and throw if something won't work""" @@ -559,6 +606,37 @@ def get_mixin_qualifications( """Method to provide any required qualifications to make this mixin function""" raise NotImplementedError() + @classmethod + def mixin_args_and_state(mixin_cls: "BlueprintMixin", target_cls: "Blueprint"): + """ + Magic utility decorator that can be used to inject mixin configurations + (BlueprintArgs and SharedTaskState) without the user needing to define new + classes for these. Should only be used by tasks that aren't already specifying + new versions of these, which should just inherit otherwise. + + Usage: + @register_mephisto_abstraction() + @ABlueprintMixin.mixin_args_and_state + class MyBlueprint(ABlueprintMixin, Blueprint): + pass + """ + + @dataclass + class MixedInArgsClass(mixin_cls.ArgsMixin, target_cls.ArgsClass): + pass + + @dataclass + class MixedInSharedStateClass( + mixin_cls.SharedStateMixin, target_cls.SharedStateClass + ): + pass + + class MixedInBlueprint(target_cls): + ArgsClass = MixedInArgsClass + SharedStateClass = MixedInSharedStateClass + + return MixedInBlueprint + class Blueprint(ABC): """ @@ -585,6 +663,20 @@ def __init__( self.shared_state = shared_state self.frontend_task_config = shared_state.task_config + # We automatically call all mixins `init_mixin_config` methods available. + mixin_subclasses = BlueprintMixin.extract_unique_mixins(self.__class__) + for clazz in mixin_subclasses: + clazz.init_mixin_config(self, task_run, args, shared_state) + + @classmethod + def get_required_qualifications( + cls, args: DictConfig, shared_state: "SharedTaskState" + ): + quals = [] + for clazz in BlueprintMixin.extract_unique_mixins(cls): + quals += clazz.get_mixin_qualifications(args, shared_state) + return quals + @classmethod def assert_task_args(cls, args: DictConfig, shared_state: "SharedTaskState"): """ @@ -592,6 +684,10 @@ def assert_task_args(cls, args: DictConfig, shared_state: "SharedTaskState"): fail if a task launched with these arguments would not work """ + # We automatically call all mixins `assert_task_args` methods available. + mixin_subclasses = BlueprintMixin.extract_unique_mixins(cls) + for clazz in mixin_subclasses: + clazz.assert_mixin_args(args, shared_state) return def get_frontend_args(self) -> Dict[str, Any]: diff --git a/mephisto/abstractions/blueprints/abstract/static_task/static_blueprint.py b/mephisto/abstractions/blueprints/abstract/static_task/static_blueprint.py index 44f00c821..18be4e65e 100644 --- a/mephisto/abstractions/blueprints/abstract/static_task/static_blueprint.py +++ b/mephisto/abstractions/blueprints/abstract/static_task/static_blueprint.py @@ -52,12 +52,12 @@ @dataclass -class SharedStaticTaskState(SharedTaskState, OnboardingSharedState): +class SharedStaticTaskState(OnboardingRequired.SharedStateMixin, SharedTaskState): static_task_data: Iterable[Any] = field(default_factory=list) @dataclass -class StaticBlueprintArgs(BlueprintArgs): +class StaticBlueprintArgs(OnboardingRequired.ArgsMixin, BlueprintArgs): _blueprint_type: str = BLUEPRINT_TYPE _group: str = field( default="StaticBlueprint", @@ -92,7 +92,7 @@ class StaticBlueprintArgs(BlueprintArgs): ) -class StaticBlueprint(Blueprint, OnboardingRequired): +class StaticBlueprint(OnboardingRequired, Blueprint): """ Abstract blueprint for a task that runs without any extensive backend. These are generally one-off tasks sending data to the frontend and then @@ -114,7 +114,6 @@ def __init__( shared_state: "SharedStaticTaskState", ): super().__init__(task_run, args, shared_state) - self.init_onboarding_config(task_run, args, shared_state) # Originally just a list of dicts, but can also be a generator of dicts self._initialization_data_dicts: Iterable[Dict[str, Any]] = [] @@ -152,6 +151,8 @@ def __init__( @classmethod def assert_task_args(cls, args: DictConfig, shared_state: "SharedStaticTaskState"): """Ensure that the data can be properly loaded""" + super().assert_task_args(args, shared_state) + blue_args = args.blueprint if blue_args.get("data_csv", None) is not None: csv_file = os.path.expanduser(blue_args.data_csv) diff --git a/mephisto/abstractions/blueprints/mixins/onboarding_required.py b/mephisto/abstractions/blueprints/mixins/onboarding_required.py index 8498c1468..88ac0fed4 100644 --- a/mephisto/abstractions/blueprints/mixins/onboarding_required.py +++ b/mephisto/abstractions/blueprints/mixins/onboarding_required.py @@ -53,6 +53,9 @@ class OnboardingRequired(BlueprintMixin): Compositional class for blueprints that may have an onboarding step """ + ArgsMixin = OnboardingRequiredArgs + SharedStateMixin = OnboardingSharedState + def init_mixin_config( self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState" ) -> None: @@ -60,7 +63,7 @@ def init_mixin_config( self.init_onboarding_config(task_run, args, shared_state) @classmethod - def assert_task_args( + def assert_mixin_args( cls, args: "DictConfig", shared_state: "SharedTaskState" ) -> None: """Method to validate the incoming args and throw if something won't work""" diff --git a/mephisto/abstractions/blueprints/mixins/screen_task_required.py b/mephisto/abstractions/blueprints/mixins/screen_task_required.py index b5c4f761a..cee7c0ff7 100644 --- a/mephisto/abstractions/blueprints/mixins/screen_task_required.py +++ b/mephisto/abstractions/blueprints/mixins/screen_task_required.py @@ -81,6 +81,9 @@ class ScreenTaskRequired(BlueprintMixin): qualify workers who have never attempted the task before """ + ArgsMixin = ScreenTaskRequiredArgs + SharedStateMixin = ScreenTaskSharedState + def init_mixin_config( self, task_run: "TaskRun", @@ -113,7 +116,7 @@ def init_screening_config( find_or_create_qualification(task_run.db, self.failed_qualification_name) @classmethod - def assert_task_args(cls, args: "DictConfig", shared_state: "SharedTaskState"): + def assert_mixin_args(cls, args: "DictConfig", shared_state: "SharedTaskState"): use_screening_task = args.blueprint.get("use_screening_task", False) if not use_screening_task: return diff --git a/mephisto/abstractions/blueprints/mock/mock_blueprint.py b/mephisto/abstractions/blueprints/mock/mock_blueprint.py index b69bfae2f..ae85c6c29 100644 --- a/mephisto/abstractions/blueprints/mock/mock_blueprint.py +++ b/mephisto/abstractions/blueprints/mock/mock_blueprint.py @@ -90,9 +90,6 @@ def __init__( self, task_run: "TaskRun", args: "DictConfig", shared_state: "MockSharedState" ): super().__init__(task_run, args, shared_state) - # TODO these can be done with self.mro() and using the mixin variant - self.init_onboarding_config(task_run, args, shared_state) - self.init_screening_config(task_run, args, shared_state) def get_initialization_data(self) -> Iterable[InitializationData]: """ diff --git a/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_blueprint.py b/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_blueprint.py index 2445dff7b..57c7a66a5 100644 --- a/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_blueprint.py +++ b/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_blueprint.py @@ -58,7 +58,7 @@ @dataclass -class SharedParlAITaskState(SharedTaskState, OnboardingSharedState): +class SharedParlAITaskState(OnboardingRequired.SharedStateMixin, SharedTaskState): frontend_task_opts: Dict[str, Any] = field(default_factory=dict) world_opt: Dict[str, Any] = field(default_factory=dict) onboarding_world_opt: Dict[str, Any] = field(default_factory=dict) @@ -66,7 +66,7 @@ class SharedParlAITaskState(SharedTaskState, OnboardingSharedState): @dataclass -class ParlAIChatBlueprintArgs(BlueprintArgs): +class ParlAIChatBlueprintArgs(OnboardingRequired.ArgsMixin, BlueprintArgs): _blueprint_type: str = BLUEPRINT_TYPE _group: str = field( default="ParlAIChatBlueprint", @@ -130,7 +130,7 @@ class ParlAIChatBlueprintArgs(BlueprintArgs): @register_mephisto_abstraction() -class ParlAIChatBlueprint(Blueprint, OnboardingRequired): +class ParlAIChatBlueprint(OnboardingRequired, Blueprint): """Blueprint for a task that runs a parlai chat""" AgentStateClass: ClassVar[Type["AgentState"]] = ParlAIChatAgentState @@ -154,7 +154,6 @@ def __init__( ): super().__init__(task_run, args, shared_state) self._initialization_data_dicts: List[Dict[str, Any]] = [] - self.init_onboarding_config(task_run, args, shared_state) if args.blueprint.get("context_csv", None) is not None: csv_file = os.path.expanduser(args.blueprint.context_csv) diff --git a/mephisto/abstractions/blueprints/static_html_task/static_html_blueprint.py b/mephisto/abstractions/blueprints/static_html_task/static_html_blueprint.py index cb766d238..03f6773d9 100644 --- a/mephisto/abstractions/blueprints/static_html_task/static_html_blueprint.py +++ b/mephisto/abstractions/blueprints/static_html_task/static_html_blueprint.py @@ -107,6 +107,7 @@ def __init__( @classmethod def assert_task_args(cls, args: DictConfig, shared_state: "SharedTaskState"): """Ensure that the data can be properly loaded""" + Blueprint.assert_task_args(args, shared_state) blue_args = args.blueprint if isinstance(shared_state.static_task_data, types.GeneratorType): raise AssertionError("You can't launch an HTML static task on a generator") diff --git a/test/abstractions/blueprints/test_mixin_core.py b/test/abstractions/blueprints/test_mixin_core.py new file mode 100644 index 000000000..f31a854c2 --- /dev/null +++ b/test/abstractions/blueprints/test_mixin_core.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +import tempfile +import os +import shutil + +from omegaconf import OmegaConf +from dataclasses import dataclass + +from mephisto.data_model.task_run import TaskRun +from mephisto.abstractions.databases.local_database import LocalMephistoDB +from mephisto.abstractions.blueprint import ( + Blueprint, + BlueprintMixin, + BlueprintArgs, + SharedTaskState, +) +from mephisto.abstractions.test.utils import get_test_task_run +from mephisto.abstractions.architects.mock_architect import ( + MockArchitect, + MockArchitectArgs, +) +from mephisto.operations.hydra_config import MephistoConfig +from mephisto.abstractions.providers.mock.mock_provider import MockProviderArgs +from mephisto.abstractions.blueprints.mock.mock_blueprint import MockBlueprintArgs +from mephisto.data_model.task_config import TaskConfigArgs + +from typing import List, Dict, Optional, Any + + +class BrokenMixin(BlueprintMixin): + """Mixin that fails to define ArgsMixin or SharedStateMixin""" + + def init_mixin_config( + self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState" + ) -> None: + return + + @classmethod + def assert_mixin_args( + cls, args: "DictConfig", shared_state: "SharedTaskState" + ) -> None: + return + + @classmethod + def get_mixin_qualifications( + cls, args: "DictConfig", shared_state: "SharedTaskState" + ) -> List[Dict[str, Any]]: + return [] + + +@dataclass +class ArgsMixin1: + arg1: int = 0 + + +@dataclass +class StateMixin1: + arg1: int = 0 + + +class MockBlueprintMixin1(BlueprintMixin): + MOCK_QUAL_NAME = "mock_mixin_one" + ArgsMixin = ArgsMixin1 + SharedStateMixin = StateMixin1 + + def init_mixin_config( + self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState" + ) -> None: + if hasattr(self, "mixin_init_calls"): + self.mixin_init_calls += 1 + else: + self.mixin_init_calls = 1 + + @classmethod + def assert_mixin_args( + cls, args: "DictConfig", shared_state: "SharedTaskState" + ) -> None: + assert args.blueprint.arg1 == 0, "Was not the default value of arg1" + + @classmethod + def get_mixin_qualifications( + cls, args: "DictConfig", shared_state: "SharedTaskState" + ) -> List[Dict[str, Any]]: + return [{"qual_name": cls.MOCK_QUAL_NAME}] + + +@dataclass +class ArgsMixin2: + arg2: int = 0 + + +@dataclass +class StateMixin2: + arg2: int = 0 + + +class MockBlueprintMixin2(BlueprintMixin): + MOCK_QUAL_NAME = "mock_mixin_two" + ArgsMixin = ArgsMixin2 + SharedStateMixin = StateMixin2 + + def init_mixin_config( + self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState" + ) -> None: + if hasattr(self, "mixin_init_calls"): + self.mixin_init_calls += 1 + else: + self.mixin_init_calls = 1 + + @classmethod + def assert_mixin_args( + cls, args: "DictConfig", shared_state: "SharedTaskState" + ) -> None: + assert args.blueprint.arg2 == 0, "Was not the default value of arg2" + + @classmethod + def get_mixin_qualifications( + cls, args: "DictConfig", shared_state: "SharedTaskState" + ) -> List[Dict[str, Any]]: + return [{"qual_name": cls.MOCK_QUAL_NAME}] + + +@dataclass +class ComposedArgsMixin(ArgsMixin1, ArgsMixin2): + pass + + +@dataclass +class ComposedStateMixin(StateMixin1, StateMixin2): + pass + + +class ComposedMixin(MockBlueprintMixin1, MockBlueprintMixin2): + MOCK_QUAL_NAME = "mock_mixin_mixed" + ArgsMixin = ComposedArgsMixin + SharedStateMixin = ComposedStateMixin + + @classmethod + def assert_mixin_args( + cls, args: "DictConfig", shared_state: "SharedTaskState" + ) -> None: + MockBlueprintMixin1.assert_mixin_args(args, shared_state) + MockBlueprintMixin2.assert_mixin_args(args, shared_state) + + +class TestBlueprintMixinCore(unittest.TestCase): + """Test the functionality underlying blueprint mixins that allow them to work""" + + def get_structured_config(self, blueprint_args): + config = MephistoConfig( + blueprint=blueprint_args, + provider=MockProviderArgs(requester_name="mock_requester"), + architect=MockArchitectArgs(should_run_server=False), + task=TaskConfigArgs( + task_title="title", + task_description="This is a description", + task_reward="0.3", + task_tags="1,2,3", + maximum_units_per_worker=2, + allowed_concurrent=1, + task_name="max-unit-test", + ), + ) + return OmegaConf.structured(config) + + def setUp(self) -> None: + self.data_dir = tempfile.mkdtemp() + database_path = os.path.join(self.data_dir, "mephisto.db") + self.db = LocalMephistoDB(database_path) + self.task_run = TaskRun.get(self.db, get_test_task_run(self.db)) + + def tearDown(self) -> None: + self.db.shutdown() + shutil.rmtree(self.data_dir) + + def test_broken_mixin(self): + class TestBlueprint(BrokenMixin, Blueprint): + def get_initialization_data(self): + return [] + + args = TestBlueprint.ArgsClass() + shared_state = TestBlueprint.SharedStateClass() + cfg = self.get_structured_config(args) + with self.assertRaises( + TypeError, msg="Mixin classes not defined should raise type error" + ): + blueprint = TestBlueprint(self.task_run, cfg, shared_state) + + with self.assertRaises( + TypeError, msg="Undefined mixin classes should fail here too" + ): + + @BrokenMixin.mixin_args_and_state + class TestBlueprint(BrokenMixin, Blueprint): + def get_initialization_data(self): + return [] + + def test_working_mixin(self): + class TestBlueprint(MockBlueprintMixin1, Blueprint): + def get_initialization_data(self): + return [] + + args = TestBlueprint.ArgsClass() + shared_state = TestBlueprint.SharedStateClass() + cfg = self.get_structured_config(args) + with self.assertRaises(Exception, msg="Class should not have correct args"): + TestBlueprint.assert_task_args(cfg, shared_state) + + # Working mixin by manually creating classes + @dataclass + class TestArgs(ArgsMixin1, BlueprintArgs): + pass + + @dataclass + class TestState(StateMixin1, SharedTaskState): + pass + + class TestBlueprint(MockBlueprintMixin1, Blueprint): + ArgsClass = TestArgs + SharedStateClass = TestState + + def get_initialization_data(self): + return [] + + args = TestBlueprint.ArgsClass() + shared_state = TestBlueprint.SharedStateClass() + cfg = self.get_structured_config(args) + TestBlueprint.assert_task_args(cfg, shared_state) + blueprint = TestBlueprint(self.task_run, cfg, shared_state) + self.assertEqual( + blueprint.mixin_init_calls, 1, "More than one mixin init call!" + ) + + # Working mixin using the decorator + @MockBlueprintMixin1.mixin_args_and_state + class TestBlueprint(MockBlueprintMixin1, Blueprint): + def get_initialization_data(self): + return [] + + args = TestBlueprint.ArgsClass() + shared_state = TestBlueprint.SharedStateClass() + cfg = self.get_structured_config(args) + TestBlueprint.assert_task_args(cfg, shared_state) + blueprint = TestBlueprint(self.task_run, cfg, shared_state) + self.assertEqual( + blueprint.mixin_init_calls, 1, "More than one mixin init call!" + ) + + def test_mixin_multi_inheritence(self): + @MockBlueprintMixin1.mixin_args_and_state + @MockBlueprintMixin2.mixin_args_and_state + class DoubleMixinBlueprint(MockBlueprintMixin1, MockBlueprintMixin2, Blueprint): + def get_initialization_data(self): + return [] + + args = DoubleMixinBlueprint.ArgsClass() + shared_state = DoubleMixinBlueprint.SharedStateClass() + cfg = self.get_structured_config(args) + DoubleMixinBlueprint.assert_task_args(cfg, shared_state) + blueprint = DoubleMixinBlueprint(self.task_run, cfg, shared_state) + self.assertEqual(blueprint.mixin_init_calls, 2, "Should have 2 mixin calls") + + # Ensure qualifications are correct + required_quals = DoubleMixinBlueprint.get_required_qualifications( + args, shared_state + ) + self.assertEqual( + len(BlueprintMixin.extract_unique_mixins(DoubleMixinBlueprint)), 2 + ) + qual_names = [q["qual_name"] for q in required_quals] + self.assertIn(MockBlueprintMixin1.MOCK_QUAL_NAME, qual_names) + self.assertIn(MockBlueprintMixin2.MOCK_QUAL_NAME, qual_names) + + # Check functionality of important helpers + self.assertEqual( + len(BlueprintMixin.extract_unique_mixins(DoubleMixinBlueprint)), 2 + ) + + # Ensure failures work for each of the arg failures + shared_state = DoubleMixinBlueprint.SharedStateClass() + args = DoubleMixinBlueprint.ArgsClass(arg1=2) + cfg = self.get_structured_config(args) + with self.assertRaises(AssertionError, msg="Should have called both asserts"): + DoubleMixinBlueprint.assert_task_args(cfg, shared_state) + args = DoubleMixinBlueprint.ArgsClass(arg2=2) + print(args) + cfg = self.get_structured_config(args) + with self.assertRaises(AssertionError, msg="Should have called both asserts"): + DoubleMixinBlueprint.assert_task_args(cfg, shared_state) + + def test_composed_mixin_inheritence(self): + @ComposedMixin.mixin_args_and_state + class ComposedBlueprint(ComposedMixin, MockBlueprintMixin1, Blueprint): + def get_initialization_data(self): + return [] + + args = ComposedBlueprint.ArgsClass() + shared_state = ComposedBlueprint.SharedStateClass() + cfg = self.get_structured_config(args) + ComposedBlueprint.assert_task_args(cfg, shared_state) + blueprint = ComposedBlueprint(self.task_run, cfg, shared_state) + self.assertEqual(blueprint.mixin_init_calls, 1, "Should have 1 mixin call") + + # Ensure qualifications are correct + required_quals = ComposedBlueprint.get_required_qualifications( + args, shared_state + ) + self.assertEqual( + len(BlueprintMixin.extract_unique_mixins(ComposedBlueprint)), 1 + ) + qual_names = [q["qual_name"] for q in required_quals] + self.assertIn(ComposedBlueprint.MOCK_QUAL_NAME, qual_names) + + # Check functionality of important helpers + self.assertEqual( + len(BlueprintMixin.extract_unique_mixins(ComposedBlueprint)), 1 + ) + + # Ensure failures work for each of the arg failures + shared_state = ComposedBlueprint.SharedStateClass() + args = ComposedBlueprint.ArgsClass(arg1=2) + cfg = self.get_structured_config(args) + with self.assertRaises(AssertionError, msg="Should have called both asserts"): + ComposedBlueprint.assert_task_args(cfg, shared_state) + args = ComposedBlueprint.ArgsClass(arg2=2) + cfg = self.get_structured_config(args) + with self.assertRaises(AssertionError, msg="Should have called both asserts"): + ComposedBlueprint.assert_task_args(cfg, shared_state) + + +if __name__ == "__main__": + unittest.main()