Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to use flash attention #222

Conversation

Warvito
Copy link
Collaborator

@Warvito Warvito commented Feb 4, 2023

Fixes #210

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
@Warvito Warvito linked an issue Feb 4, 2023 that may be closed by this pull request
Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
@Warvito
Copy link
Collaborator Author

Warvito commented Feb 4, 2023

@ericspod After adding the option to use the efficient memory to the model I started to get the following error when running the unit tests:

======================================================================
ERROR: test_script_conditioned_2d_models (tests.test_diffusion_model_unet.TestDiffusionModelUNet2D)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/media/walter/Storage/Projects/GenerativeModels/tests/test_diffusion_model_unet.py", line 402, in test_script_conditioned_2d_models
    test_script_save(net, torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3)))
  File "/media/walter/Storage/Projects/GenerativeModels/tests/utils.py", line 723, in test_script_save
    convert_to_torchscript(
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/monai/networks/utils.py", line 597, in convert_to_torchscript
    script_module = torch.jit.script(model, **kwargs)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_script.py", line 1286, in script
    return torch.jit._recursive.create_script_module(
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 476, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 538, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_script.py", line 615, in _construct
    init_fn(script_module)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 516, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 538, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_script.py", line 615, in _construct
    init_fn(script_module)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 516, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 538, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_script.py", line 615, in _construct
    init_fn(script_module)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 516, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 538, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_script.py", line 615, in _construct
    init_fn(script_module)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 516, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 538, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_script.py", line 615, in _construct
    init_fn(script_module)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 516, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 538, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_script.py", line 615, in _construct
    init_fn(script_module)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 516, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 538, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_script.py", line 615, in _construct
    init_fn(script_module)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 516, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 393, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 894, in compile_unbound_method
    create_methods_and_properties_from_stubs(concrete_type, (stub,), ())
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 393, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_recursive.py", line 863, in try_compile_fn
    return torch.jit.script(fn, _rcb=rcb)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/_script.py", line 1340, in script
    ast = get_jit_def(obj, obj.__name__)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/frontend.py", line 293, in get_jit_def
    return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/frontend.py", line 331, in build_def
    param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/torch/jit/frontend.py", line 366, in build_param_list
    raise NotSupportedError(ctx_range, _vararg_kwarg_err)
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py", line 125
    scale: Optional[float] = None,
    *,
    op: Optional[AttentionOp] = None,
                                ~~~~ <--- HERE
) -> torch.Tensor:
    """Implements the memory-efficient attention mechanism following
'CrossAttention._memory_efficient_attention_xformers' is being compiled since it was called from 'CrossAttention.forward'
  File "/media/walter/Storage/Projects/GenerativeModels/generative/networks/nets/diffusion_model_unet.py", line 164
    
        if self.use_flash_attention:
            x = self._memory_efficient_attention_xformers(query, key, value)
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        else:
            x = self._attention(query, key, value)

It looks like a problem with the dependency and torchscript. How would you suggest to deal with this issue?

@ericspod
Copy link
Member

ericspod commented Feb 6, 2023

Where you call memory_efficient_attention you may have to provide all arguments that it requires and do so positionally, so not like keywords like it is and explicitly providing the default values for those you don't want to provide values for here.

…tion-to-the-diffusion-unet

# Conflicts:
#	generative/networks/nets/diffusion_model_unet.py
#	requirements-dev.txt
Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
@Warvito Warvito marked this pull request as ready for review February 26, 2023 19:19
@Warvito Warvito added the need reviewer This PR need a reviewer label Feb 27, 2023
@Warvito Warvito merged commit 811ac1e into main Feb 27, 2023
@Warvito Warvito deleted the 210-add-option-to-use-memory-efficient-attention-to-the-diffusion-unet branch March 1, 2023 13:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
need reviewer This PR need a reviewer
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add option to use memory efficient attention to the Diffusion UNet
2 participants