Skip to content

[sparse] Migrate Float8SemiSparseTensor off of AQT#3361

Merged
jcaip merged 25 commits intomainfrom
jcaip/fp8-semi-sparse-migration
Dec 22, 2025
Merged

[sparse] Migrate Float8SemiSparseTensor off of AQT#3361
jcaip merged 25 commits intomainfrom
jcaip/fp8-semi-sparse-migration

Conversation

@jcaip
Copy link
Copy Markdown
Contributor

@jcaip jcaip commented Nov 20, 2025

This PR migrates Float8DynamicActivationFloat8SemiSparseWeighConfig off of using the AQT CutlassSemiSparseLayout subclass.

The old AQT flow can still be used by passing version=1 into the config

Testing:

pytest test/quantization/quantize_/workflows/float8/test_float8_semi_sparse_tensor.py

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Nov 20, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3361

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 2c7f730 with merge base 7035fb7 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 20, 2025
@jcaip jcaip added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Nov 20, 2025
@jcaip jcaip requested a review from jerryzh168 November 20, 2025 18:21
@meta-codesync
Copy link
Copy Markdown

meta-codesync Bot commented Nov 20, 2025

@jcaip has imported this pull request. If you are a Meta employee, you can view this in D87560869.


"""Use torchao cutlass kernel for fp8 + 2:4 sparse mm, requires building torchao with CUDA
"""
SPARSE_CUTLASS = "sparse_cutlass"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my understanding is this is a new packing format, why is this a new kernel preference?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sparse_cutlass vs sparse_cusparselt/hipsparselt is something we will need for AMD support coming up next half, which sounds like kernel preference to me (decide which op to use).

But if this is more a general thing and packing_format is the more specific way to decide op dispatch I am fine with using that as well.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jcaip , would be good to specify if the data format will be different and kernels different, or if data format is the same and kernels different.

Comment on lines +172 to +178
kernel_choice = "sparse_cutlass"
elif kernel_preference == KernelPreference.SPARSE_CUTLASS:
# if user explicitly chose FBGEMM kernel preference, we'll also use fbgemm kernel
assert is_sm_at_least_90(), (
"Specified sparse_cutlass kernel and hardware is not >= SM 9.0 (>= H100)"
)
kernel_choice = "sparse_cutlass"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if "sparse_cutlass" is the only option, then I don't think we are dealing with a kernel preference here?

from .float8_tensor import QuantizeTensorToFloat8Kwargs


class Float8SemiSparseTensor(TorchAOBaseTensor):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a more descriptive name, something like Float8With2By4SparsityTensor?

dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.sparse_quantized_data = sparse_quantized_data
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about qdata to match other tensors

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do sparse_dqata? but I think just qdata is a bit confusing since qdata is split between the specified values and metadata

"""
Sparse packing formats for 2:4 sparsity + FP8 quantization
"""
SPARSE_CUTLASS = "sparse_cutlass"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intent is for the sparse tensor to use OPAQUE, and you can keep these formats internal to your workflow

Comment on lines +57 to +63
SPARSE_CUTLASS = "sparse_cutlass"

"""
SPARSE_CUSPARSELT will pack the quantized_data into a single tensor, sparse_qdata, which contains both the specified values and appends the metadata.
This packing format will dispatch to `_cslt_sparse_mm`, which does not fuse per-row scaling into the matmul.
"""
SPARSE_CUSPARSELT = "sparse_cusparselt"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these belong to Float8PackingFormat? we structure these by "dtype" currently

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Float8PackingFormat was removed recently, so we can't reuse.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's fine to add again for this I think

Copy link
Copy Markdown
Contributor Author

@jcaip jcaip left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @bbeckca left some comments, but this should be pretty close to landing.

I think randy has some scripts for running the ads workflows, we should test the default version bump on those before we land.

One quick heads up on running this locally - you'll need to run build the custom kernel withUSE_CPP=1 pip install -e . --no-build-isolation, otherwise you won't have op support.

Comment thread torchao/quantization/quant_api.py Outdated
Comment on lines +57 to +63
SPARSE_CUTLASS = "sparse_cutlass"

"""
SPARSE_CUSPARSELT will pack the quantized_data into a single tensor, sparse_qdata, which contains both the specified values and appends the metadata.
This packing format will dispatch to `_cslt_sparse_mm`, which does not fuse per-row scaling into the matmul.
"""
SPARSE_CUSPARSELT = "sparse_cusparselt"
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Float8PackingFormat was removed recently, so we can't reuse.

@jcaip jcaip requested a review from vkuzo December 18, 2025 05:33
from .float8_tensor import QuantizeTensorToFloat8Kwargs


class Sparse2x4Float8Tensor(TorchAOBaseTensor):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the way we structure these is one to one correspondance between packing format and tensors actually, for example:

if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
assert int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D, (
f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, "
f"it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D currently"
)
if int4_packing_format == Int4PackingFormat.PRESHUFFLED:
new_weight = Int4PreshuffledTensor.from_hp(
weight,
block_size,
activation_dtype=torch.bfloat16,
)
return new_weight
elif int4_packing_format == Int4PackingFormat.PLAIN:
new_weight = Int4Tensor.from_hp(
weight,
block_size,
)
return new_weight
elif int4_packing_format == Int4PackingFormat.PLAIN_INT32:
new_weight = Int4PlainInt32Tensor.from_hp(
weight,
block_size,
)
return new_weight
elif int4_packing_format == Int4PackingFormat.MARLIN_SPARSE:
new_weight = Int4MarlinSparseTensor.from_hp(
weight,
block_size,
)
return new_weight
elif int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D:
new_weight = Int4TilePackedTo4dTensor.from_hp(
weight,
block_size,
int4_choose_qparams_algorithm=int4_choose_qparams_algorithm,
)
return new_weight
else:
raise ValueError(f"Unsupported int4 packing format: {int4_packing_format}")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok ill just make this the CUTLASS format and open a new PR for cuSPARSELt .

@jerryzh168
Copy link
Copy Markdown
Contributor

I think we should split the tensor into 2, one for each packing format

@jcaip jcaip requested a review from jerryzh168 December 19, 2025 19:27
Comment thread torchao/quantization/quant_api.py Outdated
Comment thread torchao/quantization/quantize_/workflows/__init__.py Outdated
Comment thread torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py Outdated
return out


@implements(aten.clone.default)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be supported already by TorchAOBaseTensor I think?

args[1],
args[2] if len(args) > 2 else None,
)
from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is already imported in the top of the file

@jerryzh168
Copy link
Copy Markdown
Contributor

main thing is to move the config to the Float8DynamicActivationFloat8WeightConfig I think, others are mostly nits

"Int8Tensor",
"QuantizeTensorToInt8Kwargs",
"Float8Tensor",
"Sparse2x4Float8TensorCUTLASS",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Sparse2x4CUTLASSFloat8Tensor

if packing_format == Float8PackingFormat.PLAIN and isinstance(
weight_granularity, PerRow
):
assert weight.dtype == torch.bfloat16, (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably better to move and duplicate this code to the config.version == 1 and config.version == 2 + packing_format == Float8PackingFormat.PLAIN branches for now, (current modification will skip this assertion for v1)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like v1 config just got deleted, so this should be simpler now

model = torch.compile(model)
sparse_result = model(input)

torch.testing.assert_close(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we can use sqnr, and also compare the PLAIN and SPARSE_CUTLASS format

cloned = model.weight.clone().dequantize()

for o, c in zip(original, cloned):
torch.testing.assert_close(o, c, atol=0.0, rtol=0.0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: assertEqual

Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, LGTM

sparse_result = model(input)
sparse_sqnr = compute_error(baseline_result, sparse_result)

self.assertEqual(dense_sqnr, sparse_sqnr)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would these be the same?

I meant we just compare dense_result and sparse_result and SQNR should be high

Copy link
Copy Markdown
Contributor Author

@jcaip jcaip Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dense_result and sparse_result should be numerically identical because we mask the weights ahead of time. the differences are just because of compile.

I think thats a better check than SQNR between dense unmasked and sparse.

Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 Dec 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, then should we compare them directly?

I'm unsure what does "making sure sqnr are equivalent" means, is it less strict than equal? something similar to assert allclose?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, i will change this to the old test. checking the sqnr is equivalent is less strict than that.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think the file name should be aligned with the tensor name

from .float8_tensor import QuantizeTensorToFloat8Kwargs


class Sparse2x4CUTLASSFloat8Tensor(TorchAOBaseTensor):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: same here, file name should be aligned with the tensor name Sparse2x4CUTLASSFloat8Tensor

)


class TestSparse2x4Float8Tensor(common_utils.TestCase):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also the test name

@jcaip jcaip merged commit 486fe0d into main Dec 22, 2025
20 of 23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants