Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #33

Merged
merged 5 commits into from
Nov 14, 2023
Merged

Dev #33

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ __NOTE__: `stable-fast` is currently only in beta stage and is prone to be buggy
Performance varies very greatly across different hardware/software/platform/driver configurations.
It is very hard to benchmark accurately. And preparing the environment for benchmarking is also a hard job.
I have tested on some platforms before but the results may still be inaccurate.
Note that when benchmarking, the progress bar showed by `tqdm` may be inaccurate because of the asynchronous nature of CUDA.

`stable-fast` is expected to work better on newer GPUs and newer CUDA versions.
__On older GPUs, the performance increase might be limited.__
Expand Down Expand Up @@ -253,18 +254,20 @@ output_image = compiled_model(**kwarg_inputs).images[0]
### Dynamically Switch LoRA

Switching LoRA dynamically is supported but you need to do some extra work.
It is possible because the compiled graph and `CUDA Graph` shares the same
It is possible because the compiled graph and `CUDA Graph` share the same
underlaying data (pointers) with the original UNet model. So all you need to do
is to update the original UNet model's parameters inplace.

The following code assumes you have already load a LoRA and compiled the model,
and you want to switch to another LoRA.

```python
# load_state_dict with assign=True requires torch >= 2.1.0

def update_state_dict(dst, src):
for key, value in src.items():
# Do inplace copy.
# As the traced forward function shares the same reference of the tensors,
# As the traced forward function shares the same underlaying data (pointers),
# this modification will be reflected in the traced forward function.
dst[key].copy_(value)

Expand Down
295 changes: 156 additions & 139 deletions sfast/compilers/stable_diffusion_pipeline_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import sfast
from sfast.jit import passes
from sfast.jit.trace_helper import lazy_trace, to_module
from sfast.jit.trace_helper import lazy_trace
from sfast.jit import utils as jit_utils
from sfast.cuda.graphs import make_dynamic_graphed_callable
from sfast.utils import gpu_device
Expand All @@ -15,168 +15,117 @@


class CompilationConfig:

@dataclass
class Default:
'''
Default compilation config

memory_format:
channels_last if tensor core is available, otherwise contiguous_format.
On GPUs with tensor core, channels_last is faster
enable_jit:
Whether to enable JIT, most optimizations are done with JIT
enable_jit_freeze:
Whether to freeze the model after JIT tracing.
Freezing the model will enable us to optimize the model further.
preserve_parameters:
Whether to preserve parameters when freezing the model.
If True, parameters will be preserved, but the model will be a bit slower.
If False, parameters will be marked as constants, and the model will be faster.
However, if parameters are not preserved, LoRA cannot be switched dynamically.
enable_cnn_optimization:
Whether to enable CNN optimization by fusion.
prefer_lowp_gemm:
Whether to prefer low-precision GEMM and a series of fusion optimizations.
This will make the model faster, but may cause numerical issues.
enable_xformers:
Whether to enable xformers and hijack it to make it compatible with JIT tracing.
enable_cuda_graph:
Whether to enable CUDA graph. CUDA Graph will significantly speed up the model,
by reducing the overhead of CUDA kernel launch, memory allocation, etc.
However, it will also increase the memory usage.
Our implementation of CUDA graph supports dynamic shape by caching graphs of
different shapes.
enable_triton:
Whether to enable Triton generated CUDA kernels.
Triton generated CUDA kernels are faster than PyTorch's CUDA kernels.
However, Triton has a lot of bugs, and can increase the CPU overhead,
though the overhead can be reduced by enabling CUDA graph.
trace_scheduler:
Whether to trace the scheduler.
'''
memory_format: torch.memory_format = (
torch.channels_last
if gpu_device.device_has_tensor_core()
else torch.contiguous_format
)
torch.channels_last if gpu_device.device_has_tensor_core() else
torch.contiguous_format)
enable_jit: bool = True
enable_jit_freeze: bool = True
preserve_parameters: bool = True
enable_cnn_optimization: bool = True
prefer_lowp_gemm: bool = True
enable_xformers: bool = False
enable_cuda_graph: bool = False
enable_triton: bool = False
enable_quantization: bool = False
trace_scheduler: bool = False


def compile(m, config):
with torch.no_grad():
enable_cuda_graph = config.enable_cuda_graph and m.device.type == "cuda"
m.unet = compile_unet(m.unet, config)
if hasattr(m, 'controlnet'):
m.controlnet = compile_unet(m.controlnet, config)

scheduler = m.scheduler
scheduler._set_timesteps = scheduler.set_timesteps
if config.enable_xformers:
_enable_xformers(m)

def set_timesteps(self, num_timesteps: int, device: Union[str, torch.device]):
return self._set_timesteps(num_timesteps, device=torch.device("cpu"))
if config.memory_format is not None:
m.vae.to(memory_format=config.memory_format)

scheduler.set_timesteps = set_timesteps.__get__(scheduler)
if config.enable_jit:
lazy_trace_ = _build_lazy_trace(config)

if config.enable_xformers:
if config.enable_jit:
from sfast.utils.xformers_attention import (
xformers_memory_efficient_attention,
)
from xformers import ops
m.text_encoder.forward = lazy_trace_(m.text_encoder.forward)
if (not packaging.version.parse('2.0.0') <= packaging.version.parse(
torch.__version__) < packaging.version.parse('2.1.0')):
"""
Weird bug in PyTorch 2.0.x

ops.memory_efficient_attention = xformers_memory_efficient_attention
RuntimeError: shape '[512, 512, 64, 64]' is invalid for input of size 2097152

m.enable_xformers_memory_efficient_attention()
When executing AttnProcessor in TorchScript
"""
m.vae.decode = lazy_trace_(m.vae.decode)
# For img2img
m.vae.encoder.forward = lazy_trace_(m.vae.encoder.forward)
m.vae.quant_conv.forward = lazy_trace_(m.vae.quant_conv.forward)
if config.trace_scheduler:
m.scheduler.scale_model_input = lazy_trace_(
m.scheduler.scale_model_input)
m.scheduler.step = lazy_trace_(m.scheduler.step)

if config.memory_format == torch.channels_last:
m.unet.to(memory_format=torch.channels_last)
m.vae.to(memory_format=torch.channels_last)
if hasattr(m, "controlnet"):
m.controlnet.to(memory_format=torch.channels_last)
return m

if config.enable_quantization:
m.unet = torch.quantization.quantize_dynamic(
m.unet, {torch.nn.Linear}, dtype=torch.qint8, inplace=True
)

if config.enable_jit:
modify_model = functools.partial(
_modify_model,
enable_cnn_optimization=config.enable_cnn_optimization,
prefer_lowp_gemm=config.prefer_lowp_gemm,
enable_triton=config.enable_triton,
memory_format=config.memory_format,
)
def compile_unet(m, config):
# attribute `device` is not generally available
device = m.device if hasattr(m, 'device') else torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')

def ts_compiler(
m,
call_helper,
inputs,
kwarg_inputs,
freeze=False,
enable_cuda_graph=False,
):
with torch.jit.optimized_execution(True):
if freeze:
# raw freeze causes Tensor reference leak
# because the constant Tensors in the GraphFunction of
# the compilation unit are never freed.
m = jit_utils.better_freeze(
m,
# preserve_parameters=True is probably not needed
# for load_state_dict() to work, because the
# traced graph and CUDA graph shares the same underlying
# data (pointer) for the parameters.
# preserve_parameters=False,
)
modify_model(m)

if enable_cuda_graph:
m = make_dynamic_graphed_callable(m)
return m

lazy_trace_ = functools.partial(
lazy_trace,
ts_compiler=functools.partial(
ts_compiler,
freeze=config.enable_jit_freeze,
),
check_trace=False,
strict=False,
)
enable_cuda_graph = config.enable_cuda_graph and device.type == 'cuda'

m.text_encoder.forward = lazy_trace_(to_module(m.text_encoder.forward))
unet_forward = lazy_trace(
to_module(m.unet.forward),
ts_compiler=functools.partial(
ts_compiler,
freeze=config.enable_jit_freeze,
enable_cuda_graph=enable_cuda_graph,
),
check_trace=False,
strict=False,
)
if config.enable_xformers:
_enable_xformers(m)

if config.memory_format is not None:
m.to(memory_format=config.memory_format)

if config.enable_jit:
lazy_trace_ = _build_lazy_trace(config)
m.forward = lazy_trace_(m.forward)

@functools.wraps(m.unet.forward)
def unet_forward_wrapper(sample, t, *args, **kwargs):
t = t.to(device=sample.device)
return unet_forward(sample, t, *args, **kwargs)

m.unet.forward = unet_forward_wrapper

if (
not packaging.version.parse("2.0.0")
<= packaging.version.parse(torch.__version__)
< packaging.version.parse("2.1.0")
):
"""
Weird bug in PyTorch 2.0.x

RuntimeError: shape '[512, 512, 64, 64]' is invalid for input of size 2097152

When executing AttnProcessor in TorchScript
"""
m.vae.decode = lazy_trace_(to_module(m.vae.decode))
# For img2img
m.vae.encoder.forward = lazy_trace_(to_module(m.vae.encoder.forward))
m.vae.quant_conv.forward = lazy_trace_(
to_module(m.vae.quant_conv.forward)
)

if config.trace_scheduler:
m.scheduler.scale_model_input = lazy_trace_(
to_module(m.scheduler.scale_model_input)
)
m.scheduler.step = lazy_trace_(to_module(m.scheduler.step))

if hasattr(m, "controlnet"):
controlnet_forward = lazy_trace(
to_module(m.controlnet.forward),
ts_compiler=functools.partial(
ts_compiler,
freeze=False,
enable_cuda_graph=enable_cuda_graph,
),
check_trace=False,
strict=False,
)

@functools.wraps(m.controlnet.forward)
def controlnet_forward_wrapper(sample, t, *args, **kwargs):
t = t.to(device=sample.device)
return controlnet_forward(sample, t, *args, **kwargs)

m.controlnet.forward = controlnet_forward_wrapper

return m
if enable_cuda_graph:
m.forward = make_dynamic_graphed_callable(m.forward)

return m


def _modify_model(
Expand Down Expand Up @@ -207,12 +156,80 @@ def _modify_model(

if memory_format is not None:
sfast._C._jit_pass_convert_op_input_tensors(
m.graph, "aten::_convolution", indices=[0], memory_format=memory_format
)
m.graph,
'aten::_convolution',
indices=[0],
memory_format=memory_format)

if enable_cnn_optimization:
passes.jit_pass_optimize_cnn(m.graph)

if prefer_lowp_gemm:
passes.jit_pass_prefer_lowp_gemm(m.graph)
passes.jit_pass_fuse_lowp_linear_add(m.graph)


def _ts_compiler(
m,
call_helper,
inputs,
kwarg_inputs,
modify_model_fn=None,
freeze=False,
preserve_parameters=False,
):
with torch.jit.optimized_execution(True):
if freeze and not getattr(m, 'training', False):
# raw freeze causes Tensor reference leak
# because the constant Tensors in the GraphFunction of
# the compilation unit are never freed.
m = jit_utils.better_freeze(
m,
preserve_parameters=preserve_parameters,
)
if modify_model_fn is not None:
modify_model_fn(m)

return m


def _build_lazy_trace(config):
modify_model = functools.partial(
_modify_model,
enable_cnn_optimization=config.enable_cnn_optimization,
prefer_lowp_gemm=config.prefer_lowp_gemm,
enable_triton=config.enable_triton,
memory_format=config.memory_format,
)

ts_compiler = functools.partial(
_ts_compiler,
freeze=config.enable_jit_freeze,
preserve_parameters=config.preserve_parameters,
modify_model_fn=modify_model,
)

lazy_trace_ = functools.partial(
lazy_trace,
ts_compiler=ts_compiler,
check_trace=False,
strict=False,
)

return lazy_trace_


def _enable_xformers(m):
from xformers import ops
from sfast.utils.xformers_attention import (
xformers_memory_efficient_attention, )

ops.memory_efficient_attention = xformers_memory_efficient_attention

if hasattr(m, 'enable_xformers_memory_efficient_attention'):
m.enable_xformers_memory_efficient_attention()
else:
logger.warning(
'enable_xformers_memory_efficient_attention() is not available.'
' If you have enabled xformers by other means, ignore this warning.'
)
Loading