Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 GEMM for the fprop should use fast accumulation #6168

Closed
kaixih opened this issue Oct 9, 2023 · 15 comments
Closed

FP8 GEMM for the fprop should use fast accumulation #6168

kaixih opened this issue Oct 9, 2023 · 15 comments
Assignees
Labels
GPU XLA on GPU NVIDIA-GPU XLA on Nvidia GPU

Comments

@kaixih
Copy link
Contributor

kaixih commented Oct 9, 2023

In the current implementation of fp8 gemm, the CUBLASLT_MATMUL_DESC_FAST_ACCUM (link) setting is not configured. This means it defaults to disabled fast accumulation. However, following the transformer engine recipe, it is recommended to enable fast accumulation during the fprop pass, which can further enhance the speed of the fp8 gemm operation.

Determining whether the gemm/dot op node in the hlo graph belongs to fprop or bprop poses a non-trivial challenge. In practice, e4m3 is typically used for the fprop gemm, while a combination of e4m3 and e5m2 is used for the bprop gemm. To address this, we propose a solution: if both input data types are e4m3, we will set the aforementioned flag to ensure fast accumulation is utilized specifically during the fprop pass. If this proposal is accepted, we can proceed with preparing a pull request (PR) to implement this change.

cc. @reedwm @philipphack @nluehr @instinct79

@reedwm
Copy link
Member

reedwm commented Oct 10, 2023

To address this, we propose a solution: if both input data types are e4m3, we will set the aforementioned flag to ensure fast accumulation is utilized specifically during the fprop pass.

I'm not a fan of making the accumulation precision dependent on which of the FP8 types are used for the inputs. If we want the forward pass to use faster, less precise accumulation, this should be directly encoded in the HLO instruction. Also I don't think convolutions necessarily use different FP8 types on the forward vs backward pass, so this would only work for dots.

How about we use the PrecisionConfig field (see the precision_config field of the StableHLO spec for dot_general)? There is one PrecisionConfig per input and currently it only affects the input precisions, and only when the inputs are FP32. In XLA:GPU, TF32 is used for the inputs if the PrecisionConfig is DEFAULT or HIGH, and FP32 is used if it's HIGHEST.

