Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(core): Improve task module extraction logic #2290

Merged
merged 7 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion flytekit/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,11 @@
if mod_name == "__main__":
if hasattr(f, "task_function"):
f = f.task_function
# If the module is __main__, we need to find the actual module name based on the file path
inspect_file = inspect.getfile(f) # type: ignore
return name, "", name, os.path.abspath(inspect_file)
file_name, _ = os.path.splitext(os.path.basename(inspect_file))
mod_name = get_full_module_path(f, file_name) # type: ignore
return name, mod_name, name, os.path.abspath(inspect_file)

Check warning on line 337 in flytekit/core/tracker.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/tracker.py#L335-L337

Added lines #L335 - L337 were not covered by tests

mod_name = get_full_module_path(mod, mod_name)
return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(mod))
Expand Down
18 changes: 16 additions & 2 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.reference_entity import ReferenceSpec
from flytekit.core.task import ReferenceTask
from flytekit.core.tracker import extract_task_module
from flytekit.core.type_engine import LiteralsResolver, TypeEngine
from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase, WorkflowFailurePolicy
from flytekit.exceptions import user as user_exceptions
Expand Down Expand Up @@ -82,7 +83,7 @@
from flytekit.remote.remote_fs import get_flyte_fs
from flytekit.tools.fast_registration import fast_package
from flytekit.tools.interactive import ipython_check
from flytekit.tools.script_mode import compress_scripts, hash_file
from flytekit.tools.script_mode import _find_project_root, compress_scripts, hash_file
from flytekit.tools.translator import (
FlyteControlPlaneEntity,
FlyteLocalEntity,
Expand Down Expand Up @@ -778,7 +779,10 @@
return ident

def register_task(
self, entity: PythonTask, serialization_settings: SerializationSettings, version: typing.Optional[str] = None
self,
entity: PythonTask,
serialization_settings: typing.Optional[SerializationSettings] = None,
version: typing.Optional[str] = None,
) -> FlyteTask:
"""
Register a qualified task (PythonTask) with Remote
Expand All @@ -789,6 +793,16 @@
:param version: version that will be used to register. If not specified will default to using the serialization settings default
:return:
"""
# Create a default serialization settings object if not provided
# It makes registration easier for the user
if serialization_settings is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment here explaining when this situation is relevant? and maybe also a comment in tracker.py explaining when the second element in the return tuple (the thing that used to be "") is relevant?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

_, _, _, module_file = extract_task_module(entity)
project_root = _find_project_root(module_file)
serialization_settings = SerializationSettings(

Check warning on line 801 in flytekit/remote/remote.py

View check run for this annotation

Codecov / codecov/patch

flytekit/remote/remote.py#L799-L801

Added lines #L799 - L801 were not covered by tests
image_config=ImageConfig.auto_default_image(),
source_root=project_root,
)

ident = self._serialize_and_register(entity=entity, settings=serialization_settings, version=version)
ft = self.fetch_task(
ident.project,
Expand Down
4 changes: 4 additions & 0 deletions plugins/flytekit-spark/tests/test_remote_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@ def my_python_task(a: str) -> int:

# Check if the serialized python task has no mainApplicaitonFile field set by default.
assert serialized_spec.template.custom is None

remote.register_task(my_python_task, version="v1")
serialized_spec = mock_client.create_task.call_args.kwargs["task_spec"]
assert serialized_spec.template.custom is None
Loading