Skip to content

NVfp4#2408

Merged
drisspg merged 1 commit into
mainfrom
drisspg/stack/78
Jun 24, 2025
Merged

NVfp4#2408
drisspg merged 1 commit into
mainfrom
drisspg/stack/78

Conversation

@drisspg

@drisspg drisspg commented Jun 18, 2025

Copy link
Copy Markdown
Contributor

Stacked PRs:


Add NVFP4 Inference flow

Details:

I kept this separate for MX but realistically we should probably merge the two. Basic support for blocksize 16 + e4m3 scales.

Double Quant Update

Ignore previous comments, the double quant is actually really similar to NF4 where you just scale the fp32 scales prior to casting to e4m3 to try and reduce scale quant error.

I have that implemented now in the Nvfp4 code if a tesor_scale is given, just need to figure out how to thread to cublas param scale_in_d or how we want to expose this. We currently don't expose the C matrix to the Python API so we could use alpha as @gau-nernst pointed out to me, however we dont expose alpha either 🙃. However if we wanted to use alpha we would need the value on the host, the sync would likely rule out this option. I might keep this double quant on hold until we have the public api, since I am thinking about adding scale overloads to addmm. However I read the cublas docs many times and it feels as though passing to scale result should work since we don't set the d_mode and its default value should work.

Early Perf

No double quant here

python /home/drisspg/meta/vllm/benchmarks/benchmark_throughput.py \
 --backend vllm \
 --model "data/nvfp4-Qwen3-8B" \
 --dataset-name sharegpt \
 --dataset-path data/ShareGPT_V3_unfiltered_cleaned_split.json \
 --num-prompts 1024 \
 --disable-log-stats \
 --gpu-memory-utilization=0.9 \
 --seed 42
Throughput: 43.23 requests/s, 18347.24 total tokens/s, 8840.47 output tokens/s
Total num prompt tokens:  225190
Total num output tokens:  209407

which is even worse than mxfp4..., will profile later

Micro Bench

LLama 70B mlp no TP:

Model Configuration Runtime (μs/iteration) Speedup vs BF16
BF16 1353.09 1.00x
mxfp8 766.76 1.76x
mxfp4 638.00 2.12x
nvfp4 540.41 2.50x

Diffusers

# Bf16 Compile
|           ckpt_id            |   batch_size |  fuse  |  compile  |  compile_vae  |  quantization  |  sparsify  |   model_memory |   inference_memory |   time |
|:----------------------------:|-------------:|:------:|:---------:|:-------------:|:--------------:|:----------:|---------------:|-------------------:|-------:|
| black-forest-labs/FLUX.1-dev |            1 | False  |   True    |     False     |      None      |   False    |         31.438 |             33.827 |  3.286 |

Errors

Annoyingly we are getting an error due to the view as fp4x2 + packing https://fburl.com/cd92w431 because this is trying to be bitcast iside inside triton kernel which is very annoying. Not sure how this didn't show up until vllm / w/ mxfp4
^ similar to this: triton-lang/triton#6054 but make the same changes in _inductor/utils.py as we did for float8em0

Numerics

Script: https://gist.github.com/drisspg/4024ed055a6db911495102614c674c4c -> still emulating till we fix this bug in cublaslt bindings
Double quant really helps w/ tensor that have very small amax values, likely by reducing the amount of underflows will verify:
nvfp4_gelu_performance_heatmap

Flow

float4_two_level scale

@pytorch-bot

pytorch-bot Bot commented Jun 18, 2025

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2408

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Pending

As of commit 4fe3daf with merge base 4e25496 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

drisspg added a commit that referenced this pull request Jun 18, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from c58c5b0 to 3948f5d Compare June 18, 2025 20:33
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 18, 2025
drisspg added a commit that referenced this pull request Jun 18, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from 3948f5d to 1025236 Compare June 18, 2025 21:05
drisspg added a commit that referenced this pull request Jun 18, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from 1025236 to 1c007a4 Compare June 18, 2025 21:30
@drisspg drisspg added mx topic: new feature Use this tag if this PR adds a new feature labels Jun 19, 2025
drisspg added a commit that referenced this pull request Jun 19, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from 1c007a4 to a3d2874 Compare June 19, 2025 04:19
drisspg added a commit that referenced this pull request Jun 19, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from a3d2874 to 034f892 Compare June 19, 2025 04:26
drisspg added a commit that referenced this pull request Jun 19, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from 034f892 to 92e0622 Compare June 19, 2025 04:27
drisspg added a commit that referenced this pull request Jun 19, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from 92e0622 to b2c45a1 Compare June 19, 2025 04:38
drisspg added a commit that referenced this pull request Jun 19, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from b2c45a1 to 7448f45 Compare June 19, 2025 04:56
drisspg added a commit that referenced this pull request Jun 19, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from 7448f45 to fad58b5 Compare June 19, 2025 16:00
drisspg added a commit that referenced this pull request Jun 19, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from fad58b5 to b5a593d Compare June 19, 2025 16:03
drisspg added a commit that referenced this pull request Jun 19, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from b5a593d to 2b4ba64 Compare June 19, 2025 16:31
drisspg added a commit that referenced this pull request Jun 19, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from 2b4ba64 to 5d50579 Compare June 19, 2025 23:00
drisspg added a commit that referenced this pull request Jun 19, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from 5d50579 to 29fa9ef Compare June 19, 2025 23:36
@drisspg drisspg force-pushed the drisspg/stack/78 branch from 79720b3 to f194e35 Compare June 20, 2025 23:56
drisspg added a commit that referenced this pull request Jun 21, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from f194e35 to b08a108 Compare June 21, 2025 00:17
@drisspg drisspg marked this pull request as draft June 21, 2025 00:46
"scale": None,
}

