diff --git a/Dockerfile.dev b/Dockerfile.dev index d7d9b46308..2d01b09c5c 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -15,7 +15,7 @@ WORKDIR /root ARG VERSION -RUN apt-get update && apt-get install build-essential vim -y +RUN apt-get update && apt-get install build-essential vim libmagic1 -y COPY . /flytekit diff --git a/dev-requirements.in b/dev-requirements.in index b80953e1b9..407799b5b2 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -27,6 +27,12 @@ torch<=1.12.1; python_version<'3.11' # pytorch 2 supports python 3.11 torch<=2.0.0; python_version>='3.11' or platform_system!='Windows' +# TODO: Currently, the python-magic library causes build errors on Windows due to its dependency on DLLs for libmagic. +# We have temporarily disabled this feature on Windows and are using python-magic for Mac OS and Linux instead. +# For more details, see the related GitHub issue. +# Once a solution is found, this should be updated to support Windows as well. +python-magic; (platform_system=='Darwin' or platform_system=='Linux') + pillow scikit-learn types-protobuf diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index f51c48cdf5..b189190494 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -1,5 +1,6 @@ from __future__ import annotations +import mimetypes import os import pathlib import typing @@ -324,6 +325,57 @@ def assert_type( def get_literal_type(self, t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> LiteralType: return LiteralType(blob=self._blob_type(format=FlyteFilePathTransformer.get_format(t))) + def get_mime_type_from_extension(self, extension: str) -> str: + extension_to_mime_type = { + "hdf5": "text/plain", + "joblib": "application/octet-stream", + "python_pickle": "application/octet-stream", + "ipynb": "application/json", + "onnx": "application/json", + "tfrecord": "application/octet-stream", + } + + for ext, mimetype in mimetypes.types_map.items(): + extension_to_mime_type[ext.split(".")[1]] = mimetype + + return extension_to_mime_type[extension] + + def validate_file_type( + self, python_type: typing.Type[FlyteFile], source_path: typing.Union[str, os.PathLike] + ) -> None: + """ + This method validates the type of the file at source_path against the expected python_type. + It uses the magic library to determine the real type of the file. If the magic library is not installed, + it logs a debug message and returns. If the actual file does not exist, it returns without raising an error. + + :param python_type: The expected type of the file + :param source_path: The path to the file to validate + :raises ValueError: If the real type of the file is not the same as the expected python_type + """ + if FlyteFilePathTransformer.get_format(python_type) == "": + return + + try: + # isolate the exception to the libmagic import + import magic + + except ImportError as e: + logger.debug(f"Libmagic is not installed. Error message: {e}") + return + + ctx = FlyteContext.current_context() + if ctx.file_access.is_remote(source_path): + # Skip validation for remote files. One of the use cases for FlyteFile is to point to remote files, + # you might have access to a remote file (e.g., in s3) that you want to pass to a Flyte workflow. + # Therefore, we should only validate FlyteFiles for which their path is considered local. + return + + if FlyteFilePathTransformer.get_format(python_type): + real_type = magic.from_file(source_path, mime=True) + expected_type = self.get_mime_type_from_extension(FlyteFilePathTransformer.get_format(python_type)) + if real_type != expected_type: + raise ValueError(f"Incorrect file type, expected {expected_type}, got {real_type}") + def to_literal( self, ctx: FlyteContext, @@ -348,6 +400,7 @@ def to_literal( if isinstance(python_val, FlyteFile): source_path = python_val.path + self.validate_file_type(python_type, source_path) # If the object has a remote source, then we just convert it back. This means that if someone is just # going back and forth between a FlyteFile Python value and a Blob Flyte IDL value, we don't do anything. @@ -373,6 +426,7 @@ def to_literal( elif isinstance(python_val, pathlib.Path) or isinstance(python_val, str): source_path = str(python_val) if issubclass(python_type, FlyteFile): + self.validate_file_type(python_type, source_path) if ctx.file_access.is_remote(source_path): should_upload = False else: diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 147e1e8bd8..9922436205 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -2,7 +2,7 @@ import pathlib import tempfile import typing -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from typing_extensions import Annotated @@ -19,7 +19,7 @@ from flytekit.core.workflow import workflow from flytekit.models.core.types import BlobType from flytekit.models.literals import LiteralMap -from flytekit.types.file.file import FlyteFile +from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer # Fixture that ensures a dummy local file @@ -34,6 +34,25 @@ def local_dummy_file(): os.remove(path) +@pytest.fixture +def local_dummy_txt_file(): + fd, path = tempfile.mkstemp(suffix=".txt") + try: + with os.fdopen(fd, "w") as tmp: + tmp.write("Hello World") + yield path + finally: + os.remove(path) + + +def can_import(module_name) -> bool: + try: + __import__(module_name) + return True + except ImportError: + return False + + def test_file_type_in_workflow_with_bad_format(): @task def t1() -> FlyteFile[typing.TypeVar("txt")]: @@ -52,6 +71,116 @@ def my_wf() -> FlyteFile[typing.TypeVar("txt")]: assert fh.read() == "Hello World\n" +def test_matching_file_types_in_workflow(local_dummy_txt_file): + # TXT + @task + def t1(path: FlyteFile[typing.TypeVar("txt")]) -> FlyteFile[typing.TypeVar("txt")]: + return path + + @workflow + def my_wf(path: FlyteFile[typing.TypeVar("txt")]) -> FlyteFile[typing.TypeVar("txt")]: + f = t1(path=path) + return f + + res = my_wf(path=local_dummy_txt_file) + with open(res, "r") as fh: + assert fh.read() == "Hello World" + + +def test_file_types_with_naked_flytefile_in_workflow(local_dummy_txt_file): + @task + def t1(path: FlyteFile[typing.TypeVar("txt")]) -> FlyteFile: + return path + + @workflow + def my_wf(path: FlyteFile[typing.TypeVar("txt")]) -> FlyteFile: + f = t1(path=path) + return f + + res = my_wf(path=local_dummy_txt_file) + with open(res, "r") as fh: + assert fh.read() == "Hello World" + + +@pytest.mark.skipif(not can_import("magic"), reason="Libmagic is not installed") +def test_mismatching_file_types(local_dummy_txt_file): + @task + def t1(path: FlyteFile[typing.TypeVar("txt")]) -> FlyteFile[typing.TypeVar("jpeg")]: + return path + + @workflow + def my_wf(path: FlyteFile[typing.TypeVar("txt")]) -> FlyteFile[typing.TypeVar("jpeg")]: + f = t1(path=path) + return f + + with pytest.raises(TypeError) as excinfo: + my_wf(path=local_dummy_txt_file) + assert "Incorrect file type, expected image/jpeg, got text/plain" in str(excinfo.value) + + +def test_get_mime_type_from_extension_success(): + transformer = TypeEngine.get_transformer(FlyteFile) + assert transformer.get_mime_type_from_extension("html") == "text/html" + assert transformer.get_mime_type_from_extension("jpeg") == "image/jpeg" + assert transformer.get_mime_type_from_extension("png") == "image/png" + assert transformer.get_mime_type_from_extension("hdf5") == "text/plain" + assert transformer.get_mime_type_from_extension("joblib") == "application/octet-stream" + assert transformer.get_mime_type_from_extension("pdf") == "application/pdf" + assert transformer.get_mime_type_from_extension("python_pickle") == "application/octet-stream" + assert transformer.get_mime_type_from_extension("ipynb") == "application/json" + assert transformer.get_mime_type_from_extension("svg") == "image/svg+xml" + assert transformer.get_mime_type_from_extension("csv") == "text/csv" + assert transformer.get_mime_type_from_extension("onnx") == "application/json" + assert transformer.get_mime_type_from_extension("tfrecord") == "application/octet-stream" + assert transformer.get_mime_type_from_extension("txt") == "text/plain" + + +def test_get_mime_type_from_extension_failure(): + transformer = TypeEngine.get_transformer(FlyteFile) + with pytest.raises(KeyError): + transformer.get_mime_type_from_extension("unknown_extension") + + +@pytest.mark.skipif(not can_import("magic"), reason="Libmagic is not installed") +def test_validate_file_type_incorrect(): + transformer = TypeEngine.get_transformer(FlyteFile) + source_path = "/tmp/flytekit_test.png" + source_file_mime_type = "image/png" + user_defined_format = "jpeg" + + with patch.object(FlyteFilePathTransformer, "get_format", return_value=user_defined_format): + with patch("magic.from_file", return_value=source_file_mime_type): + with pytest.raises( + ValueError, match=f"Incorrect file type, expected image/jpeg, got {source_file_mime_type}" + ): + transformer.validate_file_type(user_defined_format, source_path) + + +@pytest.mark.skipif(not can_import("magic"), reason="Libmagic is not installed") +def test_flyte_file_type_annotated_hashmethod(local_dummy_file): + def calc_hash(ff: FlyteFile) -> str: + return str(ff.path) + + HashedFlyteFile = Annotated[FlyteFile["jpeg"], HashMethod(calc_hash)] + + @task + def t1(path: str) -> HashedFlyteFile: + return HashedFlyteFile(path) + + @task + def t2(ff: HashedFlyteFile) -> None: + print(ff.path) + + @workflow + def wf(path: str) -> None: + ff = t1(path=path) + t2(ff=ff) + + with pytest.raises(TypeError) as excinfo: + wf(path=local_dummy_file) + assert "Incorrect file type, expected image/jpeg, got text/plain" in str(excinfo.value) + + def test_file_handling_remote_default_wf_input(): SAMPLE_DATA = "https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv" diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 4e32070f9f..f1e04f6718 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -41,7 +41,7 @@ from flytekit.models.types import LiteralType, SimpleType from flytekit.tools.translator import get_serializable from flytekit.types.directory import FlyteDirectory, TensorboardLogs -from flytekit.types.file import FlyteFile, PNGImageFile +from flytekit.types.file import FlyteFile from flytekit.types.schema import FlyteSchema, SchemaOpenMode from flytekit.types.structured.structured_dataset import StructuredDataset @@ -390,7 +390,7 @@ def test_flyte_file_in_dataclass(): @dataclass class InnerFileStruct(DataClassJsonMixin): a: FlyteFile - b: PNGImageFile + b: FlyteFile @dataclass class FileStruct(DataClassJsonMixin): @@ -400,7 +400,7 @@ class FileStruct(DataClassJsonMixin): @task def t1(path: str) -> FileStruct: file = FlyteFile(path) - fs = FileStruct(a=file, b=InnerFileStruct(a=file, b=PNGImageFile(path))) + fs = FileStruct(a=file, b=InnerFileStruct(a=file, b=FlyteFile(path))) return fs @dynamic diff --git a/tests/flytekit/unit/extras/tasks/test_shell.py b/tests/flytekit/unit/extras/tasks/test_shell.py index 1e92377202..fca19ff0f9 100644 --- a/tests/flytekit/unit/extras/tasks/test_shell.py +++ b/tests/flytekit/unit/extras/tasks/test_shell.py @@ -117,7 +117,7 @@ def test_input_output_substitution_files(): name="test", debug=True, script=script, - inputs=kwtypes(f=CSVFile), + inputs=kwtypes(f=FlyteFile), output_locs=[ OutputLocation(var="y", var_type=FlyteFile, location="{inputs.f}.mod"), ], @@ -127,11 +127,10 @@ def test_input_output_substitution_files(): contents = "1,2,3,4\n" with tempfile.TemporaryDirectory() as tmp: - csv = os.path.join(tmp, "abc.csv") - print(csv) - with open(csv, "w") as f: + test_data = os.path.join(tmp, "abc.txt") + with open(test_data, "w") as f: f.write(contents) - y = t(f=csv) + y = t(f=test_data) assert y.path[-4:] == ".mod" assert os.path.exists(y.path) with open(y.path) as f: diff --git a/tests/flytekit/unit/extras/tasks/testdata/test.csv b/tests/flytekit/unit/extras/tasks/testdata/test.csv index e69de29bb2..c8f2749b92 100644 --- a/tests/flytekit/unit/extras/tasks/testdata/test.csv +++ b/tests/flytekit/unit/extras/tasks/testdata/test.csv @@ -0,0 +1,4 @@ +SN,Name,Contribution +1,Linus Torvalds,Linux Kernel +2,Tim Berners-Lee,World Wide Web +3,Guido van Rossum,Python Programming