diff --git a/flytekit/clis/sdk_in_container/fetch.py b/flytekit/clis/sdk_in_container/fetch.py index 580b0fcaac..8c83b5d548 100644 --- a/flytekit/clis/sdk_in_container/fetch.py +++ b/flytekit/clis/sdk_in_container/fetch.py @@ -1,63 +1,17 @@ -import os -import pathlib import typing import rich_click as click -from google.protobuf.json_format import MessageToJson from rich import print from rich.panel import Panel from rich.pretty import Pretty -from flytekit import BlobType, FlyteContext, Literal +from flytekit import Literal from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context from flytekit.core.type_engine import LiteralsResolver -from flytekit.interaction.rich_utils import RichCallback from flytekit.interaction.string_literals import literal_map_string_repr, literal_string_repr from flytekit.remote import FlyteRemote -def download_literal(var: str, data: Literal, download_to: typing.Optional[pathlib.Path] = None): - """ - Download a single literal to a file, if it is a blob or structured dataset. - """ - if data is None: - print(f"Skipping {var} as it is None.") - return - if data.scalar: - if data.scalar and (data.scalar.blob or data.scalar.structured_dataset): - uri = data.scalar.blob.uri if data.scalar.blob else data.scalar.structured_dataset.uri - if uri is None: - print("No data to download.") - return - is_multipart = False - if data.scalar.blob: - is_multipart = data.scalar.blob.metadata.type.dimensionality == BlobType.BlobDimensionality.MULTIPART - elif data.scalar.structured_dataset: - is_multipart = True - FlyteContext.current_context().file_access.get_data( - uri, str(download_to / var) + os.sep, is_multipart=is_multipart, callback=RichCallback() - ) - elif data.scalar.union is not None: - download_literal(var, data.scalar.union.value, download_to) - elif data.scalar.generic is not None: - with open(download_to / f"{var}.json", "w") as f: - f.write(MessageToJson(data.scalar.generic)) - else: - print( - f"[dim]Skipping {var} val {literal_string_repr(data)} as it is not a blob, structured dataset," - f" or generic type.[/dim]" - ) - return - elif data.collection: - for i, v in enumerate(data.collection.literals): - download_literal(f"{i}", v, download_to / var) - elif data.map: - download_to = pathlib.Path(download_to) - for k, v in data.map.literals.items(): - download_literal(f"{k}", v, download_to / var) - print(f"Downloaded f{var} to {download_to}") - - @click.command("fetch") @click.option( "--recursive", @@ -91,11 +45,4 @@ def fetch(ctx: click.Context, recursive: bool, flyte_data_uri: str, download_to: panel = Panel(pretty) print(panel) if download_to: - download_to = pathlib.Path(download_to) - if isinstance(data, Literal): - download_literal("data", data, download_to) - else: - if not recursive: - raise click.UsageError("Please specify --recursive to download all variables in a literal map.") - for var, literal in data.literals.items(): - download_literal(var, literal, download_to) + remote.download(data, download_to, recursive=recursive) diff --git a/flytekit/remote/data.py b/flytekit/remote/data.py new file mode 100644 index 0000000000..84fcff1420 --- /dev/null +++ b/flytekit/remote/data.py @@ -0,0 +1,55 @@ +import os +import pathlib +import typing + +from google.protobuf.json_format import MessageToJson +from rich import print + +from flytekit import BlobType, Literal +from flytekit.core.data_persistence import FileAccessProvider +from flytekit.interaction.rich_utils import RichCallback +from flytekit.interaction.string_literals import literal_string_repr + + +def download_literal( + file_access: FileAccessProvider, var: str, data: Literal, download_to: typing.Optional[pathlib.Path] = None +): + """ + Download a single literal to a file, if it is a blob or structured dataset. + """ + if data is None: + print(f"Skipping {var} as it is None.") + return + if data.scalar: + if data.scalar and (data.scalar.blob or data.scalar.structured_dataset): + uri = data.scalar.blob.uri if data.scalar.blob else data.scalar.structured_dataset.uri + if uri is None: + print("No data to download.") + return + is_multipart = False + if data.scalar.blob: + is_multipart = data.scalar.blob.metadata.type.dimensionality == BlobType.BlobDimensionality.MULTIPART + elif data.scalar.structured_dataset: + is_multipart = True + file_access.get_data( + uri, str(download_to / var) + os.sep, is_multipart=is_multipart, callback=RichCallback() + ) + elif data.scalar.union is not None: + download_literal(file_access, var, data.scalar.union.value, download_to) + elif data.scalar.generic is not None: + with open(download_to / f"{var}.json", "w") as f: + f.write(MessageToJson(data.scalar.generic)) + else: + print( + f"[dim]Skipping {var} val {literal_string_repr(data)} as it is not a blob, structured dataset," + f" or generic type.[/dim]" + ) + return + elif data.collection: + for i, v in enumerate(data.collection.literals): + download_literal(file_access, f"{i}", v, download_to / var) + elif data.map: + download_to = pathlib.Path(download_to) + for k, v in data.map.literals.items(): + download_literal(file_access, f"{k}", v, download_to / var) + print(f"Downloaded f{var} to {download_to}") diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 213f96a2eb..618f8d9bda 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -19,6 +19,7 @@ from dataclasses import asdict, dataclass from datetime import datetime, timedelta +import click import fsspec import requests from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest @@ -66,6 +67,7 @@ from flytekit.models.launch_plan import LaunchPlanState from flytekit.models.literals import Literal, LiteralMap from flytekit.remote.backfill import create_backfill_workflow +from flytekit.remote.data import download_literal from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution from flytekit.remote.interface import TypedInterface @@ -2008,3 +2010,32 @@ def activate_launchplan(self, ident: Identifier): Given a launchplan, activate it, all previous versions are deactivated. """ self.client.update_launch_plan(id=ident, state=LaunchPlanState.ACTIVE) + + def download( + self, data: typing.Union[LiteralsResolver, Literal, LiteralMap], download_to: str, recursive: bool = True + ): + """ + Download the data to the specified location. If the data is a LiteralsResolver, LiteralMap and if recursive is + specified, then all file like objects will be recursively downloaded (e.g. FlyteFile/Dir (blob), + StructuredDataset etc). + + Note: That it will use your sessions credentials to access the remote location. For sandbox, this should be + automatically configured, assuming you are running sandbox locally. For other environments, you will need to + configure your credentials appropriately. + + :param data: data to be downloaded + :param download_to: location to download to (str) that should be a valid path + :param recursive: if the data is a LiteralsResolver or LiteralMap, then this flag will recursively download + """ + download_to = pathlib.Path(download_to) + if isinstance(data, Literal): + download_literal(self.file_access, "data", data, download_to) + else: + if not recursive: + raise click.UsageError("Please specify --recursive to download all variables in a literal map.") + if isinstance(data, LiteralsResolver): + lm = data.literals + else: + lm = data + for var, literal in lm.items(): + download_literal(self.file_access, var, literal, download_to)