NVfp4#2408
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2408
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 PendingAs of commit 4fe3daf with merge base 4e25496 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
c58c5b0 to
3948f5d
Compare
3948f5d to
1025236
Compare
1025236 to
1c007a4
Compare
1c007a4 to
a3d2874
Compare
a3d2874 to
034f892
Compare
034f892 to
92e0622
Compare
92e0622 to
b2c45a1
Compare
b2c45a1 to
7448f45
Compare
7448f45 to
fad58b5
Compare
fad58b5 to
b5a593d
Compare
b5a593d to
2b4ba64
Compare
2b4ba64 to
5d50579
Compare
5d50579 to
29fa9ef
Compare
79720b3 to
f194e35
Compare
f194e35 to
b08a108
Compare
| "scale": None, | ||
| } | ||
|
|
||
| quantized_weight = to_linear_activation_quantized( |
There was a problem hiding this comment.
can we just write the logic here instead of using to_linear_activation_quantized? I remember same feedback on the mxfp4 inference tensor.
There was a problem hiding this comment.
So you can't just move the logic out here, the entirety of the forward behavior has to be "wrapped" by the subclass. Currently there are two ways to do that, without changing nn.modules.
- Like above; this is subclass composition
- The other is to copy the same behavior into the implementations of the ops,
e.g.
NVFP4's dispatch would need to copy:
ao/torchao/quantization/linear_activation_quantized_tensor.py
Lines 135 to 186 in 7192edf
Not the end of the world. But for some subclasses that serve dual purpose (dyanmic + weight only, + static, + training) it can be alot of switch statements in the ops as opposed to having the base subclass + some sugar
| M, K = orig_shape[0], orig_shape[1] | ||
| data_f32 = data_f32.view(M, K // self._block_size, self._block_size) | ||
| scale_e4m3_reshaped = self._scale_e4m3.view(M, K // self._block_size, 1) | ||
| data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32) |
There was a problem hiding this comment.
is there a second rescaling somewhere?
There was a problem hiding this comment.
my two-level scaling is ont working correctly yet
Fixed
There was a problem hiding this comment.
I'm still confused. If there is per-tensor scaling, we scale by the per-tensor scale to convert to nvfp4. Do we need to undo this scaling when converting back to high precision?
There was a problem hiding this comment.
I see, it looks like it lives in self.get_scales(). Maybe we can rename the function to make it obvious that both scales are there, like self.get_blockwise_and_maybe_tensorwise_scales()?
Found in: pytorch/ao#2408 Pull Request resolved: #156461 Approved by: https://github.com/vkuzo
b08a108 to
ac989f2
Compare
ac989f2 to
e9708eb
Compare
e9708eb to
b03916b
Compare
|
Weight only fails w/ compile and bisected to: Likely from the work around to get triton to not error on e2m1 Disabling lowerings fixed the issue.
Starting bisect by getting upper bound.
Upper bound of 38 found for inductor.
Bisecting inductor - lowerings (Range: [0, 38], Midpoint: 19)
Bisecting inductor - lowerings (Range: [20, 38], Midpoint: 29)
Bisecting inductor - lowerings (Range: [30, 38], Midpoint: 34)
Bisecting inductor - lowerings (Range: [35, 38], Midpoint: 36)
Bisecting inductor - lowerings (Range: [35, 36], Midpoint: 35)
Binary search completed for inductor - lowerings. The bisect number is 36. Debug info: convert_element_type_5
Bisection status deleted.
Bisection result: BisectionResult(backend='inductor', subsystem='lowerings', bisect_number=36, debug_info='convert_element_type_5')
6. Testing inductor config workarounds for WEIGHT_ONLY:
{'inductor.coordinate_descent_tuning': False} ERROR
{'inductor.force_fuse_int_mm_with_mul': False} ERROR
{'inductor.post_grad_passes': False} ERROR
{'inductor.pattern_matcher': False} ERROR
{'inductor.epilogue_fusion': False} ERROR
{'inductor.max_autotune': False} ERROR
{'triton.autotune_pointwise': False} ✗ 3.1dB
{'inductor.benchmark_kernel': False} ERROR
{'inductor.aggressive_fusion': False} ERROR
7. Testing other compile backends:
Backend 'eager': SQNR = 20.00 dB ✓
Backend 'aot_eager': SQNR = 20.00 dB ✓
skipping cudagraphs due to skipping cudagraphs due to cpu device (_tensor_constant0). Found from :
File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_dynamo/external_utils.py", line 70, in inner
return fn(*args, **kwargs)
Backend 'cudagraphs': SQNR = 20.00 dB ✓ |
|
@vkuzo updated to use the mm_config |
|
|
||
|
|
||
| @implements([torch.nn.functional.linear, aten.linear.default]) | ||
| def nvfp4_linear(func, types, args, kwargs): |
There was a problem hiding this comment.
curious on why we need both linear and mm flavors instead of picking one? I thought linear was for torch function and mm variants for torch dispatch?
There was a problem hiding this comment.
We only need the mm and addmm in most circumstances, however if you run under inference mdoe that is pre-dispatch you end up see these linear ops
|
|
||
| data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX) | ||
| data_scaled = data_scaled.view(orig_shape) | ||
| data_lp = f32_to_f4_unpacked(data_scaled.float()) |
There was a problem hiding this comment.
I remember this function was being pretty slow when we wrote the emulation code, curious how performance is now? There are likely things we can do to make it faster if this is in the hot path.
There was a problem hiding this comment.
Yeah @gaunerst has a nice triton kernel for doing the cast + pack that I am going to try in a follow up PR
Stacked PRs:
Add NVFP4 Inference flow
Details:
I kept this separate for MX but realistically we should probably merge the two. Basic support for blocksize 16 + e4m3 scales.
Double Quant Update
Ignore previous comments, the double quant is actually really similar to NF4 where you just scale the fp32 scales prior to casting to e4m3 to try and reduce
scale quant error.I have that implemented now in the Nvfp4 code if a tesor_scale is given, just need to figure out how to thread to cublas param
scale_in_dor how we want to expose this. We currently don't expose the C matrix to the Python API so we could use alpha as @gau-nernst pointed out to me, however we dont expose alpha either 🙃. However if we wanted to use alpha we would need the value on the host, the sync would likely rule out this option. I might keep this double quant on hold until we have the public api, since I am thinking about addingscaleoverloads to addmm. However I read the cublas docs many times and it feels as though passing to scale result should work since we don't set the d_mode and its default value should work.Early Perf
No double quant here
which is even worse than mxfp4..., will profile later
Micro Bench
LLama 70B mlp no TP:
Diffusers
Errors
Annoyingly we are getting an error due to the
view as fp4x2+ packing https://fburl.com/cd92w431 because this is trying to be bitcast iside inside triton kernel which is very annoying. Not sure how this didn't show up until vllm / w/ mxfp4^ similar to this: triton-lang/triton#6054 but make the same changes in _inductor/utils.py as we did for float8em0
Numerics
Script: https://gist.github.com/drisspg/4024ed055a6db911495102614c674c4c -> still emulating till we fix this bug in cublaslt bindings

Double quant really helps w/ tensor that have very small amax values, likely by reducing the amount of underflows will verify:
Flow