Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add depyf.optimization and tests #58

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test_pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
echo "success"
- name: Test with pytest
run: |
coverage run --append tests/test_pytorch/test_wrapper.py
coverage run --append tests/test_pytorch/test_mp.py
coverage run --append tests/test_pytorch/test_no_graph.py
coverage run --append tests/test_pytorch/test_irregular.py
Expand Down
74 changes: 74 additions & 0 deletions depyf/optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os
import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
from typing import Callable, List

import torch


class TorchCompileWrapperWithCustomDispacther:
"""
A wrapper class for torch.compile, with a custom dispatch logic.
Subclasses should:
1. Implement the forward method
2. Implement the dispatch logic in the __call__ method
It can use `self.compiled_codes` to access the compiled bytecode,
and `with self.dispatch_to_code(index):` to dispatch to
the compiled code.
3. Implement the `__init__` method to determine how to call
`torch.compile` over the forward method.
"""

def __init__(self, compiled_callable: Callable, use_custom_dispatcher: bool = True):
self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__
self.compiled_codes: List[CodeType] = []
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)

self.use_custom_dispatcher: bool = use_custom_dispatcher

def __call__(self, *args, **kwargs):
"""Implement the dispatch logic here, beyond the torch.compile level.
NOTE: this function can have additional arguments beyond the forward
method, for directly dispatching to the compiled code.
"""
return self.compiled_callable(*args, **kwargs)

Check warning on line 37 in depyf/optimization.py

View check run for this annotation

Codecov / codecov/patch

depyf/optimization.py#L37

Added line #L37 was not covered by tests

@abstractmethod
def forward(self, *args, **kwargs):
...

Check warning on line 41 in depyf/optimization.py

View check run for this annotation

Codecov / codecov/patch

depyf/optimization.py#L41

Added line #L41 was not covered by tests

def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
"""Hook to save the compiled bytecode for direct execution."""
if old_code is not self.original_code_object:
return

Check warning on line 46 in depyf/optimization.py

View check run for this annotation

Codecov / codecov/patch

depyf/optimization.py#L46

Added line #L46 was not covered by tests
frame = sys._getframe()
while True:
frame = frame.f_back
code_name = frame.f_code.co_name
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
if code_name == "_compile" and file_name == "convert_frame.py":
break
frame = frame.f_locals["frame"]
assert frame.f_code == old_code

if frame.f_locals["self"] is not self:
return

self.compiled_codes.append(new_code)

@contextmanager
def dispatch_to_code(self, index: int):
"""Context manager to dispatch to the compiled code.
Why does this work? Because Dynamo guarantees that the compiled
bytecode has exactly the same arguments, cell variables, and free
variables as the original code. Therefore we can directly switch
the code object in the function and call it.

See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
""" # noqa
self.__class__.forward.__code__ = self.compiled_codes[index]
yield
self.__class__.forward.__code__ = self.original_code_object
58 changes: 58 additions & 0 deletions tests/test_pytorch/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Optional

import torch

from depyf.optimization import TorchCompileWrapperWithCustomDispacther


class MyMod(torch.nn.Module):

def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
if cache is not None:
return x + cache
return x * 2


class MyWrapper(TorchCompileWrapperWithCustomDispacther):

def __init__(self, model):
self.model = model
compiled_callable = torch.compile(self.forward, backend="eager")
super().__init__(compiled_callable)

def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# this is the function to be compiled
return self.model(x, cache)

def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# let torch.compile compile twice
if len(self.compiled_codes) == 2:
dispatch_id = 0 if cache is None else 1
with self.dispatch_to_code(dispatch_id):
return self.forward(x, cache)
else:
return self.compiled_callable(x, cache)


mod = MyMod()
wrappers = []
for i in range(3):
torch._dynamo.reset()
wrapper = MyWrapper(mod)
wrappers.append(wrapper)
x = torch.tensor([1])
wrapper(x, None) # profile run, compile
# create a cache tensor
cache = torch.tensor([2])
wrapper(x, cache) # warm up with cache, recompile

# for new input, dispatch to the compiled code directly
new_x = torch.tensor([3])
assert wrapper(new_x,
None).item() == 6 # dispatch to the first compiled code
assert wrapper(
new_x, cache).item() == 5 # dispatch to the second compiled code

for wrapper in wrappers:
# make sure they have independent compiled codes
assert len(wrapper.compiled_codes) == 2
Loading