Skip to content

Commit

Permalink
Refine PyTorch and Polyglot modules (#89)
Browse files Browse the repository at this point in the history
* Enable verbosity modifications in PyTorch module

* Add new polyglot example

* Add creation test

* Refactor polyglot module for clarity

* Fix typos and clean up comments

* Refine support for TorchScript in the PyTorch module

* Lint

* Fix comment and file name

* Remove unnecessary lines

* Clean up

* Lint and tidy files

* Delete extraneous information

* Bump version number for release

* Clarify what polyglot files are

* More lint

* More lint
  • Loading branch information
suhacker1 authored Jan 26, 2024
1 parent 4503d1c commit 03c3185
Show file tree
Hide file tree
Showing 13 changed files with 189 additions and 88 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ Module(

### PyTorch polyglots

We currently support inspecting, identifying, and creating file polyglots between the
PyTorch contains multiple file formats with which one can make polyglot files, which
are files that can be validly interpreted as more than one file format.
Fickling supports identifying, inspecting, and creating polyglots with the
following PyTorch file formats:

* **PyTorch v0.1.1**: Tar file with sys_info, pickle, storages, and tensors
Expand All @@ -172,7 +174,7 @@ following PyTorch file formats:
* **TorchScript v1.3**: ZIP file with data.pkl and constants.pkl (2 pickle files)
* **TorchScript v1.4**: ZIP file with data.pkl, constants.pkl, and version (2 pickle files and a folder)
* **PyTorch v1.3**: ZIP file containing data.pkl (1 pickle file)
* **PyTorch model archive format**: ZIP file that includes Python code files and pickle files
* **PyTorch model archive format[ZIP]**: ZIP file that includes Python code files and pickle files

```python
>> import torch
Expand Down
3 changes: 2 additions & 1 deletion example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* [context_manager.py](https://github.com/trailofbits/fickling/blob/master/example/context_manager.py): Halt the deserialization of a malicious pickle file with the fickling context manager
* [fault_injection.py](https://github.com/trailofbits/fickling/blob/master/example/fault_injection.py): Perform a fault injection on a PyTorch model and then analyze the result with `check_safety`
* [inject_mobilenet.py](https://github.com/trailofbits/fickling/blob/master/example/inject_mobilenet.py): Override the `eval` method of a ML model using fickling and apply `fickling.is_likely_safe` to the model file
* [inject_pytorch.py](https://github.com/trailofbits/fickling/blob/master/example/inject_pytorch.py): Inject a model loaded from a PyTorch file with malicious code using fickling’s PyTorch module
* [inject_pytorch.py](https://github.com/trailofbits/fickling/blob/master/example/inject_pytorch.py): Inject a model loaded from a PyTorch file with malicious code using the PyTorch module
* [numpy_poc.py](https://github.com/trailofbits/fickling/blob/master/example/numpy_poc.py): Analyze a malicious payload passed to `numpy.load()`
* [trace_binary.py](https://github.com/trailofbits/fickling/blob/master/example/trace_binary.py): Decompile a payload using the tracing module
* [identify_pytorch_file.py](https://github.com/trailofbits/fickling/blob/master/example/identify_pytorch_file.py): Identify 2 PyTorch files that are different formats
1 change: 1 addition & 0 deletions example/fault_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Note you may need to run `pip install pytorchfi`
"""

import pickle

import torch
Expand Down
2 changes: 1 addition & 1 deletion example/hook_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# Set up global fickling hook
fickling.always_check_safety()
# Eauivalent to fickling.hook.run_hook()
# Equivalent to fickling.hook.run_hook()

# Fickling can check a pickle file for safety prior to running it
test_list = [1, 2, 3]
Expand Down
16 changes: 16 additions & 0 deletions example/identify_pytorch_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
import torchvision.models as models

import fickling.polyglot as polyglot

model = models.mobilenet_v2()
torch.save(model, "mobilenet.pth")
torch.save(model, "legacy_mobilenet.pth", _use_new_zipfile_serialization=False)

print("Identifying PyTorch v1.3 file:")
potential_formats = polyglot.identify_pytorch_file_format("mobilenet.pth", print_results=True)

print("Identifying PyTorch v0.1.10 file:")
potential_formats_legacy = polyglot.identify_pytorch_file_format(
"legacy_mobilenet.pth", print_results=True
)
2 changes: 1 addition & 1 deletion fickling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@

# The above lines enables `fickling.load()` and `with fickling.check_safety()`
# The comments are necessary to comply with linters
__version__ = "0.0.8"
__version__ = "0.1.0"
18 changes: 10 additions & 8 deletions fickling/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def __init__(
self.severity: Severity = severity
self.message: Optional[str] = message
self.analysis_name: str = analysis_name
self.trigger: Optional[
str
] = trigger # Field to store the trigger code fragment or artifact
self.trigger: Optional[str] = (
trigger # Field to store the trigger code fragment or artifact
)

def __lt__(self, other):
return isinstance(other, AnalysisResult) and (
Expand Down Expand Up @@ -296,11 +296,13 @@ def to_dict(self, verbosity: Severity = Severity.SUSPICIOUS):
analysis_message = self.to_string(verbosity)
severity_data = {
"severity": self.severity.name,
"analysis": analysis_message
if analysis_message.strip()
else "Warning: Fickling failed to detect any overtly unsafe code, but the pickle file"
"may still be unsafe."
"Do not unpickle this file if it is from an untrusted source!\n\n",
"analysis": (
analysis_message
if analysis_message.strip()
else "Warning: Fickling failed to detect any overtly unsafe code,"
"but the pickle file may still be unsafe."
"Do not unpickle this file if it is from an untrusted source!\n\n"
),
"detailed_results": self.detailed_results(),
}
return severity_data
Expand Down
2 changes: 1 addition & 1 deletion fickling/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main(argv: Optional[List[str]] = None) -> int:
"-i",
type=str,
default=None,
help="inject the specified Python code to be run at the end of depickling, "
help="inject the specified Python code to be run at the end of unpickling, "
"and output the resulting pickle data",
)
parser.add_argument(
Expand Down
8 changes: 3 additions & 5 deletions fickling/fickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def append_python(
def insert_magic_int(self, magic: int, index: int = -1):
"""Insert and pop a specific integer value. This is used for persistent
injections to locate the injection payload in the pickled file. The value
is artificially added by using an dummy INT + POP combination that doesn't
is artificially added by using a dummy INT + POP combination that doesn't
affect the stack when executed
:param magic: magic integer value to add
Expand Down Expand Up @@ -760,13 +760,11 @@ def __init__(self, initial_value: Iterable[T] = ()):

@overload
@abstractmethod
def __getitem__(self, i: int) -> T:
...
def __getitem__(self, i: int) -> T: ...

@overload
@abstractmethod
def __getitem__(self, s: slice) -> GenericSequence:
...
def __getitem__(self, s: slice) -> GenericSequence: ...

def __getitem__(self, i: int) -> T:
return self._stack[i]
Expand Down
151 changes: 95 additions & 56 deletions fickling/polyglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
• TorchScript v1.3: ZIP file with data.pkl and constants.pkl (2 pickle files)
• TorchScript v1.4: ZIP file with data.pkl, constants.pkl, and version (2 pickle files and a folder)
• PyTorch v1.3: ZIP file containing data.pkl (1 pickle file)
• PyTorch model archive format: ZIP file that includes Python code files and pickle files
Officially, PyTorch v0.1.1 and TorchScript < v1.4 are deprecated.
However, they are still supported by some legacy parsers
• PyTorch model archive format[ZIP]: ZIP file that includes Python code files and pickle files
This description draws from this PyTorch GitHub issue: https://github.com/pytorch/pytorch/issues/31877.
If any inaccuracies in that description are found, that should be reflected in this code.
Expand Down Expand Up @@ -71,14 +68,14 @@ def check_pickle(file):


def find_file_properties(file_path, print_properties=False):
"""For a more granular analysis, we separate property discover and format identification"""
"""For a more granular analysis, we separate property discovery and format identification"""
properties = {}
with open(file_path, "rb") as file:
# PyTorch's torch.load() enforces a specific magic number at offset 0 for ZIP
is_torch_zip = _is_zipfile(file)
properties["is_torch_zip"] = is_torch_zip

# This tarfile check has many false positivies. It is not a determinant of PyTorch v0.1.1.
# This tarfile check has many false positives. It is not a determinant of PyTorch v0.1.1.
if sys.version_info >= (3, 9):
is_tar = tarfile.is_tarfile(file)
else:
Expand Down Expand Up @@ -175,10 +172,12 @@ def check_for_corruption(properties):
return corrupted, reason


def identify_pytorch_file_format(file, print_properties=False):
def identify_pytorch_file_format(file, print_properties=False, print_results=False):
"""
We are intentionally matching the semantics of the PyTorch reference parsers.
To be polyglot-aware, we show the file formats ranked by likelihood.
Our parsing depth is at the file structure level;
However, it can be at the full parsing level if necessary.
"""
properties = find_file_properties(file, print_properties)
formats = []
Expand Down Expand Up @@ -215,16 +214,21 @@ def identify_pytorch_file_format(file, print_properties=False):
print(reason)
if len(formats) != 0:
primary = formats[0]
print("Your file is most likely of this format: ", primary, "\n")
if print_results:
print("Your file is most likely of this format: ", primary, "\n")
secondary = formats[1:]
if len(secondary) != 0:
print("It is also possible that your file can be validly interpreted as: ", secondary)
if print_results:
print(
"It is also possible that your file can be validly interpreted as: ", secondary
)
else:
print(
"""Your file may not be a PyTorch file.
No valid file formats were detected.
If this is a mistake, raise an issue on our GitHub."""
)
if print_results:
print(
"""Your file may not be a PyTorch file.
No valid file formats were detected.
If this is a mistake, raise an issue on our GitHub."""
)
return formats


Expand All @@ -236,11 +240,67 @@ def append_file(source_filename, destination_filename):
return destination_filename


def make_zip_pickle_polyglot(zip_file, pickle_file, copy=False):
def create_zip_pickle_polyglot(zip_file, pickle_file):
append_file(zip_file, pickle_file)


def create_polyglot(first_file, second_file):
def create_mar_legacy_pickle_polyglot(
files, print_results=False, polyglot_file_name="polyglot.mar.pt"
):
files.sort(key=lambda x: x[1] != "PyTorch model archive format")
if print_results:
print("Making a PyTorch MAR/PyTorch v0.1.10 polyglot")
polyglot_file = append_file(*[file[0] for file in files])
shutil.copy(polyglot_file, polyglot_file_name)
polyglot_found = True
return polyglot_found


def create_standard_torchscript_polyglot(
files, print_results=False, polyglot_file_name="polyglot.pt"
):
if print_results:
print("Making a PyTorch v1.3/TorchScript v1.4 polyglot")
print("Warning: For some parsers, this may generate polymocks instead of polyglots.")
standard_pytorch_file = [file[0] for file in files if file[1] == "PyTorch v1.3"][0]
torchscript_file = [file[0] for file in files if file[1] == "TorchScript v1.4"][0]
if polyglot_file_name is None:
polyglot_file_name = "polyglot.pt"
shutil.copy(standard_pytorch_file, polyglot_file_name)

with zipfile.ZipFile(torchscript_file, "r") as zip_b:
constants_pkl_path = check_and_find_in_zip(
zip_b, "constants.pkl", check_extension=False, return_path=True
)
version_path = check_and_find_in_zip(zip_b, "version", return_path=True)
if constants_pkl_path and version_path:
zip_b.extract(constants_pkl_path, "temp")
zip_b.extract(version_path, "temp")

with zipfile.ZipFile(polyglot_file_name, "a") as zip_out:
zip_out.write(f"temp/{constants_pkl_path}", "constants.pkl")
zip_out.write(f"temp/{version_path}", "version")

shutil.rmtree("temp")
polyglot_found = True
return polyglot_found


def create_mar_legacy_tar_polyglot(
files, print_results=False, polyglot_file_name="polyglot.mar.tar"
):
if print_results:
print("Making a PyTorch v0.1.1/PyTorch MAR polyglot")
mar_file = [file[0] for file in files if file[1] == "PyTorch model archive format"][0]
tar_file = [file[0] for file in files if file[1] == "PyTorch v0.1.1"][0]
polyglot_file = append_file(mar_file, tar_file)
shutil.copy(polyglot_file, polyglot_file_name)
polyglot_found = True
return polyglot_found


def create_polyglot(first_file, second_file, polyglot_file_name=None, print_results=True):
polyglot_found = False
temp_first_file = "temp_" + os.path.basename(first_file)
temp_second_file = "temp_" + os.path.basename(second_file)
shutil.copy(first_file, temp_first_file)
Expand All @@ -250,49 +310,28 @@ def create_polyglot(first_file, second_file):
(temp_second_file, identify_pytorch_file_format(temp_second_file)[0]),
]
formats = set(map(lambda x: x[1], files)) # noqa
polyglot_found = False
if {"PyTorch model archive format", "PyTorch v0.1.10"}.issubset(formats):
files.sort(key=lambda x: x[1] != "PyTorch model archive format")
print("Making a PyTorch MAR/PyTorch v0.1.10 polyglot")
polyglot_found = True
polyglot_file = append_file(*[file[0] for file in files])
shutil.copy(polyglot_file, "polyglot.mar.pt")
print("The polyglot is contained in polyglot.mar.pt")
if polyglot_file_name is None:
polyglot_file_name = "polyglot.mar.pt"
polyglot_found = create_mar_legacy_pickle_polyglot(files, print_results, polyglot_file_name)
if {"PyTorch v1.3", "TorchScript v1.4"}.issubset(formats):
print("Making a PyTorch v1.3/TorchScript v1.4 polyglot")
print("Warning: For some parsers, this may generate polymocks instead of polyglots.")
polyglot_found = True
standard_pytorch_file = [file[0] for file in files if file[1] == "PyTorch v1.3"][0]
torchscript_file = [file[0] for file in files if file[1] == "TorchScript v1.4"][0]
shutil.copy(standard_pytorch_file, "polyglot.pt")

with zipfile.ZipFile(torchscript_file, "r") as zip_b:
constants_pkl_path = check_and_find_in_zip(
zip_b, "constants.pkl", check_extension=False, return_path=True
)
version_path = check_and_find_in_zip(zip_b, "version", return_path=True)
if constants_pkl_path and version_path:
zip_b.extract(constants_pkl_path, "temp")
zip_b.extract(version_path, "temp")

with zipfile.ZipFile("polyglot.pt", "a") as zip_out:
zip_out.write(f"temp/{constants_pkl_path}", "constants.pkl")
zip_out.write(f"temp/{version_path}", "version")

shutil.rmtree("temp")
if {"PyTorch model archive format", "PyTorch v0.1.1"}.issubset(formats):
print("Making a PyTorch v0.1.1/PyTorch MAR polyglot")
polyglot_found = True
mar_file = [file[0] for file in files if file[1] == "PyTorch model archive format"][0]
tar_file = [file[0] for file in files if file[1] == "PyTorch v0.1.1"][0]
polyglot_file = append_file(mar_file, tar_file)
shutil.copy(polyglot_file, "polyglot.mar.tar")
print("The polyglot is contained in polyglot.mar.tar")
if polyglot_found is False:
print(
"""Fickling was not able to create any polglots.
If you think this is a mistake, raise an issue on our GitHub."""
if polyglot_file_name is None:
polyglot_file_name = "polyglot.pt"
polyglot_found = create_standard_torchscript_polyglot(
files, print_results, polyglot_file_name
)
if {"PyTorch model archive format", "PyTorch v0.1.1"}.issubset(formats):
if polyglot_file_name is None:
polyglot_file_name = "polyglot.mar.tar"
polyglot_found = create_mar_legacy_tar_polyglot(files, print_results, polyglot_file_name)
if print_results:
if polyglot_found is False:
print(
"""Fickling was not able to create any polyglots.
If you think this is a mistake, raise an issue on our GitHub."""
)
else:
print(f"The polyglot is contained in {polyglot_file_name}")
os.remove(temp_first_file)
os.remove(temp_second_file)
return polyglot_found
Loading

0 comments on commit 03c3185

Please sign in to comment.