Skip to content

Commit

Permalink
Merge pull request #30 from chengzeyi/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
chengzeyi authored Nov 13, 2023
2 parents a4ad186 + 3e31fa4 commit 2ddea8f
Show file tree
Hide file tree
Showing 5 changed files with 370 additions and 171 deletions.
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ __NOTE__: `stable-fast` is currently only in beta stage and is prone to be buggy
- [Install From Source](#install-from-source)
- [Usage](#usage)
- [Optimize StableDiffusionPipeline](#optimize-stablediffusionpipeline)
- [Dynamically Switch LoRA](#dynamically-switch-lora)
- [Some Common Methods To Speed Up PyTorch](#some-common-methods-to-speed-up-pytorch)
- [Trouble Shooting](#trouble-shooting)

Expand Down Expand Up @@ -249,6 +250,40 @@ output_image = compiled_model(**kwarg_inputs).images[0]
output_image = compiled_model(**kwarg_inputs).images[0]
```

### Dynamically Switch LoRA

Switching LoRA dynamically is supported but you need to do some extra work.
It is possible because the compiled graph and `CUDA Graph` shares the same
underlaying data (pointers) with the original UNet model. So all you need to do
is to update the original UNet model's parameters inplace.

The following code assumes you have already load a LoRA and compiled the model,
and you want to switch to another LoRA.

```python
def update_state_dict(dst, src):
for key, value in src.items():
# Do inplace copy.
# As the traced forward function shares the same reference of the tensors,
# this modification will be reflected in the traced forward function.
dst[key].copy_(value)

# Switch "another" LoRA into UNet
def switch_lora(unet, lora):
# Store the original UNet parameters
state_dict = unet.state_dict()
# Load another LoRA into unet
unet.load_attn_procs(lora)
# Inplace copy current UNet parameters to the original unet parameters
update_state_dict(state_dict, unet.state_dict())
# Load the original UNet parameters back.
# We use assign=True because we still want to hold the references
# of the original UNet parameters
unet.load_state_dict(state_dict, assign=True)

switch_lora(compiled_model.unet, lora_b_path)
```

### Some Common Methods To Speed Up PyTorch

```bash
Expand Down
122 changes: 73 additions & 49 deletions sfast/compilers/stable_diffusion_pipeline_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import sfast
from sfast.jit import passes
from sfast.jit.trace_helper import (lazy_trace, to_module)
from sfast.jit.trace_helper import lazy_trace, to_module
from sfast.jit import utils as jit_utils
from sfast.cuda.graphs import make_dynamic_graphed_callable
from sfast.utils import gpu_device
Expand All @@ -15,38 +15,41 @@


class CompilationConfig:

@dataclass
class Default:
memory_format: torch.memory_format = torch.channels_last if gpu_device.device_has_tensor_core(
) else torch.contiguous_format
memory_format: torch.memory_format = (
torch.channels_last
if gpu_device.device_has_tensor_core()
else torch.contiguous_format
)
enable_jit: bool = True
enable_jit_freeze: bool = True
enable_cnn_optimization: bool = True
prefer_lowp_gemm: bool = True
enable_xformers: bool = False
enable_cuda_graph: bool = False
enable_triton: bool = False
enable_quantization: bool = False
trace_scheduler: bool = False


def compile(m, config):
with torch.no_grad():
enable_cuda_graph = config.enable_cuda_graph and m.device.type == 'cuda'
enable_cuda_graph = config.enable_cuda_graph and m.device.type == "cuda"

scheduler = m.scheduler
scheduler._set_timesteps = scheduler.set_timesteps

def set_timesteps(self, num_timesteps: int,
device: Union[str, torch.device]):
return self._set_timesteps(num_timesteps,
device=torch.device('cpu'))
def set_timesteps(self, num_timesteps: int, device: Union[str, torch.device]):
return self._set_timesteps(num_timesteps, device=torch.device("cpu"))

scheduler.set_timesteps = set_timesteps.__get__(scheduler)

if config.enable_xformers:
if config.enable_jit:
from sfast.utils.xformers_attention import xformers_memory_efficient_attention
from sfast.utils.xformers_attention import (
xformers_memory_efficient_attention,
)
from xformers import ops

ops.memory_efficient_attention = xformers_memory_efficient_attention
Expand All @@ -56,9 +59,14 @@ def set_timesteps(self, num_timesteps: int,
if config.memory_format == torch.channels_last:
m.unet.to(memory_format=torch.channels_last)
m.vae.to(memory_format=torch.channels_last)
if hasattr(m, 'controlnet'):
if hasattr(m, "controlnet"):
m.controlnet.to(memory_format=torch.channels_last)

if config.enable_quantization:
m.unet = torch.quantization.quantize_dynamic(
m.unet, {torch.nn.Linear}, dtype=torch.qint8, inplace=True
)

if config.enable_jit:
modify_model = functools.partial(
_modify_model,
Expand All @@ -68,18 +76,27 @@ def set_timesteps(self, num_timesteps: int,
memory_format=config.memory_format,
)

def ts_compiler(m,
call_helper,
inputs,
kwarg_inputs,
freeze=False,
enable_cuda_graph=False):
def ts_compiler(
m,
call_helper,
inputs,
kwarg_inputs,
freeze=False,
enable_cuda_graph=False,
):
with torch.jit.optimized_execution(True):
if freeze:
# raw freeze causes Tensor reference leak
# because the constant Tensors in the GraphFunction of
# the compilation unit are never freed.
m = jit_utils.better_freeze(m)
m = jit_utils.better_freeze(
m,
# preserve_parameters=True is probably not needed
# for load_state_dict() to work, because the
# traced graph and CUDA graph shares the same underlying
# data (pointer) for the parameters.
# preserve_parameters=False,
)
modify_model(m)

if enable_cuda_graph:
Expand All @@ -93,18 +110,20 @@ def ts_compiler(m,
freeze=config.enable_jit_freeze,
),
check_trace=False,
strict=False)

m.text_encoder.forward = lazy_trace_(
to_module(m.text_encoder.forward))
unet_forward = lazy_trace(to_module(m.unet.forward),
ts_compiler=functools.partial(
ts_compiler,
freeze=config.enable_jit_freeze,
enable_cuda_graph=enable_cuda_graph,
),
check_trace=False,
strict=False)
strict=False,
)

m.text_encoder.forward = lazy_trace_(to_module(m.text_encoder.forward))
unet_forward = lazy_trace(
to_module(m.unet.forward),
ts_compiler=functools.partial(
ts_compiler,
freeze=config.enable_jit_freeze,
enable_cuda_graph=enable_cuda_graph,
),
check_trace=False,
strict=False,
)

@functools.wraps(m.unet.forward)
def unet_forward_wrapper(sample, t, *args, **kwargs):
Expand All @@ -113,28 +132,32 @@ def unet_forward_wrapper(sample, t, *args, **kwargs):

m.unet.forward = unet_forward_wrapper

if not packaging.version.parse('2.0.0') <= packaging.version.parse(
torch.__version__) < packaging.version.parse('2.1.0'):
'''
if (
not packaging.version.parse("2.0.0")
<= packaging.version.parse(torch.__version__)
< packaging.version.parse("2.1.0")
):
"""
Weird bug in PyTorch 2.0.x
RuntimeError: shape '[512, 512, 64, 64]' is invalid for input of size 2097152
When executing AttnProcessor in TorchScript
'''
"""
m.vae.decode = lazy_trace_(to_module(m.vae.decode))
# For img2img
m.vae.encoder.forward = lazy_trace_(
to_module(m.vae.encoder.forward))
m.vae.encoder.forward = lazy_trace_(to_module(m.vae.encoder.forward))
m.vae.quant_conv.forward = lazy_trace_(
to_module(m.vae.quant_conv.forward))
to_module(m.vae.quant_conv.forward)
)

if config.trace_scheduler:
m.scheduler.scale_model_input = lazy_trace_(
to_module(m.scheduler.scale_model_input))
to_module(m.scheduler.scale_model_input)
)
m.scheduler.step = lazy_trace_(to_module(m.scheduler.step))

if hasattr(m, 'controlnet'):
if hasattr(m, "controlnet"):
controlnet_forward = lazy_trace(
to_module(m.controlnet.forward),
ts_compiler=functools.partial(
Expand All @@ -143,7 +166,8 @@ def unet_forward_wrapper(sample, t, *args, **kwargs):
enable_cuda_graph=enable_cuda_graph,
),
check_trace=False,
strict=False)
strict=False,
)

@functools.wraps(m.controlnet.forward)
def controlnet_forward_wrapper(sample, t, *args, **kwargs):
Expand All @@ -155,11 +179,13 @@ def controlnet_forward_wrapper(sample, t, *args, **kwargs):
return m


def _modify_model(m,
enable_cnn_optimization=True,
prefer_lowp_gemm=True,
enable_triton=False,
memory_format=None):
def _modify_model(
m,
enable_cnn_optimization=True,
prefer_lowp_gemm=True,
enable_triton=False,
memory_format=None,
):
if enable_triton:
from sfast.jit.passes import triton_passes

Expand All @@ -181,10 +207,8 @@ def _modify_model(m,

if memory_format is not None:
sfast._C._jit_pass_convert_op_input_tensors(
m.graph,
'aten::_convolution',
indices=[0],
memory_format=memory_format)
m.graph, "aten::_convolution", indices=[0], memory_format=memory_format
)

if enable_cnn_optimization:
passes.jit_pass_optimize_cnn(m.graph)
Expand Down
64 changes: 54 additions & 10 deletions sfast/jit/utils.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
import inspect
import functools
import torch
import sfast
import functools


class ScriptModuleClearHook:

def __init__(self, script_module_c):
self.class_type = sfast._C._jit_get_module_type(script_module_c)

def __del__(self):
sfast._C._jit_clear_class_type_registration(self.class_type)


def attach_script_module_clear_hook(script_module,
attr_name='_module_registration_clear_hook'
):
def attach_script_module_clear_hook(
script_module, attr_name="_module_registration_clear_hook"
):
script_module._register_attribute(
attr_name, torch._C.PyObjectType.get(),
ScriptModuleClearHook(script_module))
for child_name, child_module in torch._C._jit_debug_module_iterators(
script_module)['named_children']:
attr_name, torch._C.PyObjectType.get(), ScriptModuleClearHook(script_module)
)
for child_name, child_module in torch._C._jit_debug_module_iterators(script_module)[
"named_children"
]:
attach_script_module_clear_hook(child_module, attr_name)


Expand All @@ -32,6 +33,49 @@ def better_trace(func, *args, **kwargs):

@functools.wraps(torch.jit.freeze)
def better_freeze(script_module, *args, **kwargs):
freezed_module = torch.jit.freeze(script_module, *args, **kwargs)
freeze = torch.jit.freeze
if (
"preserve_parameters" in kwargs
and "preserve_parameters" not in inspect.signature(freeze).parameters
):
from typing import List, Optional
from torch.jit._script import RecursiveScriptModule, ScriptModule

# Based on https://github.com/pytorch/pytorch/blob/7bcf7da3a268b435777fe87c7794c382f444e86d/torch/jit/_freeze.py#L13C1-L13C1
def freeze(
mod,
preserved_attrs: Optional[List[str]] = None,
optimize_numerics: bool = True,
preserve_parameters: bool = False,
):
if not isinstance(mod, ScriptModule):
raise RuntimeError(
"Freezing expects a ScriptModule as input. "
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'."
)

if mod.training:
raise RuntimeError(
"Freezing is currently only implemented for modules in eval mode. "
"Please call .eval() on your module before freezing."
)

preserved_attrs = preserved_attrs if preserved_attrs is not None else []

out = RecursiveScriptModule(
torch._C._freeze_module(
mod._c, preserved_attrs, preserveParameters=preserve_parameters
)
)
RecursiveScriptModule._finalize_scriptmodule(out)

preserved_methods = [x for x in preserved_attrs if mod._c._has_method(x)]
torch.jit.run_frozen_optimizations(
out, optimize_numerics, preserved_methods
)

return out

freezed_module = freeze(script_module, *args, **kwargs)
attach_script_module_clear_hook(freezed_module._c)
return freezed_module
Loading

0 comments on commit 2ddea8f

Please sign in to comment.