-
Notifications
You must be signed in to change notification settings - Fork 440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FP8 GEMM for the fprop should use fast accumulation #6168
Comments
I'm not a fan of making the accumulation precision dependent on which of the FP8 types are used for the inputs. If we want the forward pass to use faster, less precise accumulation, this should be directly encoded in the HLO instruction. Also I don't think convolutions necessarily use different FP8 types on the forward vs backward pass, so this would only work for dots. How about we use the PrecisionConfig field (see the For FP8 inputs, I propose interpreting the PrecisionConfig slightly differently: if all of the input's PrecisionConfigs are HIGHEST, accumulate with full precision, otherwise use cuBLAS's fast accumulation mode (whose exact precision is unfortunately undocumented). This is different than how PrecisionConfig is currently used, as now I'm proposing having it affect accumulation precision for FP8 gemms while currently it only affects input precisions. But I think it's fine for the PrecisionConfig concept to not only refer to input precision. Currently the StableHLO spec does not define exactly what part of the dot the PrecisionConfig affects (see openxla/stablehlo#755). @kaixih @cheshire @burmako, WDYT about using PrecisionConfig to specify the accumulation precision for FP8 dots? |
What's the mechanism for setting the PrecisionConfig for a given FP8 dot? |
In JAX, it can be passed to various functions like |
In addition to plumbing through |
@reedwm No objections from the StableHLO side! |
If we use JAX' precision enum, we may have to augment its documentation. |
I talked to @cheshire and he is also OK with using the PrecisionConfig to specify the accumulation type for FP8 matmuls. So it sounds like this is the way to go. @kaixih do you want to implement this or should I? Once implemented, we can update the JAX documentation. Getting JAX to specify a separate PrecisionConfig on the backwards pass is a bit tricky, but can be done by calling @jax.custom_jvp
def dot_precise_grad(x, y):
return jnp.dot(x, y, precision=jax.lax.Precision.DEFAULT)
@dot_precise_grad.defjvp
def dot_precise_grad_jvp(primals, tangents):
def dot_precise(x, y):
return jnp.dot(x, y, precision=jax.lax.Precision.HIGHEST)
_, jvp = jax.jvp(dot_precise, primals, tangents)
out = jnp.dot(*primals, precision=jax.lax.Precision.DEFAULT)
return out, jvp TF currently doesn't support setting the PrecisionConfig on a per-op basis. We should probably add a way to do this once we start training with FP8 models in TF, at which point a similar approach can be used in TF. |
@reedwm It appears that our definitions of DEFAULT/HIGHEST differ from yours, as indicated in this link. In our context, DEFAULT signifies the fast accumulation, whereas HIGHEST denotes non-fast accumulation. According to our definition, fprop should use DEFAULT, and bprop should use HIGHEST. I believe this classification aligns with the current understanding of DEFAULT/HIGHEST.
|
In my example, DEFAULT signifies fast accumulation as well. DEFAULT is used in the forward pass, while HIGHEST is used in the gradients. |
Tried this commit, the cublasLt logs shows it works. |
That commit looks good. I see you created and closed a PR #6388. Do you plan on reopening it? |
This commit has drastically touched code base. Figuring out how to rebase. |
Imported from GitHub PR #6599 FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168 This PR is closely related to Flax PR-![3416](google/flax#3416). Copybara import of the project: -- a4140da by shuw <[email protected]>: Add FP8 fast accumulation support for cublasLt. -- 9684568 by shuw <[email protected]>: Improve based on review #1 -- e906d76 by shuw <[email protected]>: Improve based on review #2 Merging this change closes #6599 FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76 PiperOrigin-RevId: 578904075
Imported from GitHub PR #6599 FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168 This PR is closely related to Flax PR-![3416](google/flax#3416). Copybara import of the project: -- a4140da by shuw <[email protected]>: Add FP8 fast accumulation support for cublasLt. -- 9684568 by shuw <[email protected]>: Improve based on review #1 -- e906d76 by shuw <[email protected]>: Improve based on review #2 Merging this change closes #6599 FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76 PiperOrigin-RevId: 578904075
Imported from GitHub PR #6599 FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168 This PR is closely related to Flax PR-![3416](google/flax#3416). Copybara import of the project: -- a4140da by shuw <[email protected]>: Add FP8 fast accumulation support for cublasLt. -- 9684568 by shuw <[email protected]>: Improve based on review #1 -- e906d76 by shuw <[email protected]>: Improve based on review #2 Merging this change closes #6599 FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76 PiperOrigin-RevId: 578904075
Imported from GitHub PR #6599 FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168 This PR is closely related to Flax PR-![3416](google/flax#3416). Copybara import of the project: -- a4140da by shuw <[email protected]>: Add FP8 fast accumulation support for cublasLt. -- 9684568 by shuw <[email protected]>: Improve based on review #1 -- e906d76 by shuw <[email protected]>: Improve based on review #2 Merging this change closes #6599 COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76 PiperOrigin-RevId: 578948593
Imported from GitHub PR openxla/xla#6599 FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue#openxla/xla#6168 This PR is closely related to Flax PR-![3416](google/flax#3416). Copybara import of the project: -- a4140da8ca08cd2d4796a7b8f032827867a361bc by shuw <[email protected]>: Add FP8 fast accumulation support for cublasLt. -- 96845683cc4b1e7b947bc919fbf97d8865abeac9 by shuw <[email protected]>: Improve based on review #1 -- e906d7620780d2cf1fe8433c933648dcb98dc61d by shuw <[email protected]>: Improve based on review #2 Merging this change closes #6599 PiperOrigin-RevId: 578948593
In the current implementation of fp8 gemm, the
CUBLASLT_MATMUL_DESC_FAST_ACCUM
(link) setting is not configured. This means it defaults to disabled fast accumulation. However, following the transformer engine recipe, it is recommended to enable fast accumulation during the fprop pass, which can further enhance the speed of the fp8 gemm operation.Determining whether the gemm/dot op node in the hlo graph belongs to fprop or bprop poses a non-trivial challenge. In practice, e4m3 is typically used for the fprop gemm, while a combination of e4m3 and e5m2 is used for the bprop gemm. To address this, we propose a solution: if both input data types are e4m3, we will set the aforementioned flag to ensure fast accumulation is utilized specifically during the fprop pass. If this proposal is accepted, we can proceed with preparing a pull request (PR) to implement this change.
cc. @reedwm @philipphack @nluehr @instinct79
The text was updated successfully, but these errors were encountered: