-
Notifications
You must be signed in to change notification settings - Fork 509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix EmbeddingSpMDM bf16 in/out #1583
Conversation
✅ Deploy Preview for pytorch-fbgemm-docs canceled.
|
Hi, @jianyuh. May you help to review this? The pytorch/pytorch#94163 is reverted for some random failure not exposed before. |
@jianyuh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Hi, @jianyuh. Sorry to bother you, but is there any chance we can merge pytorch/pytorch#94163. (which depends on this fix). We wish to catch up with PT 2.0 for BF16 emb support. |
Hi @zhuhaozhe sorry there are some trunk issues on fbgemm (from #1582 ). Might need to take one more day for landing. |
@jianyuh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There is random failure https://hud.pytorch.org/pytorch/pytorch/commit/05397b12505f4fd1bc98af562e103f4162993c1a and
pytorch/pytorch#94163 is reverted.
The random failure is caused by re-use
lengths
(which is theoffset
in Pytorch Embedding Bag) addr.For FP32->BF16 convert, we need to add a VEC with 2^15 and right shift 16 to do round-nearest
https://github.com/pytorch/FBGEMM/blob/main/src/FbgemmBfloat16ConvertAvx2.cc#L18-L21.
The first version I used
This will cause random fail but cannot work on AVX2 since
asmjit
do not support it broadcast fromGP
(scratchReg2_) toVEC
(ones_vreg).https://github.com/asmjit/asmjit/blob/996deae3273073bf75fbd6ddeac038dff5fdb6eb/src/asmjit/x86/x86emitter.h#L2794-L2796
As it showed on
asmjit
headers, We can broadcast frommem
toVEC
. So I re-uselengths
(it is the ptr foroffset
from Pytorch EmbeddingBag).I first
mov
the content orlenghts
toscratchReg2_
, andmov
2^15
tolenghts
ptr and broadcast it to VEC. After this, I recoverlenghts
content withscratchReg2_
.This temporary usage of
lengths
ptr caused the random failure (related to memory overlap with multithreaded order).For example, we have threads
t1
andt2
. Aftert1 write to this addr a->mov(temp_addr, 1 << 15)
and beforet1 recovery this addr (a->mov(temp_addr, scratchReg2_))
. This addr isread by t2
. That will cause a failure.This may be because
a->mov(temp_addr, 1 << 15);
may not only write 32 (or even 64-bit orlengths
ptr) since it may randomly fail while bothindices
andoffset
are int64_t.I found another path to generate
VEC
with2^15
. This way we will not do unsafe read/write with the given memory address anymore so we can solve the random failure.