Skip to content

Commit

Permalink
Merge pull request #36 from ROCm/stride_fix
Browse files Browse the repository at this point in the history
Fix grad_k/grad_v strides
  • Loading branch information
qianfengz authored Nov 13, 2024
2 parents 7f91bb1 + 44b6def commit b000bb3
Showing 1 changed file with 2 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ efficient_attention_backward_ck(
grad_q = at::empty_strided(query.sizes(), query.strides(), query.options());
} else {
grad_q = at::empty_strided(query.sizes(), query.strides(), query.options());
grad_k = at::empty_strided(key.sizes(), key.strides(), key.options());
grad_v = at::empty_strided(value.sizes(), value.strides(), value.options());
grad_k = at::empty(key.sizes(), key.options());
grad_v = at::empty(value.sizes(), value.options());
}

at::Tensor grad_q_f32;
Expand All @@ -173,9 +173,7 @@ efficient_attention_backward_ck(
TORCH_CHECK(query.sizes() == grad_q.sizes());
TORCH_CHECK(query.strides() == grad_q.strides());
TORCH_CHECK(key.sizes() == grad_k.sizes());
TORCH_CHECK(key.strides() == grad_k.strides());
TORCH_CHECK(value.sizes() == grad_v.sizes());
TORCH_CHECK(value.strides() == grad_v.strides());

const bool bias_requires_grad = bias.has_value() && bias->requires_grad();

Expand Down

0 comments on commit b000bb3

Please sign in to comment.