Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOM-55317] Add identity to task execution metadata #2

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,18 @@ def setup_execution(
task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version),
)

metadata = {
"flyte-execution-project": exe_project,
"flyte-execution-domain": exe_domain,
"flyte-execution-launchplan": exe_lp,
"flyte-execution-workflow": exe_wf,
"flyte-execution-name": exe_name,
}
try:
file_access = FileAccessProvider(
local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"),
raw_output_prefix=raw_output_data_prefix,
execution_metadata=metadata,
)
except TypeError: # would be thrown from DataPersistencePlugins.find_plugin
logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}")
Expand Down
20 changes: 20 additions & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,24 @@ def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig:
return GCSConfig(**kwargs)


@dataclass(init=True, repr=True, eq=True, frozen=True)
class GenericPersistenceConfig(object):
"""
Data storage configuration that applies across any provider.
"""

attach_execution_metadata: bool = True

@classmethod
def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig:
config_file = get_config_file(config_file)
kwargs = {}
kwargs = set_if_exists(
kwargs, "attach_execution_metadata", _internal.Persistence.ATTACH_EXECUTION_METADATA.read(config_file)
)
return GenericPersistenceConfig(**kwargs)


@dataclass(init=True, repr=True, eq=True, frozen=True)
class AzureBlobStorageConfig(object):
"""
Expand Down Expand Up @@ -631,6 +649,7 @@ class DataConfig(object):
s3: S3Config = S3Config()
gcs: GCSConfig = GCSConfig()
azure: AzureBlobStorageConfig = AzureBlobStorageConfig()
generic: GenericPersistenceConfig = GenericPersistenceConfig()

@classmethod
def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> DataConfig:
Expand All @@ -639,6 +658,7 @@ def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> DataConfig:
azure=AzureBlobStorageConfig.auto(config_file),
s3=S3Config.auto(config_file),
gcs=GCSConfig.auto(config_file),
generic=GenericPersistenceConfig.auto(config_file),
)


Expand Down
5 changes: 5 additions & 0 deletions flytekit/configuration/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def get_specified_images(cfg: typing.Optional[ConfigFile]) -> typing.Dict[str, s
return cfg.yaml_config.get("images", images)


class Persistence(object):
SECTION = "persistence"
ATTACH_EXECUTION_METADATA = ConfigEntry(LegacyConfigEntry(SECTION, "attach_execution_metadata", bool))


class AWS(object):
SECTION = "aws"
S3_ENDPOINT = ConfigEntry(LegacyConfigEntry(SECTION, "endpoint"), YamlConfigEntry("storage.connection.endpoint"))
Expand Down
10 changes: 10 additions & 0 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
local_sandbox_dir: Union[str, os.PathLike],
raw_output_prefix: str,
data_config: typing.Optional[DataConfig] = None,
execution_metadata: typing.Optional[dict] = None,
):
"""
Args:
Expand All @@ -148,6 +149,11 @@ def __init__(
self._local = fsspec.filesystem(None)

self._data_config = data_config if data_config else DataConfig.auto()

if self.data_config.generic.attach_execution_metadata:
self._execution_metadata = execution_metadata
else:
self._execution_metadata = None
self._default_protocol = get_protocol(str(raw_output_prefix))
self._default_remote = cast(fsspec.AbstractFileSystem, self.get_filesystem(self._default_protocol))
if os.name == "nt" and raw_output_prefix.startswith("file://"):
Expand Down Expand Up @@ -308,6 +314,10 @@ def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs):
self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True
)
from_path, to_path = self.recursive_paths(from_path, to_path)
if self._execution_metadata:
if "metadata" not in kwargs:
kwargs["metadata"] = {}
kwargs["metadata"].update(self._execution_metadata)
dst = file_system.put(from_path, to_path, recursive=recursive, **kwargs)
if isinstance(dst, (str, pathlib.Path)):
return dst
Expand Down
3 changes: 2 additions & 1 deletion flytekit/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def short_string(self):
:rtype: Text
"""
literal_str = re.sub(r"\s+", " ", str(self.to_flyte_idl())).strip()
return f"<FlyteLiteral {literal_str}>"
type_str = type(self).__name__
return f"<FlyteLiteral({type_str}) {literal_str}>"

def verbose_string(self):
"""
Expand Down
3 changes: 3 additions & 0 deletions flytekit/models/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,14 @@ class Identity(_common.FlyteIdlEntity):
iam_role: Optional[str] = None
k8s_service_account: Optional[str] = None
oauth2_client: Optional[OAuth2Client] = None
execution_identity: Optional[str] = None

def to_flyte_idl(self) -> _sec.Identity:
return _sec.Identity(
iam_role=self.iam_role if self.iam_role else None,
k8s_service_account=self.k8s_service_account if self.k8s_service_account else None,
oauth2_client=self.oauth2_client.to_flyte_idl() if self.oauth2_client else None,
execution_identity=self.execution_identity if self.execution_identity else None,
)

