Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine PyTorch and Polyglot modules #89

Merged
merged 17 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ Module(

### PyTorch polyglots

We currently support inspecting, identifying, and creating file polyglots between the
PyTorch contains multiple file formats with which one can make polyglot files, which
are files that can be validly interpreted as more than one file format.
Fickling supports identifying, inspecting, and creating polyglots with the
following PyTorch file formats:

* **PyTorch v0.1.1**: Tar file with sys_info, pickle, storages, and tensors
Expand All @@ -172,7 +174,7 @@ following PyTorch file formats:
* **TorchScript v1.3**: ZIP file with data.pkl and constants.pkl (2 pickle files)
* **TorchScript v1.4**: ZIP file with data.pkl, constants.pkl, and version (2 pickle files and a folder)
* **PyTorch v1.3**: ZIP file containing data.pkl (1 pickle file)
* **PyTorch model archive format**: ZIP file that includes Python code files and pickle files
* **PyTorch model archive format[ZIP]**: ZIP file that includes Python code files and pickle files

```python
>> import torch
Expand Down
3 changes: 2 additions & 1 deletion example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* [context_manager.py](https://github.com/trailofbits/fickling/blob/master/example/context_manager.py): Halt the deserialization of a malicious pickle file with the fickling context manager
* [fault_injection.py](https://github.com/trailofbits/fickling/blob/master/example/fault_injection.py): Perform a fault injection on a PyTorch model and then analyze the result with `check_safety`
* [inject_mobilenet.py](https://github.com/trailofbits/fickling/blob/master/example/inject_mobilenet.py): Override the `eval` method of a ML model using fickling and apply `fickling.is_likely_safe` to the model file
* [inject_pytorch.py](https://github.com/trailofbits/fickling/blob/master/example/inject_pytorch.py): Inject a model loaded from a PyTorch file with malicious code using fickling’s PyTorch module
* [inject_pytorch.py](https://github.com/trailofbits/fickling/blob/master/example/inject_pytorch.py): Inject a model loaded from a PyTorch file with malicious code using the PyTorch module
* [numpy_poc.py](https://github.com/trailofbits/fickling/blob/master/example/numpy_poc.py): Analyze a malicious payload passed to `numpy.load()`
* [trace_binary.py](https://github.com/trailofbits/fickling/blob/master/example/trace_binary.py): Decompile a payload using the tracing module
* [identify_pytorch_file.py](https://github.com/trailofbits/fickling/blob/master/example/identify_pytorch_file.py): Identify 2 PyTorch files that are different formats
1 change: 1 addition & 0 deletions example/fault_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

Note you may need to run `pip install pytorchfi`
"""

import pickle

import torch
Expand Down
2 changes: 1 addition & 1 deletion example/hook_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# Set up global fickling hook
fickling.always_check_safety()
# Eauivalent to fickling.hook.run_hook()
# Equivalent to fickling.hook.run_hook()

# Fickling can check a pickle file for safety prior to running it
test_list = [1, 2, 3]
Expand Down
16 changes: 16 additions & 0 deletions example/identify_pytorch_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
import torchvision.models as models

import fickling.polyglot as polyglot

model = models.mobilenet_v2()
torch.save(model, "mobilenet.pth")
torch.save(model, "legacy_mobilenet.pth", _use_new_zipfile_serialization=False)

print("Identifying PyTorch v1.3 file:")
potential_formats = polyglot.identify_pytorch_file_format("mobilenet.pth", print_results=True)

print("Identifying PyTorch v0.1.10 file:")
potential_formats_legacy = polyglot.identify_pytorch_file_format(
"legacy_mobilenet.pth", print_results=True
)
2 changes: 1 addition & 1 deletion fickling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@

# The above lines enables `fickling.load()` and `with fickling.check_safety()`
# The comments are necessary to comply with linters
__version__ = "0.0.8"
__version__ = "0.1.0"
18 changes: 10 additions & 8 deletions fickling/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def __init__(
self.severity: Severity = severity
self.message: Optional[str] = message
self.analysis_name: str = analysis_name
self.trigger: Optional[
str
] = trigger # Field to store the trigger code fragment or artifact
self.trigger: Optional[str] = (
trigger # Field to store the trigger code fragment or artifact
)

def __lt__(self, other):
return isinstance(other, AnalysisResult) and (
Expand Down Expand Up @@ -296,11 +296,13 @@ def to_dict(self, verbosity: Severity = Severity.SUSPICIOUS):
analysis_message = self.to_string(verbosity)
severity_data = {
"severity": self.severity.name,
"analysis": analysis_message
if analysis_message.strip()
else "Warning: Fickling failed to detect any overtly unsafe code, but the pickle file"
"may still be unsafe."
"Do not unpickle this file if it is from an untrusted source!\n\n",
"analysis": (
analysis_message
if analysis_message.strip()
else "Warning: Fickling failed to detect any overtly unsafe code,"
"but the pickle file may still be unsafe."
"Do not unpickle this file if it is from an untrusted source!\n\n"
),
"detailed_results": self.detailed_results(),
}
return severity_data
Expand Down
2 changes: 1 addition & 1 deletion fickling/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main(argv: Optional[List[str]] = None) -> int:
"-i",
type=str,
default=None,
help="inject the specified Python code to be run at the end of depickling, "
help="inject the specified Python code to be run at the end of unpickling, "
"and output the resulting pickle data",
)
parser.add_argument(
Expand Down
8 changes: 3 additions & 5 deletions fickling/fickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def append_python(
def insert_magic_int(self, magic: int, index: int = -1):
"""Insert and pop a specific integer value. This is used for persistent
injections to locate the injection payload in the pickled file. The value
is artificially added by using an dummy INT + POP combination that doesn't
is artificially added by using a dummy INT + POP combination that doesn't
affect the stack when executed

:param magic: magic integer value to add
Expand Down Expand Up @@ -760,13 +760,11 @@ def __init__(self, initial_value: Iterable[T] = ()):

@overload
@abstractmethod
def __getitem__(self, i: int) -> T:
...
def __getitem__(self, i: int) -> T: ...

@overload
@abstractmethod
def __getitem__(self, s: slice) -> GenericSequence:
...
def __getitem__(self, s: slice) -> GenericSequence: ...

def __getitem__(self, i: int) -> T:
return self._stack[i]
Expand Down
151 changes: 95 additions & 56 deletions fickling/polyglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
• TorchScript v1.3: ZIP file with data.pkl and constants.pkl (2 pickle files)
• TorchScript v1.4: ZIP file with data.pkl, constants.pkl, and version (2 pickle files and a folder)
• PyTorch v1.3: ZIP file containing data.pkl (1 pickle file)
• PyTorch model archive format: ZIP file that includes Python code files and pickle files

Officially, PyTorch v0.1.1 and TorchScript < v1.4 are deprecated.
However, they are still supported by some legacy parsers
• PyTorch model archive format[ZIP]: ZIP file that includes Python code files and pickle files

This description draws from this PyTorch GitHub issue: https://github.com/pytorch/pytorch/issues/31877.
If any inaccuracies in that description are found, that should be reflected in this code.
Expand Down Expand Up @@ -71,14 +68,14 @@ def check_pickle(file):


def find_file_properties(file_path, print_properties=False):
"""For a more granular analysis, we separate property discover and format identification"""
"""For a more granular analysis, we separate property discovery and format identification"""
properties = {}
with open(file_path, "rb") as file:
# PyTorch's torch.load() enforces a specific magic number at offset 0 for ZIP
is_torch_zip = _is_zipfile(file)
properties["is_torch_zip"] = is_torch_zip

# This tarfile check has many false positivies. It is not a determinant of PyTorch v0.1.1.
# This tarfile check has many false positives. It is not a determinant of PyTorch v0.1.1.
if sys.version_info >= (3, 9):
is_tar = tarfile.is_tarfile(file)
else:
Expand Down Expand Up @@ -175,10 +172,12 @@ def check_for_corruption(properties):
return corrupted, reason


def identify_pytorch_file_format(file, print_properties=False):
def identify_pytorch_file_format(file, print_properties=False, print_results=False):
"""
We are intentionally matching the semantics of the PyTorch reference parsers.
To be polyglot-aware, we show the file formats ranked by likelihood.
Our parsing depth is at the file structure level;
However, it can be at the full parsing level if necessary.
"""
properties = find_file_properties(file, print_properties)
formats = []
Expand Down Expand Up @@ -215,16 +214,21 @@ def identify_pytorch_file_format(file, print_properties=False):
print(reason)
if len(formats) != 0:
primary = formats[0]
print("Your file is most likely of this format: ", primary, "\n")
if print_results:
print("Your file is most likely of this format: ", primary, "\n")
secondary = formats[1:]
if len(secondary) != 0:
print("It is also possible that your file can be validly interpreted as: ", secondary)
if print_results:
print(
"It is also possible that your file can be validly interpreted as: ", secondary
)
else:
print(
"""Your file may not be a PyTorch file.
No valid file formats were detected.
If this is a mistake, raise an issue on our GitHub."""
)
if print_results:
print(
"""Your file may not be a PyTorch file.
No valid file formats were detected.
If this is a mistake, raise an issue on our GitHub."""
)
return formats


Expand All @@ -236,11 +240,67 @@ def append_file(source_filename, destination_filename):
return destination_filename


def make_zip_pickle_polyglot(zip_file, pickle_file, copy=False):
def create_zip_pickle_polyglot(zip_file, pickle_file):
append_file(zip_file, pickle_file)


def create_polyglot(first_file, second_file):
def create_mar_legacy_pickle_polyglot(
files, print_results=False, polyglot_file_name="polyglot.mar.pt"
):
files.sort(key=lambda x: x[1] != "PyTorch model archive format")
if print_results:
print("Making a PyTorch MAR/PyTorch v0.1.10 polyglot")
polyglot_file = append_file(*[file[0] for file in files])
shutil.copy(polyglot_file, polyglot_file_name)
polyglot_found = True
return polyglot_found


def create_standard_torchscript_polyglot(
files, print_results=False, polyglot_file_name="polyglot.pt"
):
if print_results:
print("Making a PyTorch v1.3/TorchScript v1.4 polyglot")
print("Warning: For some parsers, this may generate polymocks instead of polyglots.")
standard_pytorch_file = [file[0] for file in files if file[1] == "PyTorch v1.3"][0]
torchscript_file = [file[0] for file in files if file[1] == "TorchScript v1.4"][0]
if polyglot_file_name is None:
polyglot_file_name = "polyglot.pt"
shutil.copy(standard_pytorch_file, polyglot_file_name)

with zipfile.ZipFile(torchscript_file, "r") as zip_b:
constants_pkl_path = check_and_find_in_zip(
zip_b, "constants.pkl", check_extension=False, return_path=True
)
version_path = check_and_find_in_zip(zip_b, "version", return_path=True)
if constants_pkl_path and version_path:
zip_b.extract(constants_pkl_path, "temp")
zip_b.extract(version_path, "temp")

with zipfile.ZipFile(polyglot_file_name, "a") as zip_out:
zip_out.write(f"temp/{constants_pkl_path}", "constants.pkl")
zip_out.write(f"temp/{version_path}", "version")

shutil.rmtree("temp")
polyglot_found = True
return polyglot_found


def create_mar_legacy_tar_polyglot(
files, print_results=False, polyglot_file_name="polyglot.mar.tar"
):
if print_results:
print("Making a PyTorch v0.1.1/PyTorch MAR polyglot")
mar_file = [file[0] for file in files if file[1] == "PyTorch model archive format"][0]
tar_file = [file[0] for file in files if file[1] == "PyTorch v0.1.1"][0]
polyglot_file = append_file(mar_file, tar_file)
shutil.copy(polyglot_file, polyglot_file_name)
polyglot_found = True
return polyglot_found


def create_polyglot(first_file, second_file, polyglot_file_name=None, print_results=True):
polyglot_found = False
temp_first_file = "temp_" + os.path.basename(first_file)
temp_second_file = "temp_" + os.path.basename(second_file)
shutil.copy(first_file, temp_first_file)
Expand All @@ -250,49 +310,28 @@ def create_polyglot(first_file, second_file):
(temp_second_file, identify_pytorch_file_format(temp_second_file)[0]),
]
formats = set(map(lambda x: x[1], files)) # noqa
polyglot_found = False
if {"PyTorch model archive format", "PyTorch v0.1.10"}.issubset(formats):
files.sort(key=lambda x: x[1] != "PyTorch model archive format")
print("Making a PyTorch MAR/PyTorch v0.1.10 polyglot")
polyglot_found = True
polyglot_file = append_file(*[file[0] for file in files])
shutil.copy(polyglot_file, "polyglot.mar.pt")
print("The polyglot is contained in polyglot.mar.pt")
if polyglot_file_name is None:
polyglot_file_name = "polyglot.mar.pt"
polyglot_found = create_mar_legacy_pickle_polyglot(files, print_results, polyglot_file_name)
if {"PyTorch v1.3", "TorchScript v1.4"}.issubset(formats):
print("Making a PyTorch v1.3/TorchScript v1.4 polyglot")
print("Warning: For some parsers, this may generate polymocks instead of polyglots.")
polyglot_found = True
standard_pytorch_file = [file[0] for file in files if file[1] == "PyTorch v1.3"][0]
torchscript_file = [file[0] for file in files if file[1] == "TorchScript v1.4"][0]
shutil.copy(standard_pytorch_file, "polyglot.pt")

with zipfile.ZipFile(torchscript_file, "r") as zip_b:
constants_pkl_path = check_and_find_in_zip(
zip_b, "constants.pkl", check_extension=False, return_path=True
)
version_path = check_and_find_in_zip(zip_b, "version", return_path=True)
if constants_pkl_path and version_path:
zip_b.extract(constants_pkl_path, "temp")
zip_b.extract(version_path, "temp")

with zipfile.ZipFile("polyglot.pt", "a") as zip_out:
zip_out.write(f"temp/{constants_pkl_path}", "constants.pkl")
zip_out.write(f"temp/{version_path}", "version")

shutil.rmtree("temp")
if {"PyTorch model archive format", "PyTorch v0.1.1"}.issubset(formats):
print("Making a PyTorch v0.1.1/PyTorch MAR polyglot")
polyglot_found = True
mar_file = [file[0] for file in files if file[1] == "PyTorch model archive format"][0]
tar_file = [file[0] for file in files if file[1] == "PyTorch v0.1.1"][0]
polyglot_file = append_file(mar_file, tar_file)
shutil.copy(polyglot_file, "polyglot.mar.tar")
print("The polyglot is contained in polyglot.mar.tar")
if polyglot_found is False:
print(
"""Fickling was not able to create any polglots.
If you think this is a mistake, raise an issue on our GitHub."""
if polyglot_file_name is None:
polyglot_file_name = "polyglot.pt"
polyglot_found = create_standard_torchscript_polyglot(
files, print_results, polyglot_file_name
)
if {"PyTorch model archive format", "PyTorch v0.1.1"}.issubset(formats):
if polyglot_file_name is None:
polyglot_file_name = "polyglot.mar.tar"
polyglot_found = create_mar_legacy_tar_polyglot(files, print_results, polyglot_file_name)
if print_results:
if polyglot_found is False:
print(
"""Fickling was not able to create any polyglots.
If you think this is a mistake, raise an issue on our GitHub."""
)
else:
print(f"The polyglot is contained in {polyglot_file_name}")
os.remove(temp_first_file)
os.remove(temp_second_file)
return polyglot_found
Loading
Loading