For FP8 inputs, I propose interpreting the PrecisionConfig slightly differently: if all of the input's PrecisionConfigs are HIGHEST, accumulate with full precision, otherwise use cuBLAS's fast accumulation mode (whose exact precision is unfortunately undocumented). This is different than how PrecisionConfig is currently used, as now I'm proposing having it affect accumulation precision for FP8 gemms while currently it only affects input precisions. But I think it's fine for the PrecisionConfig concept to not only refer to input precision. Currently the StableHLO spec does not define exactly what part of the dot the PrecisionConfig affects (see openxla/stablehlo#755).

@kaixih @cheshire @burmako, WDYT about using PrecisionConfig to specify the accumulation precision for FP8 dots?

@philipphack
Copy link
Contributor

What's the mechanism for setting the PrecisionConfig for a given FP8 dot?

@reedwm
Copy link
Member

reedwm commented Oct 10, 2023

What's the mechanism for setting the PrecisionConfig for a given FP8 dot?

In JAX, it can be passed to various functions like jnp.dot as the precision argument. In TF, I think this is impossible right now but can be added.

@wenscarl
Copy link
Contributor

In addition to plumbing through jnp.dot, we may also need to have a wrapper around fp8_dot inspired by Transformer Engine's design. Ref here. PrecisionConfig should be a OK since it's not utilized in current design.

@burmako
Copy link
Contributor

burmako commented Oct 16, 2023

@reedwm No objections from the StableHLO side!

@philipphack
Copy link
Contributor

If we use JAX' precision enum, we may have to augment its documentation.

@reedwm
Copy link
Member

reedwm commented Oct 16, 2023

I talked to @cheshire and he is also OK with using the PrecisionConfig to specify the accumulation type for FP8 matmuls. So it sounds like this is the way to go. @kaixih do you want to implement this or should I?

Once implemented, we can update the JAX documentation.

Getting JAX to specify a separate PrecisionConfig on the backwards pass is a bit tricky, but can be done by calling jax.jvp or jax.vjp from within a JAX custom JVP/VJP. For example, the dot_precise_grad function below runs with higher precision the backwards pass:

@jax.custom_jvp
def dot_precise_grad(x, y):
  return jnp.dot(x, y, precision=jax.lax.Precision.DEFAULT)

@dot_precise_grad.defjvp
def dot_precise_grad_jvp(primals, tangents):
  def dot_precise(x, y):
    return jnp.dot(x, y, precision=jax.lax.Precision.HIGHEST)

  _, jvp = jax.jvp(dot_precise, primals, tangents)
  out = jnp.dot(*primals, precision=jax.lax.Precision.DEFAULT)
  return out, jvp

TF currently doesn't support setting the PrecisionConfig on a per-op basis. We should probably add a way to do this once we start training with FP8 models in TF, at which point a similar approach can be used in TF.

@kaixih
Copy link
Contributor Author

kaixih commented Oct 16, 2023

Yes, looks like this is similar to what @wenscarl has just drafted in here.

@kaixih
Copy link
Contributor Author

kaixih commented Oct 16, 2023

@reedwm It appears that our definitions of DEFAULT/HIGHEST differ from yours, as indicated in this link. In our context, DEFAULT signifies the fast accumulation, whereas HIGHEST denotes non-fast accumulation. According to our definition, fprop should use DEFAULT, and bprop should use HIGHEST. I believe this classification aligns with the current understanding of DEFAULT/HIGHEST.

DEFAULT: Fastest calculation, but least accurate approximation to the original number.
HIGHEST: Slowest calculation, but most accurate approximation to the original number.

@reedwm
Copy link
Member

reedwm commented Oct 16, 2023

In my example, DEFAULT signifies fast accumulation as well. DEFAULT is used in the forward pass, while HIGHEST is used in the gradients.

@wenscarl
Copy link
Contributor

In my example, DEFAULT signifies fast accumulation as well. DEFAULT is used in the forward pass, while HIGHEST is used in the gradients.

Tried this commit, the cublasLt logs shows it works.

@reedwm
Copy link
Member

reedwm commented Oct 17, 2023

That commit looks good. I see you created and closed a PR #6388. Do you plan on reopening it?

@wenscarl
Copy link
Contributor

This commit has drastically touched code base. Figuring out how to rebase.

@wenscarl
Copy link
Contributor

Opened PR6599.

copybara-service bot pushed a commit that referenced this issue Nov 2, 2023
Imported from GitHub PR #6599

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168

This PR is closely related to Flax PR-![3416](google/flax#3416).
Copybara import of the project:

--
a4140da by shuw <[email protected]>:

Add FP8 fast accumulation support for cublasLt.

--
9684568 by shuw <[email protected]>:

Improve based on review #1

--
e906d76 by shuw <[email protected]>:

Improve based on review #2

Merging this change closes #6599

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76
PiperOrigin-RevId: 578904075
copybara-service bot pushed a commit that referenced this issue Nov 2, 2023
Imported from GitHub PR #6599

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168

This PR is closely related to Flax PR-![3416](google/flax#3416).
Copybara import of the project:

--
a4140da by shuw <[email protected]>:

Add FP8 fast accumulation support for cublasLt.

--
9684568 by shuw <[email protected]>:

Improve based on review #1

--
e906d76 by shuw <[email protected]>:

Improve based on review #2

Merging this change closes #6599

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76
PiperOrigin-RevId: 578904075
copybara-service bot pushed a commit that referenced this issue Nov 2, 2023
Imported from GitHub PR #6599

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168

This PR is closely related to Flax PR-![3416](google/flax#3416).
Copybara import of the project:

--
a4140da by shuw <[email protected]>:

Add FP8 fast accumulation support for cublasLt.

--
9684568 by shuw <[email protected]>:

Improve based on review #1

--
e906d76 by shuw <[email protected]>:

Improve based on review #2

Merging this change closes #6599

FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76
PiperOrigin-RevId: 578904075
copybara-service bot pushed a commit that referenced this issue Nov 2, 2023
Imported from GitHub PR #6599

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168

This PR is closely related to Flax PR-![3416](google/flax#3416).
Copybara import of the project:

--
a4140da by shuw <[email protected]>:

Add FP8 fast accumulation support for cublasLt.

--
9684568 by shuw <[email protected]>:

Improve based on review #1

--
e906d76 by shuw <[email protected]>:

Improve based on review #2

Merging this change closes #6599

COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76
PiperOrigin-RevId: 578948593
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Nov 2, 2023
Imported from GitHub PR openxla/xla#6599

FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue#openxla/xla#6168

This PR is closely related to Flax PR-![3416](google/flax#3416).
Copybara import of the project:

--
a4140da8ca08cd2d4796a7b8f032827867a361bc by shuw <[email protected]>:

Add FP8 fast accumulation support for cublasLt.

--
96845683cc4b1e7b947bc919fbf97d8865abeac9 by shuw <[email protected]>:

Improve based on review #1

--
e906d7620780d2cf1fe8433c933648dcb98dc61d by shuw <[email protected]>:

Improve based on review #2

Merging this change closes #6599

PiperOrigin-RevId: 578948593
@penpornk penpornk added NVIDIA-GPU XLA on Nvidia GPU GPU XLA on GPU labels Feb 29, 2024
@reedwm
Copy link
Member

reedwm commented Feb 29, 2024

Closing as @wenscarl fixed this in #6599.

@reedwm reedwm closed this as completed Feb 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
GPU XLA on GPU NVIDIA-GPU XLA on Nvidia GPU
Projects
None yet
Development

No branches or pull requests

6 participants