Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch compile support for ck attention op #1085

Merged
merged 1 commit into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3312,7 +3312,6 @@ def _merge_attentions_ref(attn_split, lse_split):


@sm80_or_better_only
@skip_if_rocm # rocm doesn't support backward yet
@pytest.mark.parametrize(
"bias_t",
[None, fmha.attn_bias.LowerTriangularMask, fmha.attn_bias.BlockDiagonalMask],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,10 +547,96 @@ efficient_attention_backward_ck(
return std::make_tuple(grad_q, grad_k, grad_v, grad_bias);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
efficient_attention_backward_ck_meta(
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const c10::optional<at::Tensor>& bias, // additive attention bias
// (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the
// position of the first query token for batch $b
const c10::optional<at::Tensor>& seqstart_q,
// (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the
// position of the first key token for batch $b
const c10::optional<at::Tensor>& seqstart_k,
// (Mode 1MHK only) Maximum sequence length across batches
const c10::optional<int64_t> max_seqlen_q_,
// (Mode 1MHK only) Maximum sequence length across batches
const c10::optional<int64_t> max_seqlen_k_,
const c10::optional<at::Tensor>& seqlen_k,
const at::Tensor& logsumexp,
const at::Tensor& out,
double dropout_p, // dropout probability
int64_t rng_seed, // seed using for generating random numbers for dropout
int64_t rng_offset, // offset into random number sequence
int64_t custom_mask_type,
const c10::optional<double> scale,
const c10::optional<int64_t> window_size) {
int64_t B = query.size(0);
int64_t M = query.size(1);
int64_t N = key.size(1);
int64_t Hq = query.size(2);
int64_t Hkv = key.size(2);
int64_t K = query.size(3);
int64_t Kv = value.size(3);

auto opts = query.options();

at::Tensor grad_q, grad_k, grad_v, grad_bias;

if (query.size(1) == key.size(1) && query.size(3) == value.size(3) &&
query.size(2) == key.size(2) &&
query.storage().is_alias_of(key.storage()) &&
query.storage().is_alias_of(value.storage())) {
// Create one big contiguous chunk for grad_q, grad_k, grad_v
// This is because q, k and v usually come from a single
// output of a linear layer that is chunked.
// Creating the gradients with the right layout saves us
// a `torch.cat` call in the backward pass
at::Tensor chunk = at::empty({B, M, 3, Hq, K}, opts);
grad_q = chunk.select(2, 0);
grad_k = chunk.select(2, 1);
grad_v = chunk.select(2, 2);
} else if (
key.size(3) == value.size(3) &&
key.storage().is_alias_of(value.storage())) {
// Create one big contiguous chunk for grad_k, grad_v
// This is because k and v usually come from a single
// output of a linear layer that is chunked.
// Creating the gradients with the right layout saves us
// a `torch.cat` call in the backward pass
at::Tensor chunk = at::empty({B, N, 2, Hkv, Kv}, opts);
grad_k = chunk.select(2, 0);
grad_v = chunk.select(2, 1);

grad_q = at::empty_strided(query.sizes(), query.strides(), query.options());
} else {
grad_q = at::empty_strided(query.sizes(), query.strides(), query.options());
grad_k = at::empty_strided(key.sizes(), key.strides(), key.options());
grad_v = at::empty_strided(value.sizes(), value.strides(), value.options());
}

const bool bias_requires_grad = bias.has_value() && bias->requires_grad();
// even it is an output, the grad_bias is required to use the same data-type
// as bias in CK-FlashAttn
if (bias_requires_grad) {
grad_bias =
at::empty_strided(bias->sizes(), bias->strides(), bias->options());
}
return std::make_tuple(grad_q, grad_k, grad_v, grad_bias);
}

} // namespace

TORCH_LIBRARY_IMPL(xformers, CUDA, m) {
m.impl(
TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward_ck"),
TORCH_FN(efficient_attention_backward_ck));
}

TORCH_LIBRARY_IMPL(xformers, Meta, m) {
m.impl(
TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward_ck"),
TORCH_FN(efficient_attention_backward_ck_meta));
}
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,65 @@ efficient_attention_forward_ck(
return std::make_tuple(out, logsumexp, philox_seed, philox_offset);
}

/*
There are 2 modes for using this function.
(Mode BMHK) With all the heads having the same seqlen
(Mode 1MHK) `batch=1` with all tokens across batches concatenated
*/
std::tuple<at::Tensor, at::Tensor, int64_t, int64_t>
efficient_attention_forward_ck_meta(
const at::Tensor& query, // [b, seqlen, num_heads_q, K]
const at::Tensor& key, // [b, seqlen, num_heads_kv, K]
const at::Tensor& value, // [b, seqlen, num_heads_kv, Kv]
const c10::optional<at::Tensor>& bias, // [b, num_heads_q, seqlen, seqlen]
// (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the
// position of the first query token for batch $b
const c10::optional<at::Tensor>& seqstart_q,
// (Mode 1MHK only) [b+1]: cu_seqlen_k[b] contains the
// position of the first key token for batch $b
const c10::optional<at::Tensor>& seqstart_k,
// (Mode 1MHK only) Maximum sequence length across batches
const c10::optional<int64_t> max_seqlen_q_,
double dropout_p, // attention matrix dropout probability
bool compute_logsumexp,
int64_t custom_mask_type,
c10::optional<double> scale,
const c10::optional<at::Tensor>& seqlen_k,
const c10::optional<int64_t> window_size) {
int64_t B = query.size(0);
int64_t M = query.size(1);
int64_t N = key.size(1);
int64_t Hq = query.size(-2);
int64_t Hkv = key.size(-2);
int64_t K = query.size(-1);
int64_t Kv = value.size(-1);
auto opts = query.options();
at::Tensor logsumexp;
at::Tensor out = at::empty({B, M, Hq, Kv}, opts);
int64_t philox_seed = 0;
int64_t philox_offset = 0;
if (!seqstart_q.has_value()) { // input is batched
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this 'if-else' is necessary. When there is no varlen, the batch size is already 1 right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Will follow-up: currently making this consistent with fbcode.

if (compute_logsumexp) {
logsumexp = at::empty({B, Hq, M}, opts.dtype(at::kFloat));
}
} else {
if (compute_logsumexp) {
logsumexp = at::empty({1, Hq, M}, opts.dtype(at::kFloat));
}
}
return std::make_tuple(out, logsumexp, philox_seed, philox_offset);
}

} // namespace

TORCH_LIBRARY_IMPL(xformers, CUDA, m) {
m.impl(
TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_ck"),
TORCH_FN(efficient_attention_forward_ck));
}

TORCH_LIBRARY_IMPL(xformers, Meta, m) {
m.impl(
TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_ck"),
TORCH_FN(efficient_attention_forward_ck_meta));
}
Loading