Quick Jump:
Why?
tl;dr:
model size * 0.5
throughput * 1.2ish
(with a lot of caveats). See our benchmarks
Models today are usually trained in bf16
, which is
a decimal number stored in 16 bits (2 bytes). At the billions of
parameter scale, these add up VERY quickly. The main reason for
quantizing a model from bf16
to fp8
is
memory reduction.
For example meta-llama/Llama-3.3-70B-Instruct has 70 billion parameters, which at
bf16
is 140 billion bytes or 140 GB of data. A single H100 GPU has 80GB of GPU RAM, so you’d need at LEAST 2xH100 to serve it, but likely more for kv cache space. If you halve the number of bytes, it would only take 70 GB, enabling it to comfortably fit on 2xH100s, and just fit barely on 1xH100.
Starting with NVIDIA H100 GPU, GPUs have hardware
support for 8 bit floating point numbers (fp8
),
meaning fp8
performance is >=
bf16
performance (mostly). This performance
gain comes from a couple of reasons:
- Model takes less GPU ram => more space for kv cache. Modern inference libraries (like vllm/sglang) will have higher/more stable performance with more space for kv cache
- Model parameters are half as big => less GPU memory bandwidth
- Depending on the GPU, fp8 FLOPS are just higher than
bf16
FLOPS. E.g. See H100 specifications; bfloat16 has ~2k teraFLOPS and fp8 has ~4k teraFLOPS
How?
Note on executing fp8 models
When we talk about fp8
models, we typically only
are talking about the weights being
fp8
. The actual execution of the model is
still done in bf16
. So all the intermediate
tensors are still in bf16
, and it’s the
underlying CUDA kernels that are taking in bf16
tensors and fp8
weights.
`fp8` weight
\
v
`bf16` input -> Linear -> `bf16` output
fp8 models still use bf16
kv cache by
default (since the kv cache stores kv values, which are
intermediate tensors).
fp8 bit format
There are a number of different fp8
formats; the
most common is float8_e4m3fn
. Here are the bit
patterns for the f8 and f16 formats:
Format | Bit Pattern | INF Support |
---|---|---|
float8_e4m3fn | ⚫🟩🟩🟩🟩🟥🟥🟥 | ❌ |
float8_e5m2fn | ⚫🟩🟩🟩🟩🟩🟥🟥 | ❌ |
bfloat16 | ⚫🟩🟩🟩🟩🟩🟩🟩🟥🟥🟥🟥🟥🟥🟥 | ✅ |
float16 | ⚫🟩🟩🟩🟩🟥🟥🟥🟥🟥🟥🟥🟥🟥🟥 | ✅ |
where: ⚫ = Sign bit, 🟩 = Exponent bit, 🟥 = Mantissa (fraction) bit
Here are some facts about float8_e4m3fn
:
- This format has
1
sign bit,4
bits for exponent (e4
), and3
bits for mantissa (m3
) - Values can be between
[-448, +448]
- There are
256
representable values infinity
not supported (thefn
postfix stands for “finite numbers only” - there are other fp8 formats that do support infinity)NaN
supported- Model parameters are typically stored using this format (note
that
inf
is not usually present in pretrained model parameters)
Expand this section to see all the possible fp8_e4m3fn values
torch.arange(256, dtype=torch.uint8).view(dtype=torch.float8_e4m3fn).tolist()
[0.0, 0.001953125, 0.00390625, 0.005859375, 0.0078125,
0.009765625, 0.01171875, 0.013671875, 0.015625, 0.017578125,
0.01953125, 0.021484375, 0.0234375, 0.025390625, 0.02734375,
0.029296875, 0.03125, 0.03515625, 0.0390625, 0.04296875, 0.046875,
0.05078125, 0.0546875, 0.05859375, 0.0625, 0.0703125, 0.078125,
0.0859375, 0.09375, 0.1015625, 0.109375, 0.1171875, 0.125,
0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375,
0.25, 0.28125, 0.3125, 0.34375, 0.375, 0.40625, 0.4375, 0.46875,
0.5, 0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375, 1.0,
1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875, 2.0, 2.25, 2.5, 2.75,
3.0, 3.25, 3.5, 3.75, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0,
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 18.0, 20.0, 22.0,
24.0, 26.0, 28.0, 30.0, 32.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0,
60.0, 64.0, 72.0, 80.0, 88.0, 96.0, 104.0, 112.0, 120.0, 128.0,
144.0, 160.0, 176.0, 192.0, 208.0, 224.0, 240.0, 256.0, 288.0,
320.0, 352.0, 384.0, 416.0, 448.0, nan, -0.0, -0.001953125,
-0.00390625, -0.005859375, -0.0078125, -0.009765625, -0.01171875,
-0.013671875, -0.015625, -0.017578125, -0.01953125, -0.021484375,
-0.0234375, -0.025390625, -0.02734375, -0.029296875, -0.03125,
-0.03515625, -0.0390625, -0.04296875, -0.046875, -0.05078125,
-0.0546875, -0.05859375, -0.0625, -0.0703125, -0.078125,
-0.0859375, -0.09375, -0.1015625, -0.109375, -0.1171875, -0.125,
-0.140625, -0.15625, -0.171875, -0.1875, -0.203125, -0.21875,
-0.234375, -0.25, -0.28125, -0.3125, -0.34375, -0.375, -0.40625,
-0.4375, -0.46875, -0.5, -0.5625, -0.625, -0.6875, -0.75, -0.8125,
-0.875, -0.9375, -1.0, -1.125, -1.25, -1.375, -1.5, -1.625, -1.75,
-1.875, -2.0, -2.25, -2.5, -2.75, -3.0, -3.25, -3.5, -3.75, -4.0,
-4.5, -5.0, -5.5, -6.0, -6.5, -7.0, -7.5, -8.0, -9.0, -10.0,
-11.0, -12.0, -13.0, -14.0, -15.0, -16.0, -18.0, -20.0, -22.0,
-24.0, -26.0, -28.0, -30.0, -32.0, -36.0, -40.0, -44.0, -48.0,
-52.0, -56.0, -60.0, -64.0, -72.0, -80.0, -88.0, -96.0, -104.0,
-112.0, -120.0, -128.0, -144.0, -160.0, -176.0, -192.0, -208.0,
-224.0, -240.0, -256.0, -288.0, -320.0, -352.0, -384.0, -416.0,
-448.0, nan]
And here is how all the representable values are distributed (notice how there are waaaaay more values closer to 0! ):
So this leads us with two questions for quantization:
bf16
can store values between[-3.38953e+38, +3.38953e+38]
, how do we fit that intofp8
range of[-448, +448]
?- How do we take advantage of the distribution of values in
fp8
?
Quantization - scaling to lower precision loss & handle large values
Since bf16
and fp8
have different
ranges, we need to scale the values to fit into the
fp8
range. This scale is based on the max value of
the data at bf16
, and is roughly computed like:
# NOTE: this will be a single value
= x_bf16.abs().amax() / 448 scale
Then once we have the scale we can quantize the
bf16
tensor:
= (x_bf16 / scale).clamp(min=-448, max=448).to(torch.float8_e4m3fn) x_fp8
Note that by dividing by scale, the values should alreayd be
within the range of -448 to 448, so the extra
.clamp()
operation is just to ensure this
numerically.
And to dequantize (which is essentially done on the fly at
runtime inside the CUDA kernels), you do this (noting that you
have to store the scale
values for the forward
process):
= x_fp8.to(torch.bfloat16) * scale x
Block style scale
Above I showed the scale being a single value, but you can also
have scale applied to blocks of values in the tensor. If you look
at some popular open source fp8
models they typically
use this option.
Why would you do this? To theoretically preserve accuracy.
Given a weight_block_size
of
[128, 128]
, and a tensor of shape
[N, K]
, the scale will be of size
[N // 128, K // 128]
:
E.g. assuming x is 2d, we have the code:
= x.shape
N, K = weight_block_size
n, k = x.reshape(N // n, n, K // k, k)
x = x.abs().amax(dim=[1, 3]) / 448
scale assert scale.shape == torch.Size([N // n, K // k])
Saving a quantized checkpoint
For compatibility with things like VLLM there’s a couple things we need to do:
Add the scales to Linear layers
We need to add the previously computed
weight_scale
as a parameter to each of the
Linear
layers. This basically means just replace the
Linear
layer with this custom
PackedLinear
class, where weight
is the
fp8
tensor, and weight_scale
is the
scale from previous sections.
class PackedLinear(torch.nn.Module):
def __init__(self, weight: torch.Tensor, weight_scale: torch.Tensor):
super().__init__()
self.weight = torch.nn.Parameter(weight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
Update model config
This part is really easy, just add a
quantization_config
into the model’s config. This
will also appear in the config.json
file in the
huggingface repo of the model.
= {
model.config.quantization_config "quant_method": "fp8",
"is_checkpoint_fp8_serialized": True,
"activation_scheme": "dynamic",
"weight_block_size": ..., # `None` or `[int, int]`
"ignored_layers": ..., # list of module names that are not quantized
}
And that’s all we need to do for vllm!
NOTE: some models don’t support all layers being
quantized. For example, vllm does not support the
decoder.mlp.gate
linear layer being quantized in
Qwen3 MoE models.