From c24b3e684f2d31a9c0208d9b4f0307e87510ce76 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Wed, 29 Nov 2023 22:38:31 +0800 Subject: [PATCH] bump version to 0.0.12.post6 --- src/sfast/__init__.py | 2 +- src/sfast/compilers/stable_diffusion_pipeline_compiler.py | 5 ++++- src/sfast/triton/ops/copy.py | 6 +++--- version.txt | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/sfast/__init__.py b/src/sfast/__init__.py index 3a0ed99..21a85c2 100644 --- a/src/sfast/__init__.py +++ b/src/sfast/__init__.py @@ -32,4 +32,4 @@ def new_lru_cache(*args, **kwargs): # This line will be programatically read/write by setup.py. # Leave them at the bottom of this file and don't touch them. -__version__ = "0.0.12.post5" +__version__ = "0.0.12.post6" diff --git a/src/sfast/compilers/stable_diffusion_pipeline_compiler.py b/src/sfast/compilers/stable_diffusion_pipeline_compiler.py index ec8741e..87865eb 100644 --- a/src/sfast/compilers/stable_diffusion_pipeline_compiler.py +++ b/src/sfast/compilers/stable_diffusion_pipeline_compiler.py @@ -132,7 +132,10 @@ def compile_unet(m, config): m.to(memory_format=config.memory_format) if config.enable_jit: - lazy_trace_ = _build_lazy_trace(config, enable_triton_reshape=True) + lazy_trace_ = _build_lazy_trace( + config, + enable_triton_reshape=True, + ) m.forward = lazy_trace_(m.forward) if enable_cuda_graph: diff --git a/src/sfast/triton/ops/copy.py b/src/sfast/triton/ops/copy.py index f761b30..66a51b0 100644 --- a/src/sfast/triton/ops/copy.py +++ b/src/sfast/triton/ops/copy.py @@ -15,7 +15,7 @@ @eval('''triton.heuristics({ 'num_warps': lambda kwargs: max(1, min(16, kwargs['BLOCK_M'] // 32)), })''') -@triton.jit(do_not_specialize=[4]) +@triton.jit def copy_2d_kernel( output_ptr, input_ptr, @@ -61,7 +61,7 @@ def copy_2d_kernel( @eval('''triton.heuristics({ 'num_warps': lambda kwargs: max(1, min(16, kwargs['BLOCK_M'] * kwargs['BLOCK_N'] // 32)), })''') -@triton.jit(do_not_specialize=[5]) +@triton.jit def copy_3d_kernel( output_ptr, input_ptr, @@ -123,7 +123,7 @@ def copy_3d_kernel( @eval('''triton.heuristics({ 'num_warps': lambda kwargs: max(1, min(16, kwargs['BLOCK_M'] * kwargs['BLOCK_N'] * kwargs['BLOCK_K'] // 32)), })''') -@triton.jit(do_not_specialize=[6]) +@triton.jit def copy_4d_kernel( output_ptr, input_ptr, diff --git a/version.txt b/version.txt index e9a4474..ba1ad0b 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.0.12.post5 +0.0.12.post6