From 03c31856c6e668baccc06fa89442a74fa7997126 Mon Sep 17 00:00:00 2001 From: Suha Sabi Hussain Date: Fri, 26 Jan 2024 14:46:21 -0500 Subject: [PATCH] Refine PyTorch and Polyglot modules (#89) * Enable verbosity modifications in PyTorch module * Add new polyglot example * Add creation test * Refactor polyglot module for clarity * Fix typos and clean up comments * Refine support for TorchScript in the PyTorch module * Lint * Fix comment and file name * Remove unnecessary lines * Clean up * Lint and tidy files * Delete extraneous information * Bump version number for release * Clarify what polyglot files are * More lint * More lint --- README.md | 6 +- example/README.md | 3 +- example/fault_injection.py | 1 + example/hook_functions.py | 2 +- example/identify_pytorch_file.py | 16 ++++ fickling/__init__.py | 2 +- fickling/analysis.py | 18 ++-- fickling/cli.py | 2 +- fickling/fickle.py | 8 +- fickling/polyglot.py | 151 +++++++++++++++++++------------ fickling/pytorch.py | 29 +++--- test/test_polyglot.py | 23 +++++ test/test_pytorch.py | 16 +++- 13 files changed, 189 insertions(+), 88 deletions(-) create mode 100644 example/identify_pytorch_file.py diff --git a/README.md b/README.md index 8cc46bb..f648025 100644 --- a/README.md +++ b/README.md @@ -162,7 +162,9 @@ Module( ### PyTorch polyglots -We currently support inspecting, identifying, and creating file polyglots between the +PyTorch contains multiple file formats with which one can make polyglot files, which +are files that can be validly interpreted as more than one file format. +Fickling supports identifying, inspecting, and creating polyglots with the following PyTorch file formats: * **PyTorch v0.1.1**: Tar file with sys_info, pickle, storages, and tensors @@ -172,7 +174,7 @@ following PyTorch file formats: * **TorchScript v1.3**: ZIP file with data.pkl and constants.pkl (2 pickle files) * **TorchScript v1.4**: ZIP file with data.pkl, constants.pkl, and version (2 pickle files and a folder) * **PyTorch v1.3**: ZIP file containing data.pkl (1 pickle file) -* **PyTorch model archive format**: ZIP file that includes Python code files and pickle files +* **PyTorch model archive format[ZIP]**: ZIP file that includes Python code files and pickle files ```python >> import torch diff --git a/example/README.md b/example/README.md index 821e3ca..aee4390 100644 --- a/example/README.md +++ b/example/README.md @@ -4,6 +4,7 @@ * [context_manager.py](https://github.com/trailofbits/fickling/blob/master/example/context_manager.py): Halt the deserialization of a malicious pickle file with the fickling context manager * [fault_injection.py](https://github.com/trailofbits/fickling/blob/master/example/fault_injection.py): Perform a fault injection on a PyTorch model and then analyze the result with `check_safety` * [inject_mobilenet.py](https://github.com/trailofbits/fickling/blob/master/example/inject_mobilenet.py): Override the `eval` method of a ML model using fickling and apply `fickling.is_likely_safe` to the model file -* [inject_pytorch.py](https://github.com/trailofbits/fickling/blob/master/example/inject_pytorch.py): Inject a model loaded from a PyTorch file with malicious code using fickling’s PyTorch module +* [inject_pytorch.py](https://github.com/trailofbits/fickling/blob/master/example/inject_pytorch.py): Inject a model loaded from a PyTorch file with malicious code using the PyTorch module * [numpy_poc.py](https://github.com/trailofbits/fickling/blob/master/example/numpy_poc.py): Analyze a malicious payload passed to `numpy.load()` * [trace_binary.py](https://github.com/trailofbits/fickling/blob/master/example/trace_binary.py): Decompile a payload using the tracing module +* [identify_pytorch_file.py](https://github.com/trailofbits/fickling/blob/master/example/identify_pytorch_file.py): Identify 2 PyTorch files that are different formats \ No newline at end of file diff --git a/example/fault_injection.py b/example/fault_injection.py index 9c3f871..feb9de9 100644 --- a/example/fault_injection.py +++ b/example/fault_injection.py @@ -4,6 +4,7 @@ Note you may need to run `pip install pytorchfi` """ + import pickle import torch diff --git a/example/hook_functions.py b/example/hook_functions.py index b91862d..a6ec808 100644 --- a/example/hook_functions.py +++ b/example/hook_functions.py @@ -7,7 +7,7 @@ # Set up global fickling hook fickling.always_check_safety() -# Eauivalent to fickling.hook.run_hook() +# Equivalent to fickling.hook.run_hook() # Fickling can check a pickle file for safety prior to running it test_list = [1, 2, 3] diff --git a/example/identify_pytorch_file.py b/example/identify_pytorch_file.py new file mode 100644 index 0000000..bba42b8 --- /dev/null +++ b/example/identify_pytorch_file.py @@ -0,0 +1,16 @@ +import torch +import torchvision.models as models + +import fickling.polyglot as polyglot + +model = models.mobilenet_v2() +torch.save(model, "mobilenet.pth") +torch.save(model, "legacy_mobilenet.pth", _use_new_zipfile_serialization=False) + +print("Identifying PyTorch v1.3 file:") +potential_formats = polyglot.identify_pytorch_file_format("mobilenet.pth", print_results=True) + +print("Identifying PyTorch v0.1.10 file:") +potential_formats_legacy = polyglot.identify_pytorch_file_format( + "legacy_mobilenet.pth", print_results=True +) diff --git a/fickling/__init__.py b/fickling/__init__.py index 9d14d28..ac4dc0e 100644 --- a/fickling/__init__.py +++ b/fickling/__init__.py @@ -7,4 +7,4 @@ # The above lines enables `fickling.load()` and `with fickling.check_safety()` # The comments are necessary to comply with linters -__version__ = "0.0.8" +__version__ = "0.1.0" diff --git a/fickling/analysis.py b/fickling/analysis.py index 8c9ddfc..1fa41a1 100644 --- a/fickling/analysis.py +++ b/fickling/analysis.py @@ -107,9 +107,9 @@ def __init__( self.severity: Severity = severity self.message: Optional[str] = message self.analysis_name: str = analysis_name - self.trigger: Optional[ - str - ] = trigger # Field to store the trigger code fragment or artifact + self.trigger: Optional[str] = ( + trigger # Field to store the trigger code fragment or artifact + ) def __lt__(self, other): return isinstance(other, AnalysisResult) and ( @@ -296,11 +296,13 @@ def to_dict(self, verbosity: Severity = Severity.SUSPICIOUS): analysis_message = self.to_string(verbosity) severity_data = { "severity": self.severity.name, - "analysis": analysis_message - if analysis_message.strip() - else "Warning: Fickling failed to detect any overtly unsafe code, but the pickle file" - "may still be unsafe." - "Do not unpickle this file if it is from an untrusted source!\n\n", + "analysis": ( + analysis_message + if analysis_message.strip() + else "Warning: Fickling failed to detect any overtly unsafe code," + "but the pickle file may still be unsafe." + "Do not unpickle this file if it is from an untrusted source!\n\n" + ), "detailed_results": self.detailed_results(), } return severity_data diff --git a/fickling/cli.py b/fickling/cli.py index 0391876..0aed818 100644 --- a/fickling/cli.py +++ b/fickling/cli.py @@ -35,7 +35,7 @@ def main(argv: Optional[List[str]] = None) -> int: "-i", type=str, default=None, - help="inject the specified Python code to be run at the end of depickling, " + help="inject the specified Python code to be run at the end of unpickling, " "and output the resulting pickle data", ) parser.add_argument( diff --git a/fickling/fickle.py b/fickling/fickle.py index ee51ed4..ac7c7a0 100644 --- a/fickling/fickle.py +++ b/fickling/fickle.py @@ -508,7 +508,7 @@ def append_python( def insert_magic_int(self, magic: int, index: int = -1): """Insert and pop a specific integer value. This is used for persistent injections to locate the injection payload in the pickled file. The value - is artificially added by using an dummy INT + POP combination that doesn't + is artificially added by using a dummy INT + POP combination that doesn't affect the stack when executed :param magic: magic integer value to add @@ -760,13 +760,11 @@ def __init__(self, initial_value: Iterable[T] = ()): @overload @abstractmethod - def __getitem__(self, i: int) -> T: - ... + def __getitem__(self, i: int) -> T: ... @overload @abstractmethod - def __getitem__(self, s: slice) -> GenericSequence: - ... + def __getitem__(self, s: slice) -> GenericSequence: ... def __getitem__(self, i: int) -> T: return self._stack[i] diff --git a/fickling/polyglot.py b/fickling/polyglot.py index 8cc76f1..de4b9b3 100644 --- a/fickling/polyglot.py +++ b/fickling/polyglot.py @@ -19,10 +19,7 @@ • TorchScript v1.3: ZIP file with data.pkl and constants.pkl (2 pickle files) • TorchScript v1.4: ZIP file with data.pkl, constants.pkl, and version (2 pickle files and a folder) • PyTorch v1.3: ZIP file containing data.pkl (1 pickle file) -• PyTorch model archive format: ZIP file that includes Python code files and pickle files - -Officially, PyTorch v0.1.1 and TorchScript < v1.4 are deprecated. -However, they are still supported by some legacy parsers +• PyTorch model archive format[ZIP]: ZIP file that includes Python code files and pickle files This description draws from this PyTorch GitHub issue: https://github.com/pytorch/pytorch/issues/31877. If any inaccuracies in that description are found, that should be reflected in this code. @@ -71,14 +68,14 @@ def check_pickle(file): def find_file_properties(file_path, print_properties=False): - """For a more granular analysis, we separate property discover and format identification""" + """For a more granular analysis, we separate property discovery and format identification""" properties = {} with open(file_path, "rb") as file: # PyTorch's torch.load() enforces a specific magic number at offset 0 for ZIP is_torch_zip = _is_zipfile(file) properties["is_torch_zip"] = is_torch_zip - # This tarfile check has many false positivies. It is not a determinant of PyTorch v0.1.1. + # This tarfile check has many false positives. It is not a determinant of PyTorch v0.1.1. if sys.version_info >= (3, 9): is_tar = tarfile.is_tarfile(file) else: @@ -175,10 +172,12 @@ def check_for_corruption(properties): return corrupted, reason -def identify_pytorch_file_format(file, print_properties=False): +def identify_pytorch_file_format(file, print_properties=False, print_results=False): """ We are intentionally matching the semantics of the PyTorch reference parsers. To be polyglot-aware, we show the file formats ranked by likelihood. + Our parsing depth is at the file structure level; + However, it can be at the full parsing level if necessary. """ properties = find_file_properties(file, print_properties) formats = [] @@ -215,16 +214,21 @@ def identify_pytorch_file_format(file, print_properties=False): print(reason) if len(formats) != 0: primary = formats[0] - print("Your file is most likely of this format: ", primary, "\n") + if print_results: + print("Your file is most likely of this format: ", primary, "\n") secondary = formats[1:] if len(secondary) != 0: - print("It is also possible that your file can be validly interpreted as: ", secondary) + if print_results: + print( + "It is also possible that your file can be validly interpreted as: ", secondary + ) else: - print( - """Your file may not be a PyTorch file. - No valid file formats were detected. - If this is a mistake, raise an issue on our GitHub.""" - ) + if print_results: + print( + """Your file may not be a PyTorch file. + No valid file formats were detected. + If this is a mistake, raise an issue on our GitHub.""" + ) return formats @@ -236,11 +240,67 @@ def append_file(source_filename, destination_filename): return destination_filename -def make_zip_pickle_polyglot(zip_file, pickle_file, copy=False): +def create_zip_pickle_polyglot(zip_file, pickle_file): append_file(zip_file, pickle_file) -def create_polyglot(first_file, second_file): +def create_mar_legacy_pickle_polyglot( + files, print_results=False, polyglot_file_name="polyglot.mar.pt" +): + files.sort(key=lambda x: x[1] != "PyTorch model archive format") + if print_results: + print("Making a PyTorch MAR/PyTorch v0.1.10 polyglot") + polyglot_file = append_file(*[file[0] for file in files]) + shutil.copy(polyglot_file, polyglot_file_name) + polyglot_found = True + return polyglot_found + + +def create_standard_torchscript_polyglot( + files, print_results=False, polyglot_file_name="polyglot.pt" +): + if print_results: + print("Making a PyTorch v1.3/TorchScript v1.4 polyglot") + print("Warning: For some parsers, this may generate polymocks instead of polyglots.") + standard_pytorch_file = [file[0] for file in files if file[1] == "PyTorch v1.3"][0] + torchscript_file = [file[0] for file in files if file[1] == "TorchScript v1.4"][0] + if polyglot_file_name is None: + polyglot_file_name = "polyglot.pt" + shutil.copy(standard_pytorch_file, polyglot_file_name) + + with zipfile.ZipFile(torchscript_file, "r") as zip_b: + constants_pkl_path = check_and_find_in_zip( + zip_b, "constants.pkl", check_extension=False, return_path=True + ) + version_path = check_and_find_in_zip(zip_b, "version", return_path=True) + if constants_pkl_path and version_path: + zip_b.extract(constants_pkl_path, "temp") + zip_b.extract(version_path, "temp") + + with zipfile.ZipFile(polyglot_file_name, "a") as zip_out: + zip_out.write(f"temp/{constants_pkl_path}", "constants.pkl") + zip_out.write(f"temp/{version_path}", "version") + + shutil.rmtree("temp") + polyglot_found = True + return polyglot_found + + +def create_mar_legacy_tar_polyglot( + files, print_results=False, polyglot_file_name="polyglot.mar.tar" +): + if print_results: + print("Making a PyTorch v0.1.1/PyTorch MAR polyglot") + mar_file = [file[0] for file in files if file[1] == "PyTorch model archive format"][0] + tar_file = [file[0] for file in files if file[1] == "PyTorch v0.1.1"][0] + polyglot_file = append_file(mar_file, tar_file) + shutil.copy(polyglot_file, polyglot_file_name) + polyglot_found = True + return polyglot_found + + +def create_polyglot(first_file, second_file, polyglot_file_name=None, print_results=True): + polyglot_found = False temp_first_file = "temp_" + os.path.basename(first_file) temp_second_file = "temp_" + os.path.basename(second_file) shutil.copy(first_file, temp_first_file) @@ -250,49 +310,28 @@ def create_polyglot(first_file, second_file): (temp_second_file, identify_pytorch_file_format(temp_second_file)[0]), ] formats = set(map(lambda x: x[1], files)) # noqa - polyglot_found = False if {"PyTorch model archive format", "PyTorch v0.1.10"}.issubset(formats): - files.sort(key=lambda x: x[1] != "PyTorch model archive format") - print("Making a PyTorch MAR/PyTorch v0.1.10 polyglot") - polyglot_found = True - polyglot_file = append_file(*[file[0] for file in files]) - shutil.copy(polyglot_file, "polyglot.mar.pt") - print("The polyglot is contained in polyglot.mar.pt") + if polyglot_file_name is None: + polyglot_file_name = "polyglot.mar.pt" + polyglot_found = create_mar_legacy_pickle_polyglot(files, print_results, polyglot_file_name) if {"PyTorch v1.3", "TorchScript v1.4"}.issubset(formats): - print("Making a PyTorch v1.3/TorchScript v1.4 polyglot") - print("Warning: For some parsers, this may generate polymocks instead of polyglots.") - polyglot_found = True - standard_pytorch_file = [file[0] for file in files if file[1] == "PyTorch v1.3"][0] - torchscript_file = [file[0] for file in files if file[1] == "TorchScript v1.4"][0] - shutil.copy(standard_pytorch_file, "polyglot.pt") - - with zipfile.ZipFile(torchscript_file, "r") as zip_b: - constants_pkl_path = check_and_find_in_zip( - zip_b, "constants.pkl", check_extension=False, return_path=True - ) - version_path = check_and_find_in_zip(zip_b, "version", return_path=True) - if constants_pkl_path and version_path: - zip_b.extract(constants_pkl_path, "temp") - zip_b.extract(version_path, "temp") - - with zipfile.ZipFile("polyglot.pt", "a") as zip_out: - zip_out.write(f"temp/{constants_pkl_path}", "constants.pkl") - zip_out.write(f"temp/{version_path}", "version") - - shutil.rmtree("temp") - if {"PyTorch model archive format", "PyTorch v0.1.1"}.issubset(formats): - print("Making a PyTorch v0.1.1/PyTorch MAR polyglot") - polyglot_found = True - mar_file = [file[0] for file in files if file[1] == "PyTorch model archive format"][0] - tar_file = [file[0] for file in files if file[1] == "PyTorch v0.1.1"][0] - polyglot_file = append_file(mar_file, tar_file) - shutil.copy(polyglot_file, "polyglot.mar.tar") - print("The polyglot is contained in polyglot.mar.tar") - if polyglot_found is False: - print( - """Fickling was not able to create any polglots. - If you think this is a mistake, raise an issue on our GitHub.""" + if polyglot_file_name is None: + polyglot_file_name = "polyglot.pt" + polyglot_found = create_standard_torchscript_polyglot( + files, print_results, polyglot_file_name ) + if {"PyTorch model archive format", "PyTorch v0.1.1"}.issubset(formats): + if polyglot_file_name is None: + polyglot_file_name = "polyglot.mar.tar" + polyglot_found = create_mar_legacy_tar_polyglot(files, print_results, polyglot_file_name) + if print_results: + if polyglot_found is False: + print( + """Fickling was not able to create any polyglots. + If you think this is a mistake, raise an issue on our GitHub.""" + ) + else: + print(f"The polyglot is contained in {polyglot_file_name}") os.remove(temp_first_file) os.remove(temp_second_file) return polyglot_found diff --git a/fickling/pytorch.py b/fickling/pytorch.py index 976a5bc..d8f77f2 100644 --- a/fickling/pytorch.py +++ b/fickling/pytorch.py @@ -34,12 +34,11 @@ def __init__(self, path: Path, force: bool = False): def validate_file_format(self): self._formats = fickling.polyglot.identify_pytorch_file_format(self.path) """ - One option was to raise an error if PyTorch v1.3 was not found - or if any of the TorchScript versions were found. - However, that would prevent polyglots from being loaded. - Therefore, the 'force' argument was created to enable users to do that if needed. - Another option was to warn only if "PyTorch v1.3" was not the most likely format. - Instead, the file formats are directly specified for clarity and independence. + PyTorch v1.3 and TorchScript v1.4 are explicitly supported by PyTorchModelWrapper. + This class may work on other file formats depending on its construction. + To enable users to check that and load polyglots, the force argument exists. + There is a warning for TorchScript v1.4 because of the scripting/tracing/mixing edge cases. + For example, an injection may work on torch.load() but not torch.jit.load(). """ if len(self._formats) == 0: if self.force is True: @@ -57,12 +56,7 @@ def validate_file_format(self): If it is a PyTorch file, raise an issue on GitHub. """ ) - if ("PyTorch v1.3" not in self._formats) or { - "TorchScript v1.4", - "TorchScript v1.3", - "TorchScript v1.1", - "TorchScript v1.0", - }.intersection(self._formats): + if ("PyTorch v1.3" not in self._formats) and ("TorchScript v1.4" not in self._formats): if "PyTorch v0.1.10" in self._formats: if self.force is True: warnings.warn( @@ -92,6 +86,11 @@ def validate_file_format(self): """A fickling wrapper and injection method does not exist for that format. Please raise an issue on our GitHub or use the argument `force=True`.""" ) + if self._formats[0] == "TorchScript v1.4": + warnings.warn( + """Support for TorchScript v1.4 files is experimental.""", + UserWarning, + ) return self._formats @property @@ -118,6 +117,12 @@ def inject_payload( self, payload: str, output_path: Path, injection: str = "all", overwrite: bool = False ) -> None: self.output_path = output_path + if self.formats[0] == "TorchScript v1.4": + warnings.warn( + """Support for TorchScript v1.4 files is experimental. + Injections may not be effective depending on the model and the target parser.""", + UserWarning, + ) if injection == "insertion": # This does NOT bypass the weights based unpickler pickled = self.pickled diff --git a/test/test_polyglot.py b/test/test_polyglot.py index f090485..f0658fb 100644 --- a/test/test_polyglot.py +++ b/test/test_polyglot.py @@ -52,6 +52,10 @@ def setUp(self): self.filename_v1_3 = "model_v1_3.pth" torch.save(model, self.filename_v1_3) + # PyTorch v1.3 Dup (for testing) + self.filename_v1_3_dup = "model_v1_3_dup.pth" + torch.save(model, self.filename_v1_3_dup) + # PyTorch v0.1.10 (Stacked pickle files) self.filename_legacy_pickle = "model_legacy_pickle.pth" torch.save(model, self.filename_legacy_pickle, _use_new_zipfile_serialization=False) @@ -61,6 +65,10 @@ def setUp(self): self.filename_torchscript = "model_torchscript.pt" torch.jit.save(m, self.filename_torchscript) + # TorchScript v1.4 + self.filename_torchscript_dup = "model_torchscript_dup.pt" + torch.jit.save(m, self.filename_torchscript_dup) + # PyTorch v0.1.1 self.filename_legacy_tar = "model_legacy_tar.pth" create_pytorch_legacy_tar(self.filename_legacy_tar) @@ -70,6 +78,8 @@ def setUp(self): create_random_zip(self.zip_filename) prepend_random_string(self.zip_filename) + self.standard_torchscript_polyglot_name = "test_polyglot.pt" + def tearDown(self): for filename in [ self.filename_v1_3, @@ -77,6 +87,9 @@ def tearDown(self): self.filename_torchscript, self.filename_legacy_tar, self.zip_filename, + self.filename_torchscript_dup, + self.filename_v1_3_dup, + self.standard_torchscript_polyglot_name, ]: if os.path.exists(filename): os.remove(filename) @@ -166,3 +179,13 @@ def test_zip_properties(self): "has_attribute_pkl": False, } self.assertEqual(properties, proper_result) + + def test_create_standard_torchscript_polyglot(self): + polyglot.create_polyglot( + self.filename_v1_3_dup, + self.filename_torchscript_dup, + self.standard_torchscript_polyglot_name, + print_results=False, + ) + formats = polyglot.identify_pytorch_file_format(self.standard_torchscript_polyglot_name) + self.assertTrue({"PyTorch v1.3", "TorchScript v1.4"}.issubset(formats)) diff --git a/test/test_pytorch.py b/test/test_pytorch.py index e0eeae4..fcb8746 100644 --- a/test/test_pytorch.py +++ b/test/test_pytorch.py @@ -14,9 +14,12 @@ def setUp(self): self.filename_v1_3 = "test_model.pth" torch.save(model, self.filename_v1_3) self.zip_filename = "test_random_data.zip" + m = torch.jit.script(model) + self.torchscript_filename = "test_model_torchscript.pth" + torch.jit.save(m, self.torchscript_filename) def tearDown(self): - for filename in [self.filename_v1_3, self.zip_filename]: + for filename in [self.filename_v1_3, self.zip_filename, self.torchscript_filename]: if os.path.exists(filename): os.remove(filename) @@ -26,11 +29,22 @@ def test_wrapper(self): except Exception as e: # noqa self.fail(f"PyTorchModelWrapper was not able to load a PyTorch v1.3 file: {e}") + def test_torchscript_wrapper(self): + try: + PyTorchModelWrapper(self.torchscript_filename) + except Exception as e: # noqa + self.fail(f"PyTorchModelWrapper was not able to load a TorchScript v1.4 file: {e}") + def test_pickled(self): result = PyTorchModelWrapper(self.filename_v1_3) pickled_portion = result.pickled self.assertIsInstance(pickled_portion, Pickled) + def test_torchscript_pickled(self): + result = PyTorchModelWrapper(self.torchscript_filename) + pickled_portion = result.pickled + self.assertIsInstance(pickled_portion, Pickled) + def test_injection_insertion(self): try: result = PyTorchModelWrapper(self.filename_v1_3)