Skip to content

Commit

Permalink
bump version to 0.0.12.post6
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Nov 29, 2023
1 parent 9ad0b6b commit c24b3e6
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/sfast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ def new_lru_cache(*args, **kwargs):

# This line will be programatically read/write by setup.py.
# Leave them at the bottom of this file and don't touch them.
__version__ = "0.0.12.post5"
__version__ = "0.0.12.post6"
5 changes: 4 additions & 1 deletion src/sfast/compilers/stable_diffusion_pipeline_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ def compile_unet(m, config):
m.to(memory_format=config.memory_format)

if config.enable_jit:
lazy_trace_ = _build_lazy_trace(config, enable_triton_reshape=True)
lazy_trace_ = _build_lazy_trace(
config,
enable_triton_reshape=True,
)
m.forward = lazy_trace_(m.forward)

if enable_cuda_graph:
Expand Down
6 changes: 3 additions & 3 deletions src/sfast/triton/ops/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@eval('''triton.heuristics({
'num_warps': lambda kwargs: max(1, min(16, kwargs['BLOCK_M'] // 32)),
})''')
@triton.jit(do_not_specialize=[4])
@triton.jit
def copy_2d_kernel(
output_ptr,
input_ptr,
Expand Down Expand Up @@ -61,7 +61,7 @@ def copy_2d_kernel(
@eval('''triton.heuristics({
'num_warps': lambda kwargs: max(1, min(16, kwargs['BLOCK_M'] * kwargs['BLOCK_N'] // 32)),
})''')
@triton.jit(do_not_specialize=[5])
@triton.jit
def copy_3d_kernel(
output_ptr,
input_ptr,
Expand Down Expand Up @@ -123,7 +123,7 @@ def copy_3d_kernel(
@eval('''triton.heuristics({
'num_warps': lambda kwargs: max(1, min(16, kwargs['BLOCK_M'] * kwargs['BLOCK_N'] * kwargs['BLOCK_K'] // 32)),
})''')
@triton.jit(do_not_specialize=[6])
@triton.jit
def copy_4d_kernel(
output_ptr,
input_ptr,
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.12.post5
0.0.12.post6

0 comments on commit c24b3e6

Please sign in to comment.