@classmethod
Expand All @@ -108,6 +110,7 @@ def from_flyte_idl(cls, pb2_object: _sec.Identity) -> "Identity":
oauth2_client=OAuth2Client.from_flyte_idl(pb2_object.oauth2_client)
if pb2_object.oauth2_client and pb2_object.oauth2_client.ByteSize()
else None,
execution_identity=pb2_object.execution_identity if pb2_object.execution_identity else None,
)


Expand Down
9 changes: 9 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def __init__(
annotations,
k8s_service_account,
environment_variables,
identity,
):
"""
Runtime task execution metadata.
Expand All @@ -539,13 +540,15 @@ def __init__(
:param dict[str, str] annotations: Annotations to use for the execution of this task.
:param Text k8s_service_account: Service account to use for execution of this task.
:param dict[str, str] environment_variables: Environment variables for this task.
:param flytekit.models.security.Identity identity: Identity of user executing this task
"""
self._task_execution_id = task_execution_id
self._namespace = namespace
self._labels = labels
self._annotations = annotations
self._k8s_service_account = k8s_service_account
self._environment_variables = environment_variables
self._identity = identity

@property
def task_execution_id(self):
Expand All @@ -571,6 +574,10 @@ def k8s_service_account(self):
def environment_variables(self):
return self._environment_variables

@property
def identity(self):
return self._identity

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.agent_pb2.TaskExecutionMetadata
Expand All @@ -584,6 +591,7 @@ def to_flyte_idl(self):
environment_variables={k: v for k, v in self.environment_variables.items()}
if self.labels is not None
else None,
identity=self.identity.to_flyte_idl() if self.identity else None,
)
return task_execution_metadata

Expand All @@ -604,6 +612,7 @@ def from_flyte_idl(cls, pb2_object):
environment_variables={k: v for k, v in pb2_object.environment_variables.items()}
if pb2_object.environment_variables is not None
else None,
identity=_sec.Identity.from_flyte_idl(pb2_object.identity) if pb2_object.identity else None,
)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"diskcache>=5.2.1",
"docker>=4.0.0,<7.0.0",
"docstring-parser>=0.9.0",
"flyteidl>=1.11.0b1",
"flyteidl>=1.12.0",
"fsspec>=2023.3.0",
"gcsfs>=2023.3.0",
"googleapis-common-protos>=1.57",
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def wf(i: int, j: int):

# without providing the _inputs_not_allowed or _ignorable_inputs, all inputs to lp become required,
# which is incorrect
with pytest.raises(FlyteAssertion, match="Missing input `i` type `<FlyteLiteral simple: INTEGER>`"):
with pytest.raises(FlyteAssertion, match=r"Missing input `i` type `<FlyteLiteral\(LiteralType\) simple: INTEGER>`"):
create_and_link_node_from_remote(ctx, lp)

# Even if j is not provided it will default
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,7 +1701,7 @@ def wf2(a: typing.Union[int, str]) -> typing.Union[int, str]:
match=re.escape(
"Error encountered while executing 'wf2':\n"
f" Failed to convert inputs of task '{prefix}tests.flytekit.unit.core.test_type_hints.t2':\n"
' Cannot convert from <FlyteLiteral scalar { union { value { scalar { primitive { string_value: "2" } } } '
' Cannot convert from <FlyteLiteral(Literal) scalar { union { value { scalar { primitive { string_value: "2" } } } '
'type { simple: STRING structure { tag: "str" } } } }> to typing.Union[float, dict] (using tag str)'
),
):
Expand Down
2 changes: 2 additions & 0 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
WorkflowExecutionIdentifier,
)
from flytekit.models.literals import LiteralMap
from flytekit.models.security import Identity
from flytekit.models.task import TaskExecutionMetadata, TaskTemplate
from flytekit.tools.translator import get_serializable

Expand Down Expand Up @@ -159,6 +160,7 @@ def simple_task(i: int):
annotations={"annotation_key": "annotation_val"},
k8s_service_account="k8s service account",
environment_variables={"env_var_key": "env_var_val"},
identity=Identity(execution_identity="task executor"),
)


Expand Down
7 changes: 7 additions & 0 deletions tests/flytekit/unit/models/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,10 @@ def test_auth_role_empty():
x = obj.to_flyte_idl()
y = _common.AuthRole.from_flyte_idl(x)
assert y == obj


def test_short_string_raw_output_data_config():
""""""
obj = _common.RawOutputDataConfig("s3://bucket")
assert "FlyteLiteral(RawOutputDataConfig)" in obj.short_string()
assert "FlyteLiteral(RawOutputDataConfig)" in repr(obj)
Loading