Skip to content

Commit

Permalink
fix dynamic amax history
Browse files Browse the repository at this point in the history
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman committed Mar 23, 2024
1 parent f3c377f commit bb5b4d6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 24 deletions.
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def add_fp8_tensors_to_global_buffer(
if index_in_buffer in fp8_meta:
return

fp8_meta[index_in_buffer] = []
for forward in (True, False):
# This algorithm creates a two-way map with `autocast_to_fp8_params` and
# `fp8_param_to_autocast`. This is used for keeping track of FP8 weights
Expand Down Expand Up @@ -233,7 +234,8 @@ def add_fp8_tensors_to_global_buffer(
fp8_meta[fp8_meta_tensor_key].amax_history)
cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale)
cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv)
fp8_meta[index_in_buffer] = (len(cls.global_amax_buffer[key]) - 1, key)
fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1)
fp8_meta[index_in_buffer].append(key)

@classmethod
def is_fp8_enabled(cls) -> bool:
Expand Down
40 changes: 17 additions & 23 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,48 +227,42 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) ->
"""
if fwd is None:
fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd")
fwd_bwd_keys = ("forward", "backward")
else:
fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)
fwd_bwd_keys = ("forward" if fwd else "backward",)

for key, fwd_bwd_key in zip(fp8_meta_tensor_keys, fwd_bwd_keys):
curr_len = self.fp8_meta[key].amax_history.shape[0]
for meta_key in fp8_meta_tensor_keys:
curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
if length == curr_len:
continue
if length < curr_len:
self.fp8_meta[key].amax_history = self.fp8_meta[key].amax_history[: length].clone()
self.fp8_meta[meta_key].amax_history = (
self.fp8_meta[meta_key].amax_history[: length].clone())
elif length > curr_len:
extra_rows = length - curr_len
self.fp8_meta[key].amax_history = F.pad(
self.fp8_meta[key].amax_history, pad=(0, 0, 0, extra_rows)
self.fp8_meta[meta_key].amax_history = F.pad(
self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows)
)

# Update the global buffers with new amax and history pointers.
if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
index, autocast_key = self.fp8_meta[FP8GlobalStateManager.get_buffer_info()]
buffer_key = f"{fwd_bwd_key}_{autocast_key}" #TODO(ksivaman) fix
if buffer_key in FP8GlobalStateManager.global_amax_buffer:
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][index] = (
self.fp8_meta[key].amax_history[0])
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][index] = (
self.fp8_meta[key].amax_history)

fwd_pos, fwd_key, bwd_pos, bwd_key = (
self.fp8_meta[FP8GlobalStateManager.get_buffer_info()])
for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
if buffer_key in FP8GlobalStateManager.global_amax_buffer:
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = (
self.fp8_meta[meta_key].amax_history[0])
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
self.fp8_meta[meta_key].amax_history)

def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd."""
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"

if self.fp8_meta_tensors_initialized:
# Handle changed amax history size.
# When loading a checkpoint and using cuda graphs, we'll simply
# disallow changing the amax_history size since that involves
# moving to fresh memory loc and thus the global buffer memory
# and the local module fp8 tensor pointers will go out of
# sync. TODO(ksivaman); catch this case and exit gracefully.
self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd)
return

Expand Down

0 comments on commit bb5b4d6

Please sign in to comment.