[sparse] Migrate Float8SemiSparseTensor off of AQT#3361
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3361
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 2c7f730 with merge base 7035fb7 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
||
| """Use torchao cutlass kernel for fp8 + 2:4 sparse mm, requires building torchao with CUDA | ||
| """ | ||
| SPARSE_CUTLASS = "sparse_cutlass" |
There was a problem hiding this comment.
my understanding is this is a new packing format, why is this a new kernel preference?
There was a problem hiding this comment.
sparse_cutlass vs sparse_cusparselt/hipsparselt is something we will need for AMD support coming up next half, which sounds like kernel preference to me (decide which op to use).
But if this is more a general thing and packing_format is the more specific way to decide op dispatch I am fine with using that as well.
There was a problem hiding this comment.
@jcaip , would be good to specify if the data format will be different and kernels different, or if data format is the same and kernels different.
| kernel_choice = "sparse_cutlass" | ||
| elif kernel_preference == KernelPreference.SPARSE_CUTLASS: | ||
| # if user explicitly chose FBGEMM kernel preference, we'll also use fbgemm kernel | ||
| assert is_sm_at_least_90(), ( | ||
| "Specified sparse_cutlass kernel and hardware is not >= SM 9.0 (>= H100)" | ||
| ) | ||
| kernel_choice = "sparse_cutlass" |
There was a problem hiding this comment.
if "sparse_cutlass" is the only option, then I don't think we are dealing with a kernel preference here?
| from .float8_tensor import QuantizeTensorToFloat8Kwargs | ||
|
|
||
|
|
||
| class Float8SemiSparseTensor(TorchAOBaseTensor): |
There was a problem hiding this comment.
is there a more descriptive name, something like Float8With2By4SparsityTensor?
| dtype: Optional[torch.dtype] = None, | ||
| ): | ||
| super().__init__() | ||
| self.sparse_quantized_data = sparse_quantized_data |
There was a problem hiding this comment.
how about qdata to match other tensors
There was a problem hiding this comment.
We can do sparse_dqata? but I think just qdata is a bit confusing since qdata is split between the specified values and metadata
| """ | ||
| Sparse packing formats for 2:4 sparsity + FP8 quantization | ||
| """ | ||
| SPARSE_CUTLASS = "sparse_cutlass" |
There was a problem hiding this comment.
The intent is for the sparse tensor to use OPAQUE, and you can keep these formats internal to your workflow
| SPARSE_CUTLASS = "sparse_cutlass" | ||
|
|
||
| """ | ||
| SPARSE_CUSPARSELT will pack the quantized_data into a single tensor, sparse_qdata, which contains both the specified values and appends the metadata. | ||
| This packing format will dispatch to `_cslt_sparse_mm`, which does not fuse per-row scaling into the matmul. | ||
| """ | ||
| SPARSE_CUSPARSELT = "sparse_cusparselt" |
There was a problem hiding this comment.
should these belong to Float8PackingFormat? we structure these by "dtype" currently
There was a problem hiding this comment.
I think Float8PackingFormat was removed recently, so we can't reuse.
There was a problem hiding this comment.
it's fine to add again for this I think
There was a problem hiding this comment.
cc @bbeckca left some comments, but this should be pretty close to landing.
I think randy has some scripts for running the ads workflows, we should test the default version bump on those before we land.
One quick heads up on running this locally - you'll need to run build the custom kernel withUSE_CPP=1 pip install -e . --no-build-isolation, otherwise you won't have op support.
| SPARSE_CUTLASS = "sparse_cutlass" | ||
|
|
||
| """ | ||
| SPARSE_CUSPARSELT will pack the quantized_data into a single tensor, sparse_qdata, which contains both the specified values and appends the metadata. | ||
| This packing format will dispatch to `_cslt_sparse_mm`, which does not fuse per-row scaling into the matmul. | ||
| """ | ||
| SPARSE_CUSPARSELT = "sparse_cusparselt" |
There was a problem hiding this comment.
I think Float8PackingFormat was removed recently, so we can't reuse.
| from .float8_tensor import QuantizeTensorToFloat8Kwargs | ||
|
|
||
|
|
||
| class Sparse2x4Float8Tensor(TorchAOBaseTensor): |
There was a problem hiding this comment.
the way we structure these is one to one correspondance between packing format and tensors actually, for example:
ao/torchao/quantization/quant_api.py
Lines 895 to 934 in 8806b02
There was a problem hiding this comment.
ok ill just make this the CUTLASS format and open a new PR for cuSPARSELt .
|
I think we should split the tensor into 2, one for each packing format |
| return out | ||
|
|
||
|
|
||
| @implements(aten.clone.default) |
There was a problem hiding this comment.
this should be supported already by TorchAOBaseTensor I think?
| args[1], | ||
| args[2] if len(args) > 2 else None, | ||
| ) | ||
| from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8 |
There was a problem hiding this comment.
nit: this is already imported in the top of the file
|
main thing is to move the config to the Float8DynamicActivationFloat8WeightConfig I think, others are mostly nits |
| "Int8Tensor", | ||
| "QuantizeTensorToInt8Kwargs", | ||
| "Float8Tensor", | ||
| "Sparse2x4Float8TensorCUTLASS", |
There was a problem hiding this comment.
nit: Sparse2x4CUTLASSFloat8Tensor
| if packing_format == Float8PackingFormat.PLAIN and isinstance( | ||
| weight_granularity, PerRow | ||
| ): | ||
| assert weight.dtype == torch.bfloat16, ( |
There was a problem hiding this comment.
probably better to move and duplicate this code to the config.version == 1 and config.version == 2 + packing_format == Float8PackingFormat.PLAIN branches for now, (current modification will skip this assertion for v1)
There was a problem hiding this comment.
looks like v1 config just got deleted, so this should be simpler now
| model = torch.compile(model) | ||
| sparse_result = model(input) | ||
|
|
||
| torch.testing.assert_close( |
There was a problem hiding this comment.
nit: I think we can use sqnr, and also compare the PLAIN and SPARSE_CUTLASS format
| cloned = model.weight.clone().dequantize() | ||
|
|
||
| for o, c in zip(original, cloned): | ||
| torch.testing.assert_close(o, c, atol=0.0, rtol=0.0) |
| sparse_result = model(input) | ||
| sparse_sqnr = compute_error(baseline_result, sparse_result) | ||
|
|
||
| self.assertEqual(dense_sqnr, sparse_sqnr) |
There was a problem hiding this comment.
would these be the same?
I meant we just compare dense_result and sparse_result and SQNR should be high
There was a problem hiding this comment.
dense_result and sparse_result should be numerically identical because we mask the weights ahead of time. the differences are just because of compile.
I think thats a better check than SQNR between dense unmasked and sparse.
There was a problem hiding this comment.
I see, then should we compare them directly?
I'm unsure what does "making sure sqnr are equivalent" means, is it less strict than equal? something similar to assert allclose?
There was a problem hiding this comment.
yeah, i will change this to the old test. checking the sqnr is equivalent is less strict than that.
There was a problem hiding this comment.
nit: I think the file name should be aligned with the tensor name
| from .float8_tensor import QuantizeTensorToFloat8Kwargs | ||
|
|
||
|
|
||
| class Sparse2x4CUTLASSFloat8Tensor(TorchAOBaseTensor): |
There was a problem hiding this comment.
nit: same here, file name should be aligned with the tensor name Sparse2x4CUTLASSFloat8Tensor
| ) | ||
|
|
||
|
|
||
| class TestSparse2x4Float8Tensor(common_utils.TestCase): |
This PR migrates
Float8DynamicActivationFloat8SemiSparseWeighConfigoff of using the AQT CutlassSemiSparseLayout subclass.The old AQT flow can still be used by passing
version=1into the configTesting: