Replies: 22 comments 41 replies
-
I believe many hardware vendors will want to allow matmuls to be performed directly in FP8, with hardware support for scaling. This is what enables users to benefit from the FLOPs improvement of FP8 (in contrast to memory savings). That is, there will be an operation F = float32 # Or, maybe more likely as discussed above, a float16
def matmul_scaled(X : Tensor[float8], Y : Tensor[float8], scale : float) -> Tensor[F]:
"""
Returns (X @ Y) * scale
"""
... In terms of which a version of def quantized_matmul_fast(quantized_x : Tensor[float8], quantized_y : Tensor[float8], x_scale, y_scale, z_scale):
# Do the matmul, and divide by the old scale
z = matmul_scaled(quantized_x, quantized_y, x_scale * y_scale / z_scale)
# Compute the new scale
new_z_scale = fn(z)
# Quantize the matmul output (already scaled by the old scale)
return cast(z, fp8_e4m3), new_z_scale Note that |
Beta Was this translation helpful? Give feedback.
-
I'm not sure what exactly is the formula for the "zero point". So it would be useful to add it. |
Beta Was this translation helpful? Give feedback.
-
The approach proposed here involves what's often known as "fake quantization" around the matmul in HLO that's then pattern matched into a true quantized fp8 matmul. I'd like to propose an alternative that involves a direct fp8 matmul in HLO. JAX code that's 1:1 with the two HLO encodings: # RFC approach with fake quantization
def matmul_fp8_rfc(x_bf16, y_bf16, x_amax, y_amax, z_amax):
x_rounded = unscale_to_bf16(scale_to_fp8(x_bf16, x_amax), x_amax)
y_rounded = unscale_to_bf16(scale_to_fp8(y_bf16, y_amax), y_amax)
z_bf16 = jnp.dot(x_rounded, y_rounded)
new_z_amax = amax(z_bf16)
return z_bf16, new_z_amax
# Alternative approach that directly represents what cuBLASLt is doing
def matmul_fp8_alt(x_bf16, y_bf16, x_amax, y_amax, z_amax):
x_fp8_scaled = fp8(scale(x_bf16, x_amax))
y_fp8_scaled = fp8(scale(y_bf16, y_amax))
z_bf16_scaled_x_y = jnp.dot(x_fp8_scaled, y_fp8_scaled, precision='fp8', preferred_element_type=jnp.bfloat16)
z_bf16 = unscale(z_bf16_scaled_x_y, x_amax, y_amax)
new_z_amax = amax(z_bf16)
return z_bf16, new_z_amax The advantages of the RFC approach in a high-level frontend are:
But the RFC is proposing an implementation for an IR, not a high-level frontend. The XLA GPU backend is not the only consumer of HLO IR, and the more that HLO IR diverges from the operational intent of the user (even if it continues to have about the same numerical results), the more obligatory transformations a backend has to implement before it's operationally correct. A few concrete examples of things that would be easier with the alternative approach:
I was hoping that if the approach in the RFC is implemented, we in JAX would still be able to expose the alternative approach to our users (including library authors), many of whom are likely to value clear operational semantics ("the type you write is the type you get") more than maximal similarity to unquantized code. But I think that would require us to do a pattern match, since the transformation wouldn't be local to the matmul lowering (we need to introduce an unscale, not just a cast to bf16), and we don't really want to/aren't really able to do pattern matches in JAX (hence the desire for jaxpr to stay fairly close to HLO, with only local transformations during lowering). I recognize that the same need for frontend -> IR pattern matching is true in the RFC -> alt direction, although TF does have the tooling to implement nonlocal lowerings to HLO; if TF would like to generate code with the RFC approach then I'm advocating that XLA support matching both patterns. The alternative approach is also already the approach used by JAX library authors to implement int8 quantization-aware training (e.g. AQT), so having to follow the RFC approach for fp8 would be inconsistent with our approach to integer quantization. |
Beta Was this translation helpful? Give feedback.
-
As the author of the quantized types many years ago, I definitely want to review closely any use of them which introduces symbols as this RFC suggests -- that was listed as future so I reserved comment. Imo, for a dialect like StableHLO, the only valid use of quantized types is as a higher level "sugaring" for the sake of some of the traditional frontends/tools that reason in those terms: they should be reducible to concrete IR in StableHLO that is correct and performant on mainline platforms (which may require additional ops to express in a hardware aligned way). That is also what the principle we were discussing means: the result of any pattern matching or desugaring should be expressible in StableHLO itself. I would probably go a step further and suggest that it should be expressible without information loss that will necessitate further pattern matching to recover (but that is quite subjective and needs to be evaluated case by case). |
Beta Was this translation helpful? Give feedback.
-
Given future hardware implementations of FP8 are also in mind. What is the plan to extend this? |
Beta Was this translation helpful? Give feedback.
-
|
Beta Was this translation helpful? Give feedback.
-
“ XLA will initially only support NVIDIA's proposed FP8 types, since their hardware supporting FP8 will be among the first to be released” Graphcore already has FP8 hardware available. It seems very shortsighted not to support the GC/AMD variety too. |
Beta Was this translation helpful? Give feedback.
-
(I also raised this during the OpenXLA meeting) We should proactively think about ways to alert the user in case their usage of the contracts introduced in the RFC fails to pattern-match to genuine FP8 calls, perhaps due to the user failing to strictly follow the steps put forth and/or the pattern-matching heuristics not being encompassing enough. |
Beta Was this translation helpful? Give feedback.
-
A discussion for the framework front-end design, I realize, but it'd be nice if we could come up with a clever bookkeeping abstraction that encapsulates the FP8 tensor values along with the respective scales so that the user doesn't have to keep dragging the two components separately if they choose to. |
Beta Was this translation helpful? Give feedback.
-
Last Tuesday (Nov 17, 2022), I presented this RFC at the XLA community meeting. At the end of my presentation, there was a Q&A section. I'll summarize the questions and answers here. Q: Is there work in supporting FP8 in frameworks like TensorFlow and JAX? Q: The FP8 proposes adding a lot of additional scaling ops, like multiply/divide ops. Won't this increase the size of the IR? Q: The example shown in the presentation has a matmul multiplying an input with itself, which results in a large value, which can overflow. (In the RFC, the example is the first example in the "Scaling" section). Will this cause problems in practice? Q: How well supported are StableHLO's quantized types and ops? Q: If the user writes a pattern to do a scaled matmul but the pattern matching fails, will the user be alerted in any way? |
Beta Was this translation helpful? Give feedback.
-
Please note that this RFC will be open for review until I also updated the intro of the RFC with this information. I'm also adding a comment here because notifications are sent out for comments but not RFC edits. |
Beta Was this translation helpful? Give feedback.
-
I think the line
|
Beta Was this translation helpful? Give feedback.
-
For the scaling factor determination, I think Nvidia's logic is a little more sophisticated than a static hypothesis as |
Beta Was this translation helpful? Give feedback.
-
It would still be necessary to insert a flag for instructing conversion mode, either RNE (default) or SR, from XLA perspective, or it's immediate next upper at least. We just don't want a lack of this option when dealing with buck data (tensor) conversions for FP8 as dest. type, in the context of FP8 stack main fast-path (vs. casual/low-performance/low-volume or non-fusible SR conversions). |
Beta Was this translation helpful? Give feedback.
-
A technical note: we may want to use F8_MAX (system definable) to replace 448 to be more vendor friendly and generic - different float8 implementation may come up with a different range, OpenAI/Triton's e.g. |
Beta Was this translation helpful? Give feedback.
-
Another notes, regarding delayed scaling, Amax, and some generic thoughts. Notes are:
|
Beta Was this translation helpful? Give feedback.
-
I have a minor comment regarding patterns that will be matched in XLA for the scales, ie:
The division can be expressed as multiplication of reciprocal of the actual scale. Since the actual implementation can vary in TF or Jax, I think it'd better to match patterns with both multiplication and division for scaling. |
Beta Was this translation helpful? Give feedback.
-
"When converting to FP8, XLA will use the typical round-to-even behavior as used in other floating-point dtypes. However, in practice, FP8 should saturate on overflow, because the scale might end up being slightly too large." FP8 should saturate on overflow unless the conversion is to E5M2 and loss scaling is enabled. |
Beta Was this translation helpful? Give feedback.
-
What's proposed in this RFC follows the Transformer Engine approach, where per-tensor scaling is used and the scaling factor was determined by amax of the previous iteration. While this is one interesting approach, we think it is important to be inclusive of other approaches developed by other vendors (AMD, Graghcore, etc) to give the user more options. |
Beta Was this translation helpful? Give feedback.
-
The RFC is now approved, since the review period ended December 9 and there has been no major objections to the design. The RFC originally stated:
I replaced that paragraph with:
Despite being closed for review, if you have any more questions or feedback, please comment. Although it is too late to make any major changes to the design for the initial FP8 implementation, we are still interested in your comments, as we may evolve the design in the future (such as using StableHLO's quantized ops/types). |
Beta Was this translation helpful? Give feedback.
-
@reedwm Hi reedwm sorry late rely, Since FP8 SE4M3FUZ has been merged (3b96f8f), I am really curious about your arithmetic assumptions on FP8 :
In graphcore, you can do FP8 multiplication directly to produce FP16 output
If you cast back data to FP32, that means the datatype is simulated not supported. Perhaps only useful to reduce model size and I/O burdens. While supported FP8 means double speed both in throughput (larger) and latency (faster). |
Beta Was this translation helpful? Give feedback.
-
Dr Sergio Perez in the latest accepted paper (open reviewed version : https://openreview.net/pdf?id=nErbvDkucY) has demonstrated Graphcore-AMD-Qualcomm 's FP8 (e.g.: SE4M3FUZ) can be effectively used in large model inference and training without per-channel scaling for weights! You can see that 1% error can be observed in 70B llama inference and 13B GPT finetune tasks. Graphcore-AMD-Qualcomm's actually works well with scaling. |
Beta Was this translation helpful? Give feedback.
-
RFC: FP8 in XLA
Overview
NVIDIA is introducing support for new 8-bit floating-point formats, collectively referred to as FP8, in their upcoming Hopper GPUs. FP8 results in a 1.2x to 1.5x end to end speedup vs 16-bit training for large language models. According to NVIDIA, there is no degradation in accuracy for most image classification, image detection, GAN, and NLP models. This RFC proposes a design for adding FP8 support to XLA.
Our goal is to have initial FP8 XLA support for Hopper GPUs by the end of 2022.
This RFC has been approved and therefore closed for review on December 9, 2022. You are still free to comment with any questions or thoughts on this design.
/CC @burmako @choucc34 @d0k @hawkinsp @abattery @stellaraccident @nluehr
Summary
Background Summary
Hopper supports two FP8 data types: E4M3 (4 exponent bits, 3 mantissa bits) and E5M2 (5 exponent bits, 2 mantissa bits). Both will be supported in XLA. Other companies are also proposing their own FP8 E4M3 and E5M2 data types, but they differ in minor details such as the NaN encoding. XLA will initially only support NVIDIA's proposed FP8 types, since their hardware supporting FP8 will be among the first to be released.
FP8 has a low dynamic range and is prone to underflow and overflow. Therefore, NVIDIA recommends each FP8 tensor has an associated scale, where the true value of the tensor is the FP8 tensor multiplied by the scale. This type of scaling is a form of symmetric quantization.
The scale is dynamically computed during training. For performance reasons, it is impossible to compute the optimal scale and use it during the same step. Therefore NVIDIA recommends that each step uses the scale from the previous step and computes the scale for the next step.
Design summary
In HLO, MHLO, and StableHLO, two new dtype enum values will be added:
f8E5M2
andf8E4M3
, corresponding to the NVIDIA dtypes supported in Hopper. This is the only change made to the HLO, MHLO, and StableHLO format.Scaling will be represented using existing multiply and divide HLO instructions. In general, to run an op such as Dot with FP8 and scaling, the FP8 inputs will be cast to FP16, then multiplied by the input scales. Then the Dot will be run with FP16 inputs and outputs. A Reduce op will calculate the maximum value of the FP16 Dot output, which is used to compute the new scale for the next step. Then, the FP16 outputs are divided by the output scale and cast back to FP8. This whole process will be fused, so we don't actually pay the cost of running the Dot in FP16. (See section "Scaling" for details.)
StableHLO has special quantized types and ops, which could represent FP8 scaling. For now, we choose not to use them, since these types/ops do not yet support dynamic scales and are not yet supported in HLO. We will consider using them in the future. (See section "StableHLO quantization types and ops" for details.)
cuBLAS/cuDNN directly supports scaling for matmuls and convolutions. XLA will use pattern matching to rewrite Dot and Convolution ops with scaling via Multiply/Divide ops into cuBLAS/cuDNN calls. For non-matmul non-convolution ops with scaling, XLA will fuse them. (See section "XLA GPU codegen" for details.)
FP8 convergence and performance will be tested by training a ResNet50 and BERT model in FP8. (See section "Testing plan" for details.)
Background
Hopper supports two FP8 data types:
E4M3 has more precision but less dynamic range than E5M2. E5M2 is similar to FP16, the only difference being E5M2 has 8 fewer mantissa bits. This is similar to how bfloat16 is identical to FP32 except it has 16 less mantissa bits. E4M3 is more unusual in that it doesn't support infinities and only has two representations for NaN. For more details on these two formats, see this whitepaper.
NVIDIA, ARM, and Intel are working towards standardizing these two FP8 data types, as described in this blog post. Other companies are proposing slightly different versions of FP8, however. These proposals also have an E4M3 and an E5M2 data type, but differ in details such as support for infinities, NaN, and negative zero. For example, while both the NVIDIA types support negative zero, GraphCore and AMD are proposing an FP8 standard where neither E4M3 nor E5M2 support negative zero. Tesla proposed an FP8 format where neither E4M3 nor E5M2 has Inf or NaN, but both have negative zero.
In XLA, we plan on initially supporting the dtypes proposed by NVIDIA, ARM, and Intel, because NVIDIA Hopper GPUs will likely be very popular and we want XLA to have optimal performance on such hardware in the short term. In the future, we will consider supporting other vendors' FP8 data types.
FP8 in machine learning
During training, NVIDIA found that E4M3 should be used on the forward pass, and E5M2 on the backward pass for models to converge to good quality. The forward pass requires the extra bit of precision, while the backward pass requires the increased dynamic range.
Since these types have a very small precision and reduced dynamic range, they are particularly prone to overflow and underflow, especially E4M3. To address this, NVIDIA recommends each tensor should have a scale factor, similar to how integer quantization typically uses a scale and offset (although FP8 only needs a scale according to NVIDIA, not an offset). Given an FP32 value, the quantized value is obtained by dividing by the scale and casting to FP8. Given an FP8 quantized value, the FP32 non-quantized value is obtained by casting to FP32 and multiplying the scale:
Tensor scaling allows values in the tensor to be brought into a representable range.
Most tensors on the forward pass and many on the backward pass require their own scale. The optimal value of the scale is such that it causes the tensor to barely not overflow. In other words, the optimal value is
max(fp32_tensor) / max_representable_fp8_value
.max(fp32_tensor)
refers to the maximum absolute value of the tensor, and is often referred to as "amax".For example, suppose a full precision 3-element tensor has values
[2^-14, 2, 7]
, and that we would like to represent it as an E4M3 tensor. The max E4M3 value is 448, and so the optimal scale is7 / 448 = 1/64 ≈ 0.016
. This means the FP8 tensor is represented by[2^-8, 128, 448]
, which barely does not overflow and brings the first element to the representable value of2^-8
(the minimum positive E4M3 number is2^-9
). Typically scales will be less than one, and so FP8 tensors will have larger values than the corresponding full precision tensors.Scaling in this way is a form of symmetric quantization, which historically has been done on integer tensors, not floating-point tensors. A significant difference between FP8 quantization and integer quantization is that FP8 quantization is done during training, which means we must both compute the scale and use the scale to quantize tensors for each training step. Integer quantization aware training also is done during training, but unlike FP8 quantization, quantization aware training typically does not significantly improve per-step training performance.
Unfortunately, during FP8 training, it is not feasible to efficiently compute
tensor
andmax(tensor)
, then usemax(tensor)
to quantizetensor
to FP8. This is because computingmax(tensor)
requires iterating overtensor
in a wider precision, but we want to start quantizing certain elements oftensor
before computing other elements oftensor
to avoid storing all oftensor
in the wider precision. The solution suggested by NVIDIA is that we usemax(tensor)
to compute the scale for the next step, not the current step. For example, a quantized matmul would numerically be done in the following way during training:The inputs and outputs of
quantized_matmul
are FP8, and the three scales are taken in as inputs. The function returns the quantized output and the new scale.new_z_scale
will be the scale ofz
for the next step.slack
is used to increase the scale slightly, in case the matmul output is slightly higher in the next step. If the matmul output is significantly higher in the next step, it may overflow in the next step, and so FP8 should use saturation on overflow, which results in the max FP8 value on overflow instead of Inf.In this example, the wider precision that
quantized_x
andquantized_y
are cast to is FP32, but it can also be FP16 or BF16. The choice is up to the user, although frameworks like TensorFlow and Keras and JAX may be opinionated on the wider type. BF16 may be preferred since BF16 arithmetic is faster than FP32 on most backends and unlike FP16, BF16 has significantly higher dynamic range than E5M2.NVIDIA recommends the new scale should be calculated based on the maximum amax value over a window of the past N steps, to ensure a step with an unusually amax value does not negatively affect the next step. The choice of how to compute the scale is up to the user and does not affect the compiler design. This RFC describes the new scale being solely a function of the amax value of the previous step, since it makes the examples simpler.
The above example only shows what is numerically done, not what the hardware executes in practice.
The use of symmetric quantization for FP8 is recommended by NVIDIA and directly supported by the cuBLAS and cuDNN libraries, but other researchers have experimented with using FP8 without scaling. For example this paper from GraphCore achieves good results without scaling (although it does use different E4M3 exponent biases for weights vs activations). Ultimately, it will be up to users to decide, although higher level frameworks may be opinionated on how FP8 should be used.
Design
In HLO, MHLO, and StableHLO, two new dtype enum values will be added:
f8E5M2
andf8E4M3
. This will correspond to the dtypes as proposed by NVIDIA, Intel, and ARM and supported by Hopper.E5M2 was recently added to MLIR and the LLVM helper class APFloat and E4M3 will follow, although these dtypes are not being added (yet) to LLVM IR. See the LLVM RFC here. MLIR support for these dtypes is a prerequisite for adding them to StableHLO and XLA.
The naming convention of E5M2 and E4M3 in HLO will follow MLIR's naming convention. E5M2 was already added to MLIR as
Float8E5M2
, and so the HLO/MHLO/StableHLO type will similarly be namedf8E5M2
.The E4M3 type hasn't been added to MLIR yet. Since the dtype has unusual non-IEEE compliant semantics, it may have a more NVIDIA-specific name. The dtype name is referred to as
f8E4M3
in this RFC but will likely be different in practice. Because the dtype name for E4M3 will be decided by MLIR and not XLA, it is not further discussed in this RFC.If other vendors wish to support their own FP8 dtypes, they should first propose adding it to MLIR. Once accepted and implemented, we can consider supporting such types in StableHLO and XLA on a case by case basis. For now, our focus is on the two FP8 dtypes supported by Hopper GPUs.
Scaling
As stated in the background section, NVIDIA recommends FP8 be used with symmetric quantization. The scaling for FP8 symmetric quantization will be represented using normal multiply and divide ops in HLO and StableHLO.
There are several possibilites on how to represent scaling using multiply and divide ops. In this section, we first present a generic approach that will work with any op. The subsection "Alternative way to scale" will present an alternative representation for
Dot
andConv
which closely matches what Hopper hardware (and likely other FP8 hardware) executes in practice.With the generic approach to scaling, running an op, such as
Dot
orAdd
, when training an FP8 model will be represented by the following steps in HLO and StableHLO:Dot
, on the FP16 inputs, getting a FP16 output.Here is an abridged example of how an FP8 matmul which multiplies an input with itself would look like in StableHLO during training. Some lengthy sections of code are replaced by "..." for brevity.
Note that the input is unscaled after being cast to FP16, and the output is scaled before being cast to FP8. There should never be an unscaled FP8 tensor, because otherwise the FP8 tensor may underflow or overflow.
In this example, the new scale is computed as 1.1 * (z_max / 448). The (z_max / 448) part is to create a scale that will cause the FP8 tensor to barely not overflow, since 448 is the maximum representable E4M3 value. The scale is multiplied by 1.1, a "slack" value, in case the tensor during the next step has a slightly higher maximum value. See the "Background" section for details.
On NVIDIA Hopper GPUs, steps (1)-(6) can all be done by cuBLAS function, but requires a minor modification to step (5). cuBLAS requires the inverse output scale (i.e.
1/z_scale
) to be passed instead of the output scale itself. Instead of dividing the output by the output scale, cuBLAS multiplies the output by the inverse output scale. This is mathematically equivalent, but requires XLA to compute1/z_scale
before passing it to XLA. The way XLA GPU will handle this is described in the "XLA GPU codegen" section.When doing inference with a static scale, steps (4) and (7) are not needed since the scale is not updated. The other five steps are identical to the training case.
For dynamic range inference quantization, where the scale is dynamic during inference, steps (1)-(7) can be done similarly to training. Traditionally dynamic range quantization computes the scale for the given step and uses it the same step, instead of computing the scale for the next step. The example above can be modified to do this by running Step 7 before Step 5 and using the newly computed scale in Step 7 to scale the output in Step 5. But on Hopper GPUs, this will be significantly slower, likely reducing performance to be worse than even FP16 performance.
Alternative way to scale
In the above section, ops like Dot have inputs and outputs in FP16 (or BF16 or FP32). However, Hopper hardware directly supports matmuls with FP8 inputs and FP16/BF16 outputs. To better have HLO match what hardware supports, we can represent Dot and Conv ops in HLO with FP8 inputs and FP16/BF16 outputs as well. This approach does not work arbitrary ops such as Add however.
To show how this alternative approach can be done, we show the generic scaling example in the above section using Python-like pseudocode, which has a Dot with FP16 inputs and outputs. We then show a new equivalent example that instead has a Dot with FP8 inputs and FP16 outputs
Note both functions are identical except for the first two lines:
Both functions are mathematically and roughly numerically equivalent. The former function unscales inputs then runs the dot. The latter function runs the dot on the scaled inputs, resulting in a scaled output, then unscales the output using the input scales. This relies on the fact that it's equivalent to scale the input or the outputs. That is, we have the property that given any matrices
x
andy
and any scalarsxs
andys
, we haveThis property is true only for a limited set of ops, notably Dot and Conv. Therefore, this alternative representation of scaling can not be used for arbitrary ops.
Hopper supports FP8 arithmetic through matmuls with FP8 inputs and FP16/BF16 outputs. Therefore, the alternative representation closely matches what Hopper hardware executes. Other FP8 hardware will likely also execute FP8 matmuls similarly. This will make it easier for compiler backends to emit code for Dot and Conv instructions (Conv is typically implemented via matmuls).
Currently in XLA GPU, we plan on only supoprt FP8 Dot and Conv instructions through cuBLAS and cuDNN, and neither representation makes this easier than the other. But the altnerative representation will potentially make emitting Dot and Conv instructions easier on other backends, as well as the XLA-GPU backend if it ever chooses to generate its own Dot and Conv code instead of going through cuBLAS and cuDNN. XLA GPU will pattern match both representations to cuBLAS/cuDNN calls.
StableHLO quantization types and ops
StableHLO has special quantized types and ops, which support both symmetric quantization and asymmetric quantization. This quantization support would allow us to directly represent the scaling done by FP8 symmetric quantization without needing to use explicit multiply and divide ops. This would make pattern matching to cuBLAS calls or other backend-specific ops much simpler.
However, we will initially still use multiply and divide ops to represent scaling instead of the quantized types and ops, using either the orignal scaling representation or the altnerative. The primary reason is that fully supporting the quantized types/ops will take a considerable amount of work, and so relying on them will not allow us to support FP8 on Hopper GPUs by the end of the year. In particular, using the quantized types/ops for FP8 requires the following:
UniformQuantizedType
type, which StableHLO's quantized types use, only supports compile-time constant scales. FP8 training requires the scale to change each step, which means a dynamic scale is needed.In the future, we will consider using StableHLO's quantized types and ops. This RFC makes no recommendation on whether to switch to these types and ops in the future.
The appendix has more information on how FP8 can be represented with StableHLO's quantized types and ops in the future.
When scaling: multiply vs. divide
Recall from the "Background" section that when quantizing from FP32 to FP8, the FP32 value is divided by the scale. When dequantizing from FP8 to FP32, the FP8 value is multiplied by the scale:
This is consistent with integer symmetric quantization, where the FP32 value is converted to INT8 by dividing by the scale.
However, FP16 loss scaling traditionally takes the opposite approach: an FP32 value is multiplied by the scale to convert to FP16, and an FP16 value is divided by the scale to convert to FP32. Unlike both FP8 quantization and integer quantization, the wider type is multiplied by the scale to get to the narrower type
In this doc, we choose to express FP8 quantization similarly to symmetric quantization, in that the wider type is divided by the scale to get the narrower type. Despite the fact that both FP8 and FP16 are floating-point types, FP8 scaling is more similar to integer quantization than to FP16 loss scaling, and so it makes more sense to follow the integer quantization convention.
Some models may use both loss scaling and FP8 quantization. In this case, converting from FP32 to FP8 involves multiplying by the loss scale and dividing by the tensor scale (in practice, intermediate tensors are almost never multiplied or divided by the loss scale, however):
Note that cuBLAS uses one convention for the inputs and the other for the outputs. In particular, the FP8 inputs are multiplied by a value to convert them to a wider type, but the outputs are also multiplied by a value to convert them to FP8. By the convention of this doc, we say the FP8 inputs are multiplied by the input scales, and the outputs of a wider type are multiplied by the inverse output scale.
The choice of convention of whether to multiply or divide does not significantly affect the compiler design itself, but mostly affects the description of scaling in this doc itself. When pattern matching to cuBLAS calls, XLA will support any combination of multiplying or dividing the inputs by a scale and multiplying or dividing the ouput by a scale. The choice of convention will have a significant impact on high-level ML frameworks which support quantization, such as Keras.
FP8 arithmetic
The "Scaling" section has shown that when scaling is used, no arithmetic ops such as Dot should ever run with both non-quantized FP8 inputs and non-quantized FP8 outputs. This is because every non-quantized FP8 tensor should represent a scaled tensor when scaling is used, but arithmetic operations do not take in a scale.
Running arithmetic ops like Dot with FP8 inputs and outputs will still be supported, however. For example, the following will be allowed:
NVIDIA recommends the use of scaling in all cases. Other companies, such as GraphCore, have successfully used FP8 without scaling. Users of XLA can ultimately choose whether to scale or not. If scaling ends up being highly recommended in many cases, frameworks such as Keras and JAX can choose, if they want, to warn when FP8 is used without scaling.
NVIDIA GPUs do not directly support arithmetic operations on FP8 values, but a pass similar to bf16-normalization can upcast the inputs to arithmetic ops to get numerically equivalent results.
If FP8 arithmetic was not directly supported in HLO and StableHLO, it could still be emulated by converting the input tensors to FP16, running the arithmetic instruction in FP16, then converting the output back to FP8. This is effectively running steps (1)-(6) above with a scale of 1. But doing this would be tricky for users compared to directly using FP8 tensors and would make FP8 inconsistent with other floating-point dtypes, which is why FP8 arithmetic will be supported.
Rounding
When converting to FP8, XLA will use the typical round-to-even behavior as used in other floating-point dtypes. However, in practice, FP8 should saturate on overflow, because the scale might end up being slightly too large. Therefore, when casting to FP8, frameworks like TensorFlow and JAX should emit a Clamp instruction, to clamp to the highest possible FP8 value, before emitting the Convert instruction.
Note that with FP8 arithmetic, as described in the above section, there will be no option to clamp FP8 outputs of arithmetic ops.
The ReducePrecision instruction can model the E5M2 type but not the E4M3 type, since E4M3 lacks Inf values and ReducePrecision assumes normal IEEE-like semantics. It is currently unclear how to extend ReducePrecision to support FP8 types, since the existing FP8 implementations differ in terms of Inf, NaN, and -0 representations. Future FP8 types may differ in other ways. Therefore, we will wait until FP8 becomes available on more hardware before deciding whether and how to add E4M3 support to ReducePrecision.
Stochastic rounding has been shown in many cases to result in better model quality compared to round-to-even, especially for low-precision dtypes such as FP8. A StochasticConvertType instruction was recently added, and support for this instruction is being added to XLA backends. Since stochastic rounding is not FP8-specific, it is not further considered in the FP8 design, although it may be important in achieving optimal model quality.
XLA GPU codegen
cuBLAS directly supports scaling and computing the maximum output value for matmuls. See the documentation for details. As stated before, steps (1)-(6) in the "Scaling" section above can be run with a single cuBLAS function call, with a minor modification: For the output scale, cuBLAS requires the inverse output scale to be passed in to the matmul function, and cuBLAS multiples the output with the inverse output scale, instead of dividing the output by the output scale. This allows cuBLAS to avoid many costly divisions, and the caller only has to pay the cost of a single scalar reciprocal.
XLA will pattern match steps (1)-(6) into a cuBLAS call. This rewrite will be done in the gemm-rewriter pass. Because these steps divide the output by the output scale and cublas takes in the inverse output scale, XLA will additionally insert a divide instruction on the output scale before passing it to the cuBLAS call. A horizontal fusion pass can later fuse these scalar divisions into a single kernel.
XLA will also pattern match a version of steps (1)-(6) where the output is multiplied by the inverse scale instead of divided by the scale. In this case, the gemm-rewriter pass does not need to insert a divide instruction to compute the reciprocal. NVIDIA recommends frameworks like TensorFlow do this by computing the scale and inverse scale at the same time. However, this approach will make using the current form of the StableHLO quantized types and ops more awkward, as the
stablehlo.uniform_quantize
op expects a scale, not an inverse scale.Additionally, XLA will pattern match the pattern described in the section "Alternative way to scale", which is numerically equivalent to steps (1)-(6).
cuDNN also supports FP8 with scaling for convolutions, but currently only on the forward pass. As with matmuls in cuBLAS, we will rewrite FP8 convolutions to cuDNN calls.
When emitting LLVM IR, XLA will represent FP8 as int8, using the NVIDIA Hopper hardware instructions to convert to wider types to do arithmetic. The PTX cvt instruction to convert types currently only supports covnerting a vector of two FP8 values, but since this is difficult to support, XLA codegen will initially convert only a single FP8 value at a time, passing an unused placeholder value for the second input.
When the HLO directly does FP8 arithmetic, a pass similar to bf16-normalization will upcast tensors so that no FP8 arithmetic is done.
As of commit 72eb5d2b, XLA supports fusing steps (1)-(6) in the "Scaling" section (FP8 is not yet supported but XLA can fuse (1)-(6) when higher-precision dtypes are used). For instructions other than convolutions and dots, fusing steps (1)-(6) are not done by an FP8-specific pass but instead are done as part of the general fusion passes.
Testing plan
Unit tests will be added to XLA that test FP8 correctness.
Convergence and performance testing on Hopper GPUs will be done through TensorFlow and JAX. Both frameworks plan on adding FP8 support by the end of 2022, although FP8 will not necessarily be easy to use at first. We will find a TensorFlow or JAX ResNet50 model and a TensorFlow or JAX BERT model. For each, we will fork it and add FP8 support to both, then run performance and convergence tests.
Unfortunately, we do not have a baseline for FP8 performance. If, say, we find BERT is 20% faster on Hopper using FP8 compared to FP16, we will not know if the FP8 speedup is close to optimal. We will work with NVIDIA to determine if our performance results are acceptable.
As a stretch goal, we will also port FP8 to a GPT-3-like model and test convergence and performance.
Unresolved issue: When to scale
This section describes an unresolved issue that affects the boundary between frameworks such as TensorFlow/JAX and XLA. Suppose a user writes the follow function (using TensorFlow notation):
To convert to FP8, the user needs to add casting, scaling, and reduce_max computations. Suppose the user chooses to represent
x
andz
as FP8 tensors, keepingy
in BF16.The user starts by converting
x
into a BF16 tensor. After computingy
andz
in BF16, the user computes the max ofz
and converts the result back to FP8.Why did the user keep
y
in BF16? The reason is that this is necessary for FP8 to have optimal performance with a compiler. In the original BF16 function, the compiler will fuse the multiplication and addition. By makingx
andz
FP8, the compiler must additionally fuse the casts, scalings, andreduce_max
into the computation.However, if
y
were additionally made FP8, more scaling ops and anotherreduce_max
would be added, to compute the max ofy
. This would lead to an unnecessary performance loss, and would likely split the fusion into multiple fusions. In general, there is no reason to use FP8 within a fusion. Outside convolutions and matmuls (which are handled by cuBLAS/cuDNN), the main purpose of FP8 is that it takes less memory. However, within a fusion, intermediate values are kept in registers, not GPU memory.The big question is this: how does the user or framework know what tensors should be in FP8? The user/framework does not know what the compiler will fuse ahead of time, so it cannot ensure that tensors within fusions are BF16 and inputs/outputs to fusion are FP8.
There are no plans to address this initially by the end of 2022. TensorFlow/JAX users will have to be roughly aware of what is fused in order to get performance benefits in FP8. XLA has a flag,
xla_allow_excess_precision
, allowing it to increase precision of tensors, but this doesn't allow it to skip the computation of the scaling ops or thereduce_max
call. One solution for the long term is to develop a mechanism where XLA can skip the scaling and thereduce_max
call if it increases the precision of the corresponding tensor.Appendix: Details on StableHLO quantization types and ops
The "Scaling" section briefly described how StableHLO has special quantized types and ops, but that these will not be initially used for FP8 symmetric quantization. This section describes how they could potentially be used with FP8 in the future.
We start by giving an example of how to quantize a tensor using StableHLO quantized types/ops:
%qx
is the quantized version of the floating-point tensor%x
. Let's start by examining%qx
's element type, which is!quant.uniform<i8:f32, 2.0:1>
. The type!quant.uniform
type is defined in the MLIR repository itself and in this example is paramerized with<i8:f32, 2.0:1>
. Thei8:f32
parameters means the storage type isi8
while the expressed type, which is the type the tensor is approximating, isf32
. The2.0:1
parameters means the scale is 2 and the zero point is 1, which indicate how to convert between the quantized and real values. The formulas to convert from a quantized integer value to a real value and vice versa are:The
stablehlo.uniform_quantize
op converts a floating point tensor to an integer quantized tensor using the formula above. Because%x
is[[2., 4.], [6., 8.]]
, the integer representation of%qx
isx / 2 + 1 = [[2, 3], [4, 5]]
.Quantized types can be directly passed to ops. Let's continue the example above by passing
%qx
to thestablehlo.add
op:When quantized types are passed directly to ops, the op takes into account the scale and zero-point. In the above example, the addition does not just add the two integer tensors as if they were non-quantized. Instead, it does the equivalent of dequantizing the inputs into a floating-point tensors, running the floating-point addition, then quantizing the output back to an integer format. Therefore, the result of the addition,
%qy
, represents the FP32 tensor2 * x = [[4., 8.], [12., 16.]]
, and its quantized integer representation in memory is(2 * x) / 2 + 1 = [[3, 5], [7, 9]]
.For FP8 training, the scale should not be a compile-time constant, but instead should be dynamically computed and updated at runtime. Unfortunately, this is not yet possible with the
!quant.uniform
type, which is one of the reasons why FP8 quantization will not initially use the quantized types, but support for runtime scales may be added in the future.Once (or if) there is support for dynamic scales in MLIR's
!quant.uniform
, running an op such as Dot when training an FP8 model can be represented by the following steps in StableHLO using the quantized types and ops:stablehlo.uniform_quantize
.Here is an abridged example of how an FP8 matmul which multiplies an input with itself would look like in StableHLO during training.
The
"stablehlo.dot"
operation returns an FP16 output instead of a quantized FP8 output because the maximum value of the FP16 tensor first needs to be computed before the tensor is quantized.As stated earlier, this RFC does not make a recommendation on whether the StableHLO quantized ops and types will be used in the future. Initially, they will not be used, as scaling will be represented using multiply and divide ops.
Beta Was this translation helpful? Give feedback.
All reactions