Introduce new W8A8-FP-CSR quantitzation API#3258
Introduce new W8A8-FP-CSR quantitzation API#3258namgyu-youn wants to merge 7 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3258
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 9 New FailuresAs of commit f5f7a17 with merge base 3577306 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@jcaip could you please check this PR? |
|
cc @namgyu-youn Can you split this into two PRs? one for int8 and one for float8? In general I don't think we want to introduce weight-only sparsity configs for int8 and float8 because we don't have mixed-dtype kernel support currently. The only kernels we have are for int8 x int8 2:4 sparse and fp8 x fp8 2:4 sparse. I would like Int8SemiSparseTensor though, but I think it should live in prototype until we have a user for it. Also cc @bbeckca who has been working on fp8xfp8 2:4 sparse tensor subclass migration in #3182. |
@jcaip if we want to move int8 2:4 sparse to prototype, then we don't need to migrate the tensor I think |
|
Okay, then I'll address only |
|
cc @namgyu-youn I talked to @bbeckca and I think your PR is closer so lets use it instead. |
|
cc @jcaip to request review, thanks. |
jcaip
left a comment
There was a problem hiding this comment.
cc @namgyu-youn
I think there's a bit of confusion on what the tensor subclass should be storing and how to do the op overload.
Please take a look at https://github.com/pytorch/ao/pull/3182/files#diff-afc7dd21d2b704181a6fd55be989426c0217a2bbfb694af9eb9746239ec462ed for the appropriate logic / ops to be called.
|
|
||
| class Float8SemiSparseTensor(TorchAOBaseTensor): | ||
| """ | ||
| W8A8-FP-CSR: float8 quantized tensor with 2:4 semi-structured sparsity layout |
There was a problem hiding this comment.
nit: comment looks wrong, CSR is compressed sparse row and it's not the sparse format used here (2:4 sparsity)
| float8_dtype: float8 dtype variant | ||
| """ | ||
|
|
||
| tensor_data_names = ["qdata", "qdata_compressed", "scale"] |
There was a problem hiding this comment.
I think quantized_sparse_data and quantized_sparse_metadata would be better here for variable names.
quantized_sparse_data holds the specified values and quantized_sparse_metadata holds the sparsity metadata.
| ) | ||
|
|
||
| @property | ||
| def qdata_fp8(self): |
| w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0) | ||
|
|
||
| # Check for all-zero (sparsity=1) tensor | ||
| if w_sparse.abs().max() == 0: |
There was a problem hiding this comment.
I think this should be supported actually? I don't see why we should error here.
| with torch.no_grad(): | ||
| w_sparse = w.clone() | ||
|
|
||
| pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2] |
There was a problem hiding this comment.
you can use this util:
Line 101 in 315e9b4
| # Store fp8 data in both dense and compressed formats | ||
| fp8_data_fp16 = fp8_data.to(torch.float16) | ||
|
|
||
| fp8_compressed = to_sparse_semi_structured(fp8_data_fp16) |
There was a problem hiding this comment.
We should use the torchao cutlass packing kernels here, not the default torch ones:
| if not (scale > 0).all(): | ||
| raise ValueError(f"Scale contains non-positive values: min={scale.min()}") | ||
|
|
||
| scale_expanded = scale.unsqueeze(1) |
There was a problem hiding this comment.
Is this different from Float8Tensor, can we use the same scale calculation logic as we use there?
| fp8_compressed = to_sparse_semi_structured(fp8_data_fp16) | ||
|
|
||
| return cls( | ||
| fp8_data, # dense for dequantization |
There was a problem hiding this comment.
we shouldn't be storing both the dense data and the compressed data, we should be storing the sparse specified values and the sparse metadata.
| float8_dtype=float8_dtype, | ||
| ) | ||
|
|
||
| def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: |
There was a problem hiding this comment.
we should multiply by identity matrix to dequantize, like we do here:
| x_vals_fp8 = scaled_x.to(torch.float8_e4m3fn) | ||
|
|
||
| # MatMul | ||
| x_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input( |
There was a problem hiding this comment.
We should use the torchao cutlass fp8 kernels, which fuse in scale multiplication here.
See
There was a problem hiding this comment.
delete? we don't want this to be in prototype I think
There was a problem hiding this comment.
this should be add to the init file without the prototype in path
There was a problem hiding this comment.
also need to add to Float8DynamicActivationFloat8WeightConfig?
@jcaip Thanks a lot for the comprehensive review. I didn't know there was an already opened PR (#3182), and I found my implementation is quite far away (mostly ops, kernel). Therefore, the right move seems to be reopening #3182 and letting me update it after the last review. Is it okay to go with this? |
|
@namgyu-youn I think it'll be easier for me to just migrate this over, mind if I take over the PR? #3182 is also quite far from landing. |
|
@pytorchbot label "sparsity" |
Summary:
Introduce new W8A8-FP-CSR quantization API,
Float8SemiSparseTensor, which specializes in semi-sparse pattern using cuSPARSELt accelerations (https://docs.nvidia.com/cuda/cusparselt/)Related Issue/PR: #2752
Future Plan:
This PR only introduces core operations (quantization/dequantization). For better API support, we have to introduce tensor utility operations like indexing and slicing.
Test Plan:
test/prototype/quantization/quantize_/float8/test_float8_semisparse_tensor.py