Skip to content

Commit

Permalink
fix EmbeddingSpMDM bf16 in/out (#1583)
Browse files Browse the repository at this point in the history
Summary:
There is random failure https://hud.pytorch.org/pytorch/pytorch/commit/05397b12505f4fd1bc98af562e103f4162993c1a and
pytorch/pytorch#94163 is reverted.

The random failure is caused by re-use `lengths`(which is the `offset` in Pytorch Embedding Bag) addr.

For FP32->BF16 convert, we need to add a VEC with 2^15 and right shift 16 to do round-nearest
https://github.com/pytorch/FBGEMM/blob/main/src/FbgemmBfloat16ConvertAvx2.cc#L18-L21.
The first version I used
```
          a->mov(scratchReg2_, 1 << 15);
          a->vpbroadcastd(ones_vreg, scratchReg2_);
```
This will cause random fail but cannot work on AVX2 since `asmjit` do not support it broadcast from `GP`(scratchReg2_) to `VEC`(ones_vreg).
https://github.com/asmjit/asmjit/blob/996deae3273073bf75fbd6ddeac038dff5fdb6eb/src/asmjit/x86/x86emitter.h#L2794-L2796

As it showed on `asmjit` headers, We can broadcast from `mem` to `VEC`. So I re-use
`lengths` (it is the ptr for `offset` from Pytorch EmbeddingBag).
I first `mov` the content or `lenghts` to `scratchReg2_`, and `mov` `2^15` to `lenghts` ptr and broadcast it to VEC. After this, I recover `lenghts` content with `scratchReg2_`.
```
          // Cannot find a broadcast instruction for int from GP to VEC with
          // AVX2. We use lengths address to perform the broadcast and
          // write it back
          auto temp_addr = x86::dword_ptr(lengths, 0);
          a->mov(scratchReg2_, temp_addr);
          a->mov(temp_addr, 1 << 15);
          a->vpbroadcastd(ones_vreg, temp_addr);
          a->mov(temp_addr, scratchReg2_);
```
This temporary usage of `lengths` ptr caused the random failure (related to memory overlap with multithreaded order).
For example, we have threads `t1` and `t2`. After `t1 write to this addr a->mov(temp_addr, 1 << 15)` and before `t1 recovery this addr (a->mov(temp_addr, scratchReg2_))`. This addr is `read by t2`. That will cause a failure.
This may be because `a->mov(temp_addr, 1 << 15);` may not only write 32 (or even 64-bit or `lengths` ptr) since it may randomly fail while both `indices` and `offset` are int64_t.

I found another path to generate `VEC` with `2^15`. This way we will not do unsafe read/write with the given memory address anymore so we can solve the random failure.
```
          a->mov(scratchReg2_, 1 << 15);
          a->vpinsrd(ones_vreg.xmm(), ones_vreg.xmm(), scratchReg2_, 0);
          a->vpbroadcastd(ones_vreg, ones_vreg.xmm());
```

Pull Request resolved: #1583

Reviewed By: brad-mengchi, jiecaoyu, jiawenliu64

Differential Revision: D43112022

Pulled By: jianyuh

fbshipit-source-id: 54616eac9fb0277674de98143fde0491d0e78deb
  • Loading branch information
zhuhaozhe authored and facebook-github-bot committed Feb 10, 2023
1 parent d88187d commit 03b2046
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions src/EmbeddingSpMDM.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,14 +431,9 @@ GenEmbeddingSpMDMLookup<
if (isbf16out) {
--unroll_factor;
ones_vreg = vec_reg_t(unroll_factor);
// Cannot find a broadcast instruction for int from GP/VEC to VEC with
// AVX2. We use lengths address to perform the broadcast and
// write it back
auto temp_addr = x86::dword_ptr(lengths, 0);
a->mov(scratchReg2_, temp_addr);
a->mov(temp_addr, 1 << 15);
a->vpbroadcastd(ones_vreg, temp_addr);
a->mov(temp_addr, scratchReg2_);
a->mov(scratchReg2_, 1 << 15);
a->vpinsrd(ones_vreg.xmm(), ones_vreg.xmm(), scratchReg2_, 0);
a->vpbroadcastd(ones_vreg, ones_vreg.xmm());
}

if (is8bit || is16bit || (remainder && instSet == inst_set_t::avx2)) {
Expand Down

0 comments on commit 03b2046

Please sign in to comment.