diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index ff276da43e..ed20561ac3 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -272,10 +272,18 @@ def setup_execution( task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version), ) + metadata = { + "flyte-execution-project": exe_project, + "flyte-execution-domain": exe_domain, + "flyte-execution-launchplan": exe_lp, + "flyte-execution-workflow": exe_wf, + "flyte-execution-name": exe_name, + } try: file_access = FileAccessProvider( local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"), raw_output_prefix=raw_output_data_prefix, + execution_metadata=metadata, ) except TypeError: # would be thrown from DataPersistencePlugins.find_plugin logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}") diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 5d7c975f32..c29bd71c88 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -596,6 +596,24 @@ def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: return GCSConfig(**kwargs) +@dataclass(init=True, repr=True, eq=True, frozen=True) +class GenericPersistenceConfig(object): + """ + Data storage configuration that applies across any provider. + """ + + attach_execution_metadata: bool = True + + @classmethod + def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: + config_file = get_config_file(config_file) + kwargs = {} + kwargs = set_if_exists( + kwargs, "attach_execution_metadata", _internal.Persistence.ATTACH_EXECUTION_METADATA.read(config_file) + ) + return GenericPersistenceConfig(**kwargs) + + @dataclass(init=True, repr=True, eq=True, frozen=True) class AzureBlobStorageConfig(object): """ @@ -631,6 +649,7 @@ class DataConfig(object): s3: S3Config = S3Config() gcs: GCSConfig = GCSConfig() azure: AzureBlobStorageConfig = AzureBlobStorageConfig() + generic: GenericPersistenceConfig = GenericPersistenceConfig() @classmethod def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> DataConfig: @@ -639,6 +658,7 @@ def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> DataConfig: azure=AzureBlobStorageConfig.auto(config_file), s3=S3Config.auto(config_file), gcs=GCSConfig.auto(config_file), + generic=GenericPersistenceConfig.auto(config_file), ) diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index 2f28381782..c93e65e635 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -35,6 +35,11 @@ def get_specified_images(cfg: typing.Optional[ConfigFile]) -> typing.Dict[str, s return cfg.yaml_config.get("images", images) +class Persistence(object): + SECTION = "persistence" + ATTACH_EXECUTION_METADATA = ConfigEntry(LegacyConfigEntry(SECTION, "attach_execution_metadata", bool)) + + class AWS(object): SECTION = "aws" S3_ENDPOINT = ConfigEntry(LegacyConfigEntry(SECTION, "endpoint"), YamlConfigEntry("storage.connection.endpoint")) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index f507e491b1..492362b819 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -132,6 +132,7 @@ def __init__( local_sandbox_dir: Union[str, os.PathLike], raw_output_prefix: str, data_config: typing.Optional[DataConfig] = None, + execution_metadata: typing.Optional[dict] = None, ): """ Args: @@ -148,6 +149,11 @@ def __init__( self._local = fsspec.filesystem(None) self._data_config = data_config if data_config else DataConfig.auto() + + if self.data_config.generic.attach_execution_metadata: + self._execution_metadata = execution_metadata + else: + self._execution_metadata = None self._default_protocol = get_protocol(str(raw_output_prefix)) self._default_remote = cast(fsspec.AbstractFileSystem, self.get_filesystem(self._default_protocol)) if os.name == "nt" and raw_output_prefix.startswith("file://"): @@ -308,6 +314,10 @@ def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True ) from_path, to_path = self.recursive_paths(from_path, to_path) + if self._execution_metadata: + if "metadata" not in kwargs: + kwargs["metadata"] = {} + kwargs["metadata"].update(self._execution_metadata) dst = file_system.put(from_path, to_path, recursive=recursive, **kwargs) if isinstance(dst, (str, pathlib.Path)): return dst diff --git a/flytekit/models/common.py b/flytekit/models/common.py index f0a0bd3385..8fdd0837a8 100644 --- a/flytekit/models/common.py +++ b/flytekit/models/common.py @@ -61,7 +61,8 @@ def short_string(self): :rtype: Text """ literal_str = re.sub(r"\s+", " ", str(self.to_flyte_idl())).strip() - return f"" + type_str = type(self).__name__ + return f"" def verbose_string(self): """ diff --git a/flytekit/models/security.py b/flytekit/models/security.py index 748b4d09eb..e210c910b7 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -92,12 +92,14 @@ class Identity(_common.FlyteIdlEntity): iam_role: Optional[str] = None k8s_service_account: Optional[str] = None oauth2_client: Optional[OAuth2Client] = None + execution_identity: Optional[str] = None def to_flyte_idl(self) -> _sec.Identity: return _sec.Identity( iam_role=self.iam_role if self.iam_role else None, k8s_service_account=self.k8s_service_account if self.k8s_service_account else None, oauth2_client=self.oauth2_client.to_flyte_idl() if self.oauth2_client else None, + execution_identity=self.execution_identity if self.execution_identity else None, ) @classmethod @@ -108,6 +110,7 @@ def from_flyte_idl(cls, pb2_object: _sec.Identity) -> "Identity": oauth2_client=OAuth2Client.from_flyte_idl(pb2_object.oauth2_client) if pb2_object.oauth2_client and pb2_object.oauth2_client.ByteSize() else None, + execution_identity=pb2_object.execution_identity if pb2_object.execution_identity else None, ) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 5072f03757..0532b276e2 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -528,6 +528,7 @@ def __init__( annotations, k8s_service_account, environment_variables, + identity, ): """ Runtime task execution metadata. @@ -539,6 +540,7 @@ def __init__( :param dict[str, str] annotations: Annotations to use for the execution of this task. :param Text k8s_service_account: Service account to use for execution of this task. :param dict[str, str] environment_variables: Environment variables for this task. + :param flytekit.models.security.Identity identity: Identity of user executing this task """ self._task_execution_id = task_execution_id self._namespace = namespace @@ -546,6 +548,7 @@ def __init__( self._annotations = annotations self._k8s_service_account = k8s_service_account self._environment_variables = environment_variables + self._identity = identity @property def task_execution_id(self): @@ -571,6 +574,10 @@ def k8s_service_account(self): def environment_variables(self): return self._environment_variables + @property + def identity(self): + return self._identity + def to_flyte_idl(self): """ :rtype: flyteidl.admin.agent_pb2.TaskExecutionMetadata @@ -584,6 +591,7 @@ def to_flyte_idl(self): environment_variables={k: v for k, v in self.environment_variables.items()} if self.labels is not None else None, + identity=self.identity.to_flyte_idl() if self.identity else None, ) return task_execution_metadata @@ -604,6 +612,7 @@ def from_flyte_idl(cls, pb2_object): environment_variables={k: v for k, v in pb2_object.environment_variables.items()} if pb2_object.environment_variables is not None else None, + identity=_sec.Identity.from_flyte_idl(pb2_object.identity) if pb2_object.identity else None, ) diff --git a/pyproject.toml b/pyproject.toml index e3d311a73a..ad2fbb7d6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0,<7.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.11.0b1", + "flyteidl>=1.12.0", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 654bff0294..74f3db99e3 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -84,7 +84,7 @@ def wf(i: int, j: int): # without providing the _inputs_not_allowed or _ignorable_inputs, all inputs to lp become required, # which is incorrect - with pytest.raises(FlyteAssertion, match="Missing input `i` type ``"): + with pytest.raises(FlyteAssertion, match=r"Missing input `i` type ``"): create_and_link_node_from_remote(ctx, lp) # Even if j is not provided it will default diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 3fa9579bd5..e3b0978565 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -1701,7 +1701,7 @@ def wf2(a: typing.Union[int, str]) -> typing.Union[int, str]: match=re.escape( "Error encountered while executing 'wf2':\n" f" Failed to convert inputs of task '{prefix}tests.flytekit.unit.core.test_type_hints.t2':\n" - ' Cannot convert from to typing.Union[float, dict] (using tag str)' ), ): diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 17db5c2788..3226313079 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -47,6 +47,7 @@ WorkflowExecutionIdentifier, ) from flytekit.models.literals import LiteralMap +from flytekit.models.security import Identity from flytekit.models.task import TaskExecutionMetadata, TaskTemplate from flytekit.tools.translator import get_serializable @@ -159,6 +160,7 @@ def simple_task(i: int): annotations={"annotation_key": "annotation_val"}, k8s_service_account="k8s service account", environment_variables={"env_var_key": "env_var_val"}, + identity=Identity(execution_identity="task executor"), ) diff --git a/tests/flytekit/unit/models/test_common.py b/tests/flytekit/unit/models/test_common.py index 48e8c0ba1f..b3754c16ff 100644 --- a/tests/flytekit/unit/models/test_common.py +++ b/tests/flytekit/unit/models/test_common.py @@ -103,3 +103,10 @@ def test_auth_role_empty(): x = obj.to_flyte_idl() y = _common.AuthRole.from_flyte_idl(x) assert y == obj + + +def test_short_string_raw_output_data_config(): + """""" + obj = _common.RawOutputDataConfig("s3://bucket") + assert "FlyteLiteral(RawOutputDataConfig)" in obj.short_string() + assert "FlyteLiteral(RawOutputDataConfig)" in repr(obj)