Skip to content

Commit

Permalink
Adds download command to remote. (#1946)
Browse files Browse the repository at this point in the history
  • Loading branch information
kumare3 authored Nov 9, 2023
1 parent c7c8289 commit 6998304
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 55 deletions.
57 changes: 2 additions & 55 deletions flytekit/clis/sdk_in_container/fetch.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,17 @@
import os
import pathlib
import typing

import rich_click as click
from google.protobuf.json_format import MessageToJson
from rich import print
from rich.panel import Panel
from rich.pretty import Pretty

from flytekit import BlobType, FlyteContext, Literal
from flytekit import Literal
from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context
from flytekit.core.type_engine import LiteralsResolver
from flytekit.interaction.rich_utils import RichCallback
from flytekit.interaction.string_literals import literal_map_string_repr, literal_string_repr
from flytekit.remote import FlyteRemote


def download_literal(var: str, data: Literal, download_to: typing.Optional[pathlib.Path] = None):
"""
Download a single literal to a file, if it is a blob or structured dataset.
"""
if data is None:
print(f"Skipping {var} as it is None.")
return
if data.scalar:
if data.scalar and (data.scalar.blob or data.scalar.structured_dataset):
uri = data.scalar.blob.uri if data.scalar.blob else data.scalar.structured_dataset.uri
if uri is None:
print("No data to download.")
return
is_multipart = False
if data.scalar.blob:
is_multipart = data.scalar.blob.metadata.type.dimensionality == BlobType.BlobDimensionality.MULTIPART
elif data.scalar.structured_dataset:
is_multipart = True
FlyteContext.current_context().file_access.get_data(
uri, str(download_to / var) + os.sep, is_multipart=is_multipart, callback=RichCallback()
)
elif data.scalar.union is not None:
download_literal(var, data.scalar.union.value, download_to)
elif data.scalar.generic is not None:
with open(download_to / f"{var}.json", "w") as f:
f.write(MessageToJson(data.scalar.generic))
else:
print(
f"[dim]Skipping {var} val {literal_string_repr(data)} as it is not a blob, structured dataset,"
f" or generic type.[/dim]"
)
return
elif data.collection:
for i, v in enumerate(data.collection.literals):
download_literal(f"{i}", v, download_to / var)
elif data.map:
download_to = pathlib.Path(download_to)
for k, v in data.map.literals.items():
download_literal(f"{k}", v, download_to / var)
print(f"Downloaded f{var} to {download_to}")


@click.command("fetch")
@click.option(
"--recursive",
Expand Down Expand Up @@ -91,11 +45,4 @@ def fetch(ctx: click.Context, recursive: bool, flyte_data_uri: str, download_to:
panel = Panel(pretty)
print(panel)
if download_to:
download_to = pathlib.Path(download_to)
if isinstance(data, Literal):
download_literal("data", data, download_to)
else:
if not recursive:
raise click.UsageError("Please specify --recursive to download all variables in a literal map.")
for var, literal in data.literals.items():
download_literal(var, literal, download_to)
remote.download(data, download_to, recursive=recursive)
55 changes: 55 additions & 0 deletions flytekit/remote/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import pathlib
import typing

from google.protobuf.json_format import MessageToJson
from rich import print

from flytekit import BlobType, Literal
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.interaction.rich_utils import RichCallback
from flytekit.interaction.string_literals import literal_string_repr


def download_literal(
file_access: FileAccessProvider, var: str, data: Literal, download_to: typing.Optional[pathlib.Path] = None
):
"""
Download a single literal to a file, if it is a blob or structured dataset.
"""
if data is None:
print(f"Skipping {var} as it is None.")
return
if data.scalar:
if data.scalar and (data.scalar.blob or data.scalar.structured_dataset):
uri = data.scalar.blob.uri if data.scalar.blob else data.scalar.structured_dataset.uri
if uri is None:
print("No data to download.")
return
is_multipart = False
if data.scalar.blob:
is_multipart = data.scalar.blob.metadata.type.dimensionality == BlobType.BlobDimensionality.MULTIPART
elif data.scalar.structured_dataset:
is_multipart = True
file_access.get_data(
uri, str(download_to / var) + os.sep, is_multipart=is_multipart, callback=RichCallback()
)
elif data.scalar.union is not None:
download_literal(file_access, var, data.scalar.union.value, download_to)
elif data.scalar.generic is not None:
with open(download_to / f"{var}.json", "w") as f:
f.write(MessageToJson(data.scalar.generic))
else:
print(
f"[dim]Skipping {var} val {literal_string_repr(data)} as it is not a blob, structured dataset,"
f" or generic type.[/dim]"
)
return
elif data.collection:
for i, v in enumerate(data.collection.literals):
download_literal(file_access, f"{i}", v, download_to / var)
elif data.map:
download_to = pathlib.Path(download_to)
for k, v in data.map.literals.items():
download_literal(file_access, f"{k}", v, download_to / var)
print(f"Downloaded f{var} to {download_to}")
31 changes: 31 additions & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta

import click
import fsspec
import requests
from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest
Expand Down Expand Up @@ -66,6 +67,7 @@
from flytekit.models.launch_plan import LaunchPlanState
from flytekit.models.literals import Literal, LiteralMap
from flytekit.remote.backfill import create_backfill_workflow
from flytekit.remote.data import download_literal
from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow
from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution
from flytekit.remote.interface import TypedInterface
Expand Down Expand Up @@ -2008,3 +2010,32 @@ def activate_launchplan(self, ident: Identifier):
Given a launchplan, activate it, all previous versions are deactivated.
"""
self.client.update_launch_plan(id=ident, state=LaunchPlanState.ACTIVE)

def download(
self, data: typing.Union[LiteralsResolver, Literal, LiteralMap], download_to: str, recursive: bool = True
):
"""
Download the data to the specified location. If the data is a LiteralsResolver, LiteralMap and if recursive is
specified, then all file like objects will be recursively downloaded (e.g. FlyteFile/Dir (blob),
StructuredDataset etc).
Note: That it will use your sessions credentials to access the remote location. For sandbox, this should be
automatically configured, assuming you are running sandbox locally. For other environments, you will need to
configure your credentials appropriately.
:param data: data to be downloaded
:param download_to: location to download to (str) that should be a valid path
:param recursive: if the data is a LiteralsResolver or LiteralMap, then this flag will recursively download
"""
download_to = pathlib.Path(download_to)
if isinstance(data, Literal):
download_literal(self.file_access, "data", data, download_to)
else:
if not recursive:
raise click.UsageError("Please specify --recursive to download all variables in a literal map.")
if isinstance(data, LiteralsResolver):
lm = data.literals
else:
lm = data
for var, literal in lm.items():
download_literal(self.file_access, var, literal, download_to)

0 comments on commit 6998304

Please sign in to comment.