Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix EmbeddingSpMDM bf16 in/out #1583

Closed
wants to merge 3 commits into from
Closed

Conversation

zhuhaozhe
Copy link
Contributor

@zhuhaozhe zhuhaozhe commented Feb 7, 2023

There is random failure https://hud.pytorch.org/pytorch/pytorch/commit/05397b12505f4fd1bc98af562e103f4162993c1a and
pytorch/pytorch#94163 is reverted.

The random failure is caused by re-use lengths(which is the offset in Pytorch Embedding Bag) addr.

For FP32->BF16 convert, we need to add a VEC with 2^15 and right shift 16 to do round-nearest
https://github.com/pytorch/FBGEMM/blob/main/src/FbgemmBfloat16ConvertAvx2.cc#L18-L21.
The first version I used

          a->mov(scratchReg2_, 1 << 15);
          a->vpbroadcastd(ones_vreg, scratchReg2_);

This will cause random fail but cannot work on AVX2 since asmjit do not support it broadcast from GP(scratchReg2_) to VEC(ones_vreg).
https://github.com/asmjit/asmjit/blob/996deae3273073bf75fbd6ddeac038dff5fdb6eb/src/asmjit/x86/x86emitter.h#L2794-L2796

As it showed on asmjit headers, We can broadcast from mem to VEC. So I re-use
lengths (it is the ptr for offset from Pytorch EmbeddingBag).
I first mov the content or lenghts to scratchReg2_, and mov 2^15 to lenghts ptr and broadcast it to VEC. After this, I recover lenghts content with scratchReg2_.

          // Cannot find a broadcast instruction for int from GP to VEC with
          // AVX2. We use lengths address to perform the broadcast and
          // write it back
          auto temp_addr = x86::dword_ptr(lengths, 0);
          a->mov(scratchReg2_, temp_addr);
          a->mov(temp_addr, 1 << 15);
          a->vpbroadcastd(ones_vreg, temp_addr);
          a->mov(temp_addr, scratchReg2_);

This temporary usage of lengths ptr caused the random failure (related to memory overlap with multithreaded order).
For example, we have threads t1 and t2. After t1 write to this addr a->mov(temp_addr, 1 << 15) and before t1 recovery this addr (a->mov(temp_addr, scratchReg2_)). This addr is read by t2. That will cause a failure.
This may be because a->mov(temp_addr, 1 << 15); may not only write 32 (or even 64-bit or lengths ptr) since it may randomly fail while both indices and offset are int64_t.

I found another path to generate VEC with 2^15. This way we will not do unsafe read/write with the given memory address anymore so we can solve the random failure.

          a->mov(scratchReg2_, 1 << 15);
          a->vpinsrd(ones_vreg.xmm(), ones_vreg.xmm(), scratchReg2_, 0);
          a->vpbroadcastd(ones_vreg, ones_vreg.xmm());

@netlify
Copy link

netlify bot commented Feb 7, 2023

Deploy Preview for pytorch-fbgemm-docs canceled.

Name Link
🔨 Latest commit 026c33e
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/63e47b29663f770009ee4704

@zhuhaozhe
Copy link
Contributor Author

Hi, @jianyuh. May you help to review this? The pytorch/pytorch#94163 is reverted for some random failure not exposed before.

@facebook-github-bot
Copy link
Contributor

@jianyuh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@zhuhaozhe
Copy link
Contributor Author

Hi, @jianyuh. Sorry to bother you, but is there any chance we can merge pytorch/pytorch#94163. (which depends on this fix). We wish to catch up with PT 2.0 for BF16 emb support.

@jianyuh
Copy link
Member

jianyuh commented Feb 9, 2023

Hi @zhuhaozhe sorry there are some trunk issues on fbgemm (from #1582 ). Might need to take one more day for landing.

@facebook-github-bot
Copy link
Contributor

@jianyuh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@jianyuh merged this pull request in 03b2046.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

3 participants