quantized_weight = to_linear_activation_quantized(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just write the logic here instead of using to_linear_activation_quantized? I remember same feedback on the mxfp4 inference tensor.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you can't just move the logic out here, the entirety of the forward behavior has to be "wrapped" by the subclass. Currently there are two ways to do that, without changing nn.modules.

  1. Like above; this is subclass composition
  2. The other is to copy the same behavior into the implementations of the ops,
    e.g.

NVFP4's dispatch would need to copy:

@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
if isinstance(weight_tensor, LinearActivationQuantizedTensor):
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
raise NotImplementedError(
"LinearActivationQuantizedTensor: No specialized dispatch found for linear op"
)
@implements([aten.mm.default, aten.addmm.default])
def _(func, types, args, kwargs):
if not args[0].is_floating_point():
raise NotImplementedError(
"LinearActivationQuantizedTensor: expecting a floating point input"
)
if func == aten.addmm.default:
assert args[1].shape[-1] == args[2].shape[0], (
f"need mat1 shape: {args[1].shape} final"
f"dim to match mat2 shape: {args[2].shape} first dim "
)
input_tensor, weight_tensor, bias = (
args[1],
args[2],
args[0],
)
input_quant_func = weight_tensor.input_quant_func
original_weight_tensor = weight_tensor.original_weight_tensor
qtensor = input_quant_func(input_tensor, **weight_tensor.quant_kwargs)
return func(bias, qtensor, original_weight_tensor)
else:
# aten.mm.default
assert args[0].shape[-1] == args[1].shape[0], (
f"need mat1 shape: {args[0].shape} final dim"
f"to match mat2 shape: {args[1].shape} first dim"
)
input_tensor, weight_tensor = (
args[0],
args[1],
)
input_quant_func = weight_tensor.input_quant_func
original_weight_tensor = weight_tensor.original_weight_tensor
qtensor = input_quant_func(input_tensor, **weight_tensor.quant_kwargs)
return func(qtensor, original_weight_tensor)

Not the end of the world. But for some subclasses that serve dual purpose (dyanmic + weight only, + static, + training) it can be alot of switch statements in the ops as opposed to having the base subclass + some sugar

Comment thread torchao/prototype/mx_formats/mx_subclass.py Outdated
M, K = orig_shape[0], orig_shape[1]
data_f32 = data_f32.view(M, K // self._block_size, self._block_size)
scale_e4m3_reshaped = self._scale_e4m3.view(M, K // self._block_size, 1)
data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a second rescaling somewhere?

@drisspg drisspg Jun 21, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my two-level scaling is ont working correctly yet
Fixed

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still confused. If there is per-tensor scaling, we scale by the per-tensor scale to convert to nvfp4. Do we need to undo this scaling when converting back to high precision?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, it looks like it lives in self.get_scales(). Maybe we can rename the function to make it obvious that both scales are there, like self.get_blockwise_and_maybe_tensorwise_scales()?

Comment thread torchao/prototype/mx_formats/nvfp4_tensor.py
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jun 21, 2025
drisspg added a commit that referenced this pull request Jun 21, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from b08a108 to ac989f2 Compare June 21, 2025 05:54
drisspg added a commit that referenced this pull request Jun 21, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from ac989f2 to e9708eb Compare June 21, 2025 06:22
drisspg added a commit that referenced this pull request Jun 21, 2025
stack-info: PR: #2408, branch: drisspg/stack/78
@drisspg drisspg force-pushed the drisspg/stack/78 branch from e9708eb to b03916b Compare June 21, 2025 06:23
Comment thread torchao/prototype/mx_formats/nvfp4_tensor.py
Comment thread torchao/prototype/mx_formats/nvfp4_tensor.py Outdated
@drisspg

drisspg commented Jun 23, 2025

Copy link
Copy Markdown
Contributor Author

Weight only fails w/ compile and bisected to:

Likely from the work around to get triton to not error on e2m1

Disabling lowerings fixed the issue.
Starting bisect by getting upper bound.
Upper bound of 38 found for inductor.
Bisecting inductor - lowerings (Range: [0, 38], Midpoint: 19)
Bisecting inductor - lowerings (Range: [20, 38], Midpoint: 29)
Bisecting inductor - lowerings (Range: [30, 38], Midpoint: 34)
Bisecting inductor - lowerings (Range: [35, 38], Midpoint: 36)
Bisecting inductor - lowerings (Range: [35, 36], Midpoint: 35)
Binary search completed for inductor - lowerings. The bisect number is 36. Debug info: convert_element_type_5
Bisection status deleted.
   Bisection result: BisectionResult(backend='inductor', subsystem='lowerings', bisect_number=36, debug_info='convert_element_type_5')

6. Testing inductor config workarounds for WEIGHT_ONLY:
   {'inductor.coordinate_descent_tuning': False}           ERROR
   {'inductor.force_fuse_int_mm_with_mul': False}          ERROR
   {'inductor.post_grad_passes': False}                    ERROR
   {'inductor.pattern_matcher': False}                     ERROR
   {'inductor.epilogue_fusion': False}                     ERROR
   {'inductor.max_autotune': False}                        ERROR
   {'triton.autotune_pointwise': False}                    ✗ 3.1dB
   {'inductor.benchmark_kernel': False}                    ERROR
   {'inductor.aggressive_fusion': False}                   ERROR

7. Testing other compile backends:
   Backend 'eager': SQNR = 20.00 dBBackend 'aot_eager': SQNR = 20.00 dBskipping cudagraphs due to skipping cudagraphs due to cpu device (_tensor_constant0). Found from : 
   File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/external_utils.py", line 70, in inner
    return fn(*args, **kwargs)

   Backend 'cudagraphs': SQNR = 20.00 dB

@drisspg drisspg mentioned this pull request Jun 23, 2025
@drisspg

drisspg commented Jun 23, 2025

Copy link
Copy Markdown
Contributor Author

@vkuzo updated to use the mm_config

@gau-nernst gau-nernst left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some comments

Comment thread torchao/prototype/mx_formats/mx_subclass.py Outdated
Comment thread torchao/prototype/mx_formats/nvfp4_tensor.py
Comment thread torchao/prototype/mx_formats/nvfp4_tensor.py
Comment thread torchao/prototype/mx_formats/mx_subclass.py
Comment thread torchao/prototype/mx_formats/mx_subclass.py Outdated
Comment thread torchao/prototype/mx_formats/mx_subclass.py Outdated


@implements([torch.nn.functional.linear, aten.linear.default])
def nvfp4_linear(func, types, args, kwargs):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious on why we need both linear and mm flavors instead of picking one? I thought linear was for torch function and mm variants for torch dispatch?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need the mm and addmm in most circumstances, however if you run under inference mdoe that is pre-dispatch you end up see these linear ops

Comment thread torchao/prototype/mx_formats/nvfp4_tensor.py
Comment thread torchao/prototype/mx_formats/nvfp4_tensor.py Outdated

data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
data_scaled = data_scaled.view(orig_shape)
data_lp = f32_to_f4_unpacked(data_scaled.float())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember this function was being pretty slow when we wrote the emulation code, curious how performance is now? There are likely things we can do to make it faster if this is in the hot path.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah @gaunerst has a nice triton kernel for doing the cast + pack that I am going to try in a follow up PR

stack-info: PR: #2408, branch: drisspg/stack/78
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. mx topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants