Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix EmbeddingSpMDM bf16 in/out (#1583)
Summary: There is random failure https://hud.pytorch.org/pytorch/pytorch/commit/05397b12505f4fd1bc98af562e103f4162993c1a and pytorch/pytorch#94163 is reverted. The random failure is caused by re-use `lengths`(which is the `offset` in Pytorch Embedding Bag) addr. For FP32->BF16 convert, we need to add a VEC with 2^15 and right shift 16 to do round-nearest https://github.com/pytorch/FBGEMM/blob/main/src/FbgemmBfloat16ConvertAvx2.cc#L18-L21. The first version I used ``` a->mov(scratchReg2_, 1 << 15); a->vpbroadcastd(ones_vreg, scratchReg2_); ``` This will cause random fail but cannot work on AVX2 since `asmjit` do not support it broadcast from `GP`(scratchReg2_) to `VEC`(ones_vreg). https://github.com/asmjit/asmjit/blob/996deae3273073bf75fbd6ddeac038dff5fdb6eb/src/asmjit/x86/x86emitter.h#L2794-L2796 As it showed on `asmjit` headers, We can broadcast from `mem` to `VEC`. So I re-use `lengths` (it is the ptr for `offset` from Pytorch EmbeddingBag). I first `mov` the content or `lenghts` to `scratchReg2_`, and `mov` `2^15` to `lenghts` ptr and broadcast it to VEC. After this, I recover `lenghts` content with `scratchReg2_`. ``` // Cannot find a broadcast instruction for int from GP to VEC with // AVX2. We use lengths address to perform the broadcast and // write it back auto temp_addr = x86::dword_ptr(lengths, 0); a->mov(scratchReg2_, temp_addr); a->mov(temp_addr, 1 << 15); a->vpbroadcastd(ones_vreg, temp_addr); a->mov(temp_addr, scratchReg2_); ``` This temporary usage of `lengths` ptr caused the random failure (related to memory overlap with multithreaded order). For example, we have threads `t1` and `t2`. After `t1 write to this addr a->mov(temp_addr, 1 << 15)` and before `t1 recovery this addr (a->mov(temp_addr, scratchReg2_))`. This addr is `read by t2`. That will cause a failure. This may be because `a->mov(temp_addr, 1 << 15);` may not only write 32 (or even 64-bit or `lengths` ptr) since it may randomly fail while both `indices` and `offset` are int64_t. I found another path to generate `VEC` with `2^15`. This way we will not do unsafe read/write with the given memory address anymore so we can solve the random failure. ``` a->mov(scratchReg2_, 1 << 15); a->vpinsrd(ones_vreg.xmm(), ones_vreg.xmm(), scratchReg2_, 0); a->vpbroadcastd(ones_vreg, ones_vreg.xmm()); ``` Pull Request resolved: #1583 Reviewed By: brad-mengchi, jiecaoyu, jiawenliu64 Differential Revision: D43112022 Pulled By: jianyuh fbshipit-source-id: 54616eac9fb0277674de98143fde0491d0e78deb
- Loading branch information