diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 5c8036d179..a6b401bff8 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -454,6 +454,39 @@ def join( f = fs.unstrip_protocol(f) return f + def generate_new_custom_path( + self, + fs: typing.Optional[fsspec.AbstractFileSystem] = None, + alt: typing.Optional[str] = None, + stem: typing.Optional[str] = None, + ) -> str: + """ + Generates a new path with the raw output prefix and a random string appended to it. + Optionally, you can provide an alternate prefix and a stem. If stem is provided, it + will be appended to the path instead of a random string. If alt is provided, it will + replace the first part of the output prefix, e.g. the S3 or GCS bucket. + + If wanting to write to a non-random prefix in a non-default S3 bucket, this can be + called with alt="my-alt-bucket" and stem="my-stem" to generate a path like + s3://my-alt-bucket/default-prefix-part/my-stem + + :param fs: The filesystem to use. If None, the context's raw output filesystem is used. + :param alt: An alternate first member of the prefix to use instead of the default. + :param stem: A stem to append to the path. + :return: The new path. + """ + fs = fs or self.raw_output_fs + pref = self.raw_output_prefix + s_pref = pref.split(fs.sep)[:-1] + if alt: + s_pref[2] = alt + if stem: + s_pref.append(stem) + else: + s_pref.append(self.get_random_string()) + p = fs.sep.join(s_pref) + return p + def get_random_local_path(self, file_path_or_file_name: typing.Optional[str] = None) -> str: """ Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index eb01cdd039..b372c16d6a 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -186,18 +186,23 @@ def extension(cls) -> str: return "" @classmethod - def new_remote(cls) -> FlyteDirectory: + def new_remote(cls, stem: typing.Optional[str] = None, alt: typing.Optional[str] = None) -> FlyteDirectory: """ Create a new FlyteDirectory object using the currently configured default remote in the context (i.e. the raw_output_prefix configured in the current FileAccessProvider object in the context). This is used if you explicitly have a folder somewhere that you want to create files under. If you want to write a whole folder, you can let your task return a FlyteDirectory object, and let flytekit handle the uploading. + + :param stem: A stem to append to the path as the final prefix "directory". + :param alt: An alternate first member of the prefix to use instead of the default. + :return FlyteDirectory: A new FlyteDirectory object that points to a remote location. """ ctx = FlyteContextManager.current_context() - r = ctx.file_access.get_random_string() - d = ctx.file_access.join(ctx.file_access.raw_output_prefix, r) - return FlyteDirectory(path=d) + if stem and Path(stem).suffix: + raise ValueError("Stem should not have a file extension.") + remote_path = ctx.file_access.generate_new_custom_path(alt=alt, stem=stem) + return cls(path=remote_path) def __class_getitem__(cls, item: typing.Union[typing.Type, str]) -> typing.Type[FlyteDirectory]: if item is None: diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index e703f71ccd..087cad6b5e 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -179,13 +179,15 @@ def extension(cls) -> str: return "" @classmethod - def new_remote_file(cls, name: typing.Optional[str] = None) -> FlyteFile: + def new_remote_file(cls, name: typing.Optional[str] = None, alt: typing.Optional[str] = None) -> FlyteFile: """ Create a new FlyteFile object with a remote path. + + :param name: If you want to specify a different name for the file, you can specify it here. + :param alt: If you want to specify a different prefix head than the default one, you can specify it here. """ ctx = FlyteContextManager.current_context() - r = name or ctx.file_access.get_random_string() - remote_path = ctx.file_access.join(ctx.file_access.raw_output_prefix, r) + remote_path = ctx.file_access.generate_new_custom_path(alt=alt, stem=name) return cls(path=remote_path) @classmethod diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index 159214fe43..5063e484d2 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -136,6 +136,19 @@ def test_write_known_location(): assert f.read() == arbitrary_text.encode("utf-8") +def test_generate_new_custom_path(): + """ + Test that a new path given alternate bucket and name is generated correctly + """ + random_dir = tempfile.mkdtemp() + fs = FileAccessProvider( + local_sandbox_dir=random_dir, + raw_output_prefix="s3://my-default-bucket/my-default-prefix/" + ) + np = fs.generate_new_custom_path(alt="foo-bucket", stem="bar.txt") + assert np == "s3://foo-bucket/my-default-prefix/bar.txt" + + def test_initialise_azure_file_provider_with_account_key(): with mock.patch.dict( os.environ, diff --git a/tests/flytekit/unit/types/directory/test_types.py b/tests/flytekit/unit/types/directory/test_types.py index 199b788733..1b9cf4be97 100644 --- a/tests/flytekit/unit/types/directory/test_types.py +++ b/tests/flytekit/unit/types/directory/test_types.py @@ -22,6 +22,10 @@ def test_new_remote_dir(): fd = FlyteDirectory.new_remote() assert FlyteContext.current_context().file_access.raw_output_prefix in fd.path +def test_new_remote_dir_alt(): + ff = FlyteDirectory.new_remote(alt="my-alt-bucket", stem="my-stem") + assert "my-alt-bucket" in ff.path + assert "my-stem" in ff.path @mock.patch("flytekit.types.directory.types.os.name", "nt") def test_sep_nt(): diff --git a/tests/flytekit/unit/types/file/test_types.py b/tests/flytekit/unit/types/file/test_types.py new file mode 100644 index 0000000000..7cc6e42fea --- /dev/null +++ b/tests/flytekit/unit/types/file/test_types.py @@ -0,0 +1,7 @@ +from flytekit.types.file import FlyteFile +from flytekit import FlyteContextManager + +def test_new_remote_alt(): + ff = FlyteFile.new_remote_file(alt="my-alt-prefix", name="my-file.txt") + assert "my-alt-prefix" in ff.path + assert "my-file.txt" in ff.path