Skip to content

Add _scaled_mm_v2 API#164141

Closed
slayton58 wants to merge 10 commits intogh/slayton58/16/basefrom
gh/slayton58/16/head
Closed

Add _scaled_mm_v2 API#164141
slayton58 wants to merge 10 commits intogh/slayton58/16/basefrom
gh/slayton58/16/head

Conversation

@slayton58
Copy link
Contributor

@slayton58 slayton58 commented Sep 29, 2025

Stack from ghstack (oldest at bottom):

Summary:

  • Add new scaled-MM API to future-proof / clean-up existing code.
  • Scaling is explicitly described rather than infer
  • Swizzling of scaled must now be defined (vs. inferred)
  • Adds API support for multi-level scaling
  • Refactor dispatch logic to make it easier to add new implementations

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164141

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 49eb776 with merge base 3288fbf (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@slayton58
Copy link
Contributor Author

@pytorchbot label "release notes: quantization"

@pytorch-bot pytorch-bot bot added the release notes: quantization release notes category label Sep 29, 2025
@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Rebased gh/slayton58/17/orig onto refs/remotes/origin/viable/strict because #164142 was rebased, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/164141)

bool use_fast_accum,
Tensor& out) {
// Restrictions:
// A, B are FP8, scales are fp32, shape M/N for A/B
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of these should be TORCH_CHECK_VALUE or raise NotImplementedError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the guidance for TORCH_CHECK vs. TORCH_CHECK_VALUE vs. raising NotImplementedError? From a quick look, only TORCH_CHECK is used in this particular file to this point

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@slayton58 TORCH_CHECK raise RuntimeError, TORCH_CHECK_VALUE raises ValueError, I don't know if we have a pre-exisitng macro for raising NotImplementedError. RuntimeError's are under-specified, it could be a transient error, a user error, etc. ValueError denotes it's a user error and the args to the specified function are invalid. NotImplementedError means something could be supported in upstream PyTorch, but we haven't implemented it yet.

Most of this file should probably be changed to ValueError at some point, but that can be a subsequent PR

[ghstack-poisoned]
@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Oct 2, 2025

@slayton58 I like the explicit instead of inferred scaling type.

Question, why do the scaling recipe strings only repeat for fp8 types but not for MX types ? e.g.,

  • ScaledGemmImplementation::TENSORWISE_TENSORWISE => "tensorwise_tensorwise"
  • ScaledGemmImplementation::MXFP8_MXFP8 => "mxfp8"

Seems like we should stick with one pattern for consistency. Is having it repeat twice (presumably one for each operand) to future proof against future recipes which may having different scaling types for each operand?

- func: _scaled_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
variants: function
dispatch:
CUDA: _scaled_mm_cuda_v2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should think about how this will look for other backends, e.g. do we ever think that CPU will support some subset of recipes? I dont think that changes anything just wanted to note

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU can/will certainly support some subset (or superset) of recipes - I'm honestly not sure what is / will be supported though. A CPU backend wouldn't be hard to add (or a shim to dispatch to the existing _scaled_mm_cpu backend.

* strictly-typed enum
*/
template <class EnumType>
std::vector<EnumType> convert_int_to_enum(ArrayRef<long>& v) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth pybinding the type, I think I did this w/ SDPBackend to make sure things stay consistent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am 100% willing to do this, I don't like the approach this second - an example would be amazing. I think I've conflated no enums through native_functions with no enums shared at all between python/C++ in pytorch..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spoke offline but for provenance:

py::enum_<sdp::SDPBackend>(

bool found_impl = false;
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;

for (const auto& fn_entry : scale_kernel_dispatch) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be a scan down finding the first matching ? or should it just be a direct map from recipe pair to impl?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The big problem I have with a map -> impl, is that different implementations can (and imo really should) have different signatures - nvfp4xnvfp4 needs global scales, swizzles passed in addition for instance. I think this might be work-around-able with std::bind to present a unified API, but that's also messy..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that makes sense, but to confirm every accpet_fn matches against the enum right? so there is no ambiguity?

@slayton58
Copy link
Contributor Author

@slayton58 I like the explicit instead of inferred scaling type.

Question, why do the scaling recipe strings only repeat for fp8 types but not for MX types ? e.g.,

  • ScaledGemmImplementation::TENSORWISE_TENSORWISE => "tensorwise_tensorwise"
  • ScaledGemmImplementation::MXFP8_MXFP8 => "mxfp8"

Seems like we should stick with one pattern for consistency. Is having it repeat twice (presumably one for each operand) to future proof against future recipes which may having different scaling types for each operand?

@danielvegamyhre
I see what you mean here - I was writing in the sense that "mxfp8" is (to me at least) a clearly-defined combination of input types - calling "mxfp8" gives all the information one needs for both arguments, vs. having to explicitly repeat for both inputs. However, the intention was to provide an interface that was extensible enough for potential future combinations (mxfp8 activations x mxfp4 weights is one that I've seen put forward for instance), and in that sense, repeating for all inputs, so

ScaledGemmImplementation::MXFP8_MXFP8 => "mxfp8_mxfp8"

isn't a bad idea for consistency.

[ghstack-poisoned]
* Both inputs must be fp8
* A, B must only have 1 scale each, A: {Blockwise_1x128 (float), B: {Blockwise_128x128 (float)
*/
bool check_deepseek_recipe(ScalingType expected_recipe_a,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit lets change the name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From deepseek? What would you prefer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just the 128 by recipes similar to the other ones

using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
using namespace std::placeholders;

std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8> scale_kernel_dispatch = {{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I wrote a utility called

return {{std::forward<T>(t)...}};

which helps w/ seg faults as we expand this list, dont ask my how I know 😂

{ "mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};

Tensor&
_cutlass_scaled_gemm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: name is weird since it goes to alot more than just cutlass

#ifdef USE_ROCM
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love the AMD path to be a lil more modular. E.g. one if def that is like USE_ROCM -> go to rocm path, otherwise go to cuda path. It makes it alot easier to grok

auto scaling_choice_b = ScalingType::RowWise;
//
// NVIDIA's cuBLAS only started supporting row-wise scaling in version 12.9,
// and only for compute capability 9.0+. In other cases we use CUTLASS.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is rowwise also not supported on sm100 like the 128x recipes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9.0+ = >= 9.0, so yes, it should be on Blackwell too (I think my CC nomenclature here is different to yours :D )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh I just meant, that didn't the 128x recipes ONLY work on sm90 and not blackwell? Not sure if cublas did the same thing for the rowwise recipes

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment soup but LGTM

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #164142

Summary:

* Add new scaled-MM API to future-proof / clean-up existing code.
* Scaling is explicitly described rather than infer
* Swizzling of scaled must now be defined (vs. inferred)
* Adds API support for multi-level scaling
* Refactor dispatch logic to make it easier to add new implementations

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>


[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #164142

2 similar comments
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #164142

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #164142

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants