From e6c083f9cc0e85a9ec3c6767d4459df537aa235a Mon Sep 17 00:00:00 2001 From: Rishi Ravikumar <38955457+RRK1000@users.noreply.github.com> Date: Sun, 16 Jun 2024 05:12:42 -0400 Subject: [PATCH] fix: include remote file paths with special characters (#2478) Signed-off-by: Rishi Ravikumar Signed-off-by: bugra.gedik --- flytekit/types/file/file.py | 3 ++- tests/flytekit/unit/core/test_flyte_file.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 5720753e6f..8561ec0157 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -6,6 +6,7 @@ import typing from contextlib import contextmanager from dataclasses import dataclass, field +from urllib.parse import unquote from dataclasses_json import config from marshmallow import fields @@ -468,7 +469,7 @@ def to_literal( remote_path = ctx.file_access.put_data(source_path, remote_path, is_multipart=False, **headers) else: remote_path = ctx.file_access.put_raw_data(source_path, **headers) - return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=unquote(str(remote_path))))) # If not uploading, then we can only take the original source path as the uri. else: return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=source_path))) diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 6e055ca399..33a796b875 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -570,6 +570,24 @@ def wf(path: str) -> os.PathLike: assert flyte_tmp_dir in wf(path="s3://somewhere").path +def test_flyte_file_name_with_special_chars(): + temp_dir = tempfile.TemporaryDirectory() + file_path = os.path.join(temp_dir.name, "foo bar") + try: + with open(file_path, "w") as tmp: + tmp.write("hello world") + + @task + def get_file_path(f: FlyteFile) -> FlyteFile: + return f.path + + @workflow + def wf(f: FlyteFile) -> FlyteFile: + return get_file_path(f=f) + + wf(f=file_path) + finally: + temp_dir.cleanup() def test_flyte_file_annotated_hashmethod(local_dummy_file): def calc_hash(ff: FlyteFile) -> str: