Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix rope precision bug #1016

Closed
wants to merge 1 commit into from
Closed

fix rope precision bug #1016

wants to merge 1 commit into from

Conversation

fecet
Copy link

@fecet fecet commented Aug 23, 2023

When we use model.bfloat16().cuda(), inv_freq will be converted to bfloat16 even we defined it in dtype=float32. In bfloat16, position would not be represented right.

For example, we cannot distinguish the difference between 256.0 and 257.0, so when we create position embedding, we forced convert inv_freq and t to fp32.

When we use model.bfloat16().cuda(), `inv_freq` will be converted to bfloat16 even we defined it in dtype=float32. In bfloat16, position would not
be represented right, for example, we cannot distinguish the difference between `256.0` and `257.0`, so when we create
position embedding, we forced convert `inv_freq` and `t` to fp32.
@fecet fecet requested a review from a team as a code owner August 23, 2023 04:23
@CLAassistant
Copy link

CLAassistant commented Aug 23, 2023

CLA assistant check
All committers have signed the CLA.

@StellaAthena
Copy link
Member

We are aware of this issue, and have been discussing it in the discord server. However our current understanding is that because of the way DeepSpeed works this doesn't actually accomplish what is desired. You can see the discussion here. If you think our assessment is incorrect, we would appreciate providing screenshots & tests demonstrating this.

@fecet
Copy link
Author

fecet commented Aug 24, 2023

Hi,

I've conducted tests and here are my findings:

class LlamaForRM(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.llama = LlamaModel(config)
        self.value_head = nn.Linear(self.config.hidden_size, 1)

        self.llama.kv_enabled(False)
        self.llama.fmha_enabled(False)

        # Initialize weights and apply final processing
        self.post_init()

        self.bce_loss = nn.BCEWithLogitsLoss()

        # bf16_test_tensor = torch.arange(250, 260).float().cuda()
        self.bf16_test_tensor2 = torch.arange(250, 260).float().cuda()
        self.bf16_test_tensor3 = torch.arange(250, 260).float().cuda()
        self.register_buffer("bf16_test_tensor", self.bf16_test_tensor2)
        self.bf16_test_tensor.data = self.bf16_test_tensor2.data

    def test_bf16(self):
        a = torch.arange(250, 260).float().cuda().to(dtype=torch.float32)
        print(
            f"bf16 test before convert: bf16_test_tensor:{self.bf16_test_tensor.dtype}  " 
            f"a:{a.dtype}  bftest2:{self.bf16_test_tensor2.dtype}  "
            f"bftest3:{self.bf16_test_tensor3.dtype}  bftest_buffer:{self.bf16_test_tensor.dtype}"
        )
        self.bf16_test_tensor = self.bf16_test_tensor.float()
        c = self.bf16_test_tensor.float() + a.float()
        c2 = self.bf16_test_tensor2.float() + a.float()
        c3 = self.bf16_test_tensor3.float() + a.float()
        self.bf16_test_tensor2 *= 2
        print(
            f"bf16 test after convert: bf16_test_tensor:{self.bf16_test_tensor.dtype}  " 
            f"a:{a.dtype}  bftest2:{self.bf16_test_tensor2.dtype}  "
            f"c:{c.dtype}  c2:{c2.dtype}  c3:{c3.dtype}"
        )
        ret = {"a":a ,"c":c, "c2":c2, "c3":c3, "bf16_test_tensor":self.bf16_test_tensor, "bf16_test_tensor2":self.bf16_test_tensor2}
        return ret

with test code

if "bf16" in deepspeed_config and deepspeed_config["bf16"]["enabled"]:
    from megatron.modeling_llama_neox import LlamaConfig

    config = LlamaConfig.from_pretrained(model_path)
    config.torch_dtype = "bfloat16"
    model = (
        LlamaForRM.from_pretrained(model_path, config=config).bfloat16().cuda()
    )  # use bfloat 16
...

optimizer, scheduler = setup_optimizer(model, train_config)
model.cpu()
engine = deepspeed.initialize(
    args=None,
    model=model,
    optimizer=optimizer,
    config=deepspeed_config,
    lr_scheduler=scheduler,
    dist_init_required=False,
    mpu=mpu,
)[0]


engine.train()

engine.module.test_bf16()

# %%

engine.module.bf16_test_tensor.data = engine.module.bf16_test_tensor2.data
engine.module.bf16_test_tensor

# %%

engine.module.test_bf16()

where # %% means Jupyter Cell Split.

From the test, it's evident that the model under the deepspeed engine indeed undergoes some automatic bfloat16 precision conversions. However, this conversion primarily occurs within the register buffer method. If we define tensors directly via self.something, we can apparently perform operations in regular precision (fp32). Here's the detailed observation:

image

As seen in the image above, the operation result of c3 appears normal, whereas for c, the use of bf16 causes a precision loss. Notably, after multiplying bf16_test_tensor2 by 2 in our function, our bf16_test_tensor, i.e., the buffer tensor, remains unchanged. To address this, we explicitly aligned the data pointers of both tensors as shown in this cell

engine.module.bf16_test_tensor.data = engine.module.bf16_test_tensor2.data

Consequently, in the subsequent test, engine.module.test_bf16() again, we obtained the expected results:

image
(It's a MP4-DP2 Llama13B model, and for these cases, the output of each rank are the same)

I suspect that certain behaviors may be related to pointer handling in PyTorch's C++ segment and the data type conversion in CUDA tensors (as we know, bf16/fp16 in CUDA originates from splitting uint32_t).

To summarize:

  • Tensors in register buffer get automatically converted by deepspeed.
    Attributes in the class and temporary variables within functions, such as a in our example, do not undergo this conversion.
  • My PR mainly modifies the data type of the temporary variable, which should suffice to prevent the bfloat conversion for position ids (assuming my tests are interpreted correctly).
  • Regarding tensors in register buffer like inv_freq, if there's a need to operate them in fp32 (is it necessary?), we could utilize another attribute for actual storage. Once a deepspeed engine is instantiated, we can then explicitly point the tensor.data, i.e., the data pointer, towards the tensor.data in register buffer.

Hope this clarifies things! Awaiting feedback.

@StellaAthena
Copy link
Member

@Quentin-Anthony
Copy link
Member

This was resolved by #1041

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants