Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blueprint Mixin cleanup and test #582

Merged
merged 1 commit into from
Oct 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion mephisto/abstractions/blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,55 @@ def get_task_end(self) -> Optional[float]:
class BlueprintMixin(ABC):
"""
Base class for compositional mixins for blueprints

We expect mixins that subclass other mixins to handle subinitialization
work, such that only the highest class needs to be called.
"""

@property
@abstractmethod
def ArgsMixin(self) -> Any: # Should be a dataclass, to extend BlueprintArgs
pass

@property
@abstractmethod
def SharedStateMixin(
self,
) -> Any: # Also should be a dataclass, to extend SharedTaskState
pass

@staticmethod
def extract_unique_mixins(blueprint_class: ClassVar[Type["Blueprint"]]):
"""Return the unique mixin classes that are used in the given blueprint class"""
mixin_subclasses = [
clazz
for clazz in blueprint_class.mro()
if issubclass(clazz, BlueprintMixin)
]
# Remove magic created with `mixin_args_and_state`
while blueprint_class.__name__ == "MixedInBlueprint":
blueprint_class = mixin_subclasses.pop(0)
removed_locals = [
clazz
for clazz in mixin_subclasses
if "MixedInBlueprint" not in clazz.__name__
]
filtered_subclasses = set(
clazz
for clazz in removed_locals
if clazz != BlueprintMixin and clazz != blueprint_class
)
# we also want to make sure that we don't double-count extensions of mixins, so remove classes that other classes are subclasses of
def is_subclassed(clazz):
return True in [
issubclass(x, clazz) and x != clazz for x in filtered_subclasses
]

unique_subclasses = [
clazz for clazz in filtered_subclasses if not is_subclassed(clazz)
]
return unique_subclasses

@abstractmethod
def init_mixin_config(
self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState"
Expand All @@ -545,7 +592,7 @@ def init_mixin_config(

@classmethod
@abstractmethod
def assert_task_args(
def assert_mixin_args(
cls, args: "DictConfig", shared_state: "SharedTaskState"
) -> None:
"""Method to validate the incoming args and throw if something won't work"""
Expand All @@ -559,6 +606,37 @@ def get_mixin_qualifications(
"""Method to provide any required qualifications to make this mixin function"""
raise NotImplementedError()

@classmethod
def mixin_args_and_state(mixin_cls: "BlueprintMixin", target_cls: "Blueprint"):
"""
Magic utility decorator that can be used to inject mixin configurations
(BlueprintArgs and SharedTaskState) without the user needing to define new
classes for these. Should only be used by tasks that aren't already specifying
new versions of these, which should just inherit otherwise.

Usage:
@register_mephisto_abstraction()
@ABlueprintMixin.mixin_args_and_state
class MyBlueprint(ABlueprintMixin, Blueprint):
pass
"""

@dataclass
class MixedInArgsClass(mixin_cls.ArgsMixin, target_cls.ArgsClass):
pass

@dataclass
class MixedInSharedStateClass(
mixin_cls.SharedStateMixin, target_cls.SharedStateClass
):
pass

class MixedInBlueprint(target_cls):
ArgsClass = MixedInArgsClass
SharedStateClass = MixedInSharedStateClass

return MixedInBlueprint


class Blueprint(ABC):
"""
Expand All @@ -585,13 +663,31 @@ def __init__(
self.shared_state = shared_state
self.frontend_task_config = shared_state.task_config

# We automatically call all mixins `init_mixin_config` methods available.
mixin_subclasses = BlueprintMixin.extract_unique_mixins(self.__class__)
for clazz in mixin_subclasses:
clazz.init_mixin_config(self, task_run, args, shared_state)

@classmethod
def get_required_qualifications(
cls, args: DictConfig, shared_state: "SharedTaskState"
):
quals = []
for clazz in BlueprintMixin.extract_unique_mixins(cls):
quals += clazz.get_mixin_qualifications(args, shared_state)
return quals

@classmethod
def assert_task_args(cls, args: DictConfig, shared_state: "SharedTaskState"):
"""
Assert that the provided arguments are valid. Should
fail if a task launched with these arguments would
not work
"""
# We automatically call all mixins `assert_task_args` methods available.
mixin_subclasses = BlueprintMixin.extract_unique_mixins(cls)
for clazz in mixin_subclasses:
clazz.assert_mixin_args(args, shared_state)
return

def get_frontend_args(self) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@


@dataclass
class SharedStaticTaskState(SharedTaskState, OnboardingSharedState):
class SharedStaticTaskState(OnboardingRequired.SharedStateMixin, SharedTaskState):
static_task_data: Iterable[Any] = field(default_factory=list)


@dataclass
class StaticBlueprintArgs(BlueprintArgs):
class StaticBlueprintArgs(OnboardingRequired.ArgsMixin, BlueprintArgs):
_blueprint_type: str = BLUEPRINT_TYPE
_group: str = field(
default="StaticBlueprint",
Expand Down Expand Up @@ -92,7 +92,7 @@ class StaticBlueprintArgs(BlueprintArgs):
)


class StaticBlueprint(Blueprint, OnboardingRequired):
class StaticBlueprint(OnboardingRequired, Blueprint):
"""
Abstract blueprint for a task that runs without any extensive backend.
These are generally one-off tasks sending data to the frontend and then
Expand All @@ -114,7 +114,6 @@ def __init__(
shared_state: "SharedStaticTaskState",
):
super().__init__(task_run, args, shared_state)
self.init_onboarding_config(task_run, args, shared_state)

# Originally just a list of dicts, but can also be a generator of dicts
self._initialization_data_dicts: Iterable[Dict[str, Any]] = []
Expand Down Expand Up @@ -152,6 +151,8 @@ def __init__(
@classmethod
def assert_task_args(cls, args: DictConfig, shared_state: "SharedStaticTaskState"):
"""Ensure that the data can be properly loaded"""
super().assert_task_args(args, shared_state)

blue_args = args.blueprint
if blue_args.get("data_csv", None) is not None:
csv_file = os.path.expanduser(blue_args.data_csv)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,17 @@ class OnboardingRequired(BlueprintMixin):
Compositional class for blueprints that may have an onboarding step
"""

ArgsMixin = OnboardingRequiredArgs
SharedStateMixin = OnboardingSharedState

def init_mixin_config(
self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState"
) -> None:
"""Method to initialize any required attributes to make this mixin function"""
self.init_onboarding_config(task_run, args, shared_state)

@classmethod
def assert_task_args(
def assert_mixin_args(
cls, args: "DictConfig", shared_state: "SharedTaskState"
) -> None:
"""Method to validate the incoming args and throw if something won't work"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class ScreenTaskRequired(BlueprintMixin):
qualify workers who have never attempted the task before
"""

ArgsMixin = ScreenTaskRequiredArgs
SharedStateMixin = ScreenTaskSharedState

def init_mixin_config(
self,
task_run: "TaskRun",
Expand Down Expand Up @@ -113,7 +116,7 @@ def init_screening_config(
find_or_create_qualification(task_run.db, self.failed_qualification_name)

@classmethod
def assert_task_args(cls, args: "DictConfig", shared_state: "SharedTaskState"):
def assert_mixin_args(cls, args: "DictConfig", shared_state: "SharedTaskState"):
use_screening_task = args.blueprint.get("use_screening_task", False)
if not use_screening_task:
return
Expand Down
3 changes: 0 additions & 3 deletions mephisto/abstractions/blueprints/mock/mock_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ def __init__(
self, task_run: "TaskRun", args: "DictConfig", shared_state: "MockSharedState"
):
super().__init__(task_run, args, shared_state)
# TODO these can be done with self.mro() and using the mixin variant
self.init_onboarding_config(task_run, args, shared_state)
self.init_screening_config(task_run, args, shared_state)

def get_initialization_data(self) -> Iterable[InitializationData]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@


@dataclass
class SharedParlAITaskState(SharedTaskState, OnboardingSharedState):
class SharedParlAITaskState(OnboardingRequired.SharedStateMixin, SharedTaskState):
frontend_task_opts: Dict[str, Any] = field(default_factory=dict)
world_opt: Dict[str, Any] = field(default_factory=dict)
onboarding_world_opt: Dict[str, Any] = field(default_factory=dict)
world_module: Optional[Any] = None


@dataclass
class ParlAIChatBlueprintArgs(BlueprintArgs):
class ParlAIChatBlueprintArgs(OnboardingRequired.ArgsMixin, BlueprintArgs):
_blueprint_type: str = BLUEPRINT_TYPE
_group: str = field(
default="ParlAIChatBlueprint",
Expand Down Expand Up @@ -130,7 +130,7 @@ class ParlAIChatBlueprintArgs(BlueprintArgs):


@register_mephisto_abstraction()
class ParlAIChatBlueprint(Blueprint, OnboardingRequired):
class ParlAIChatBlueprint(OnboardingRequired, Blueprint):
"""Blueprint for a task that runs a parlai chat"""

AgentStateClass: ClassVar[Type["AgentState"]] = ParlAIChatAgentState
Expand All @@ -154,7 +154,6 @@ def __init__(
):
super().__init__(task_run, args, shared_state)
self._initialization_data_dicts: List[Dict[str, Any]] = []
self.init_onboarding_config(task_run, args, shared_state)

if args.blueprint.get("context_csv", None) is not None:
csv_file = os.path.expanduser(args.blueprint.context_csv)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(
@classmethod
def assert_task_args(cls, args: DictConfig, shared_state: "SharedTaskState"):
"""Ensure that the data can be properly loaded"""
Blueprint.assert_task_args(args, shared_state)
blue_args = args.blueprint
if isinstance(shared_state.static_task_data, types.GeneratorType):
raise AssertionError("You can't launch an HTML static task on a generator")
Expand Down
Loading