Skip to content

Add scaled_mm python API, test#164142

Closed
slayton58 wants to merge 30 commits intogh/slayton58/17/basefrom
gh/slayton58/17/head
Closed

Add scaled_mm python API, test#164142
slayton58 wants to merge 30 commits intogh/slayton58/17/basefrom
gh/slayton58/17/head

Conversation

@slayton58
Copy link
Contributor

@slayton58 slayton58 commented Sep 29, 2025

Stack from ghstack (oldest at bottom):

Summary:

  • Add torch.nn.functional.scaled_mm as an abstraction around the C++
    methods
  • Wraps torch._scaled_mm_v2 API by default, but user can force use of
    the older torch._scaled_mm interface.
  • Scaled MM tests now run on the new API

Test Plan:

pytest test/test_scaled_matmul_cuda.py

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164142

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 946102e with merge base 3288fbf (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: quantization release notes category label Sep 29, 2025
slayton58 added a commit that referenced this pull request Sep 29, 2025
Summary:

* Add `torch.quantization.scaled_mm` as an abstraction around the C++
  methods
* Wraps `torch._scaled_mm_v2` API by default, but user can force use of
  the older `torch._scaled_mm` interface.
* Scaled MM tests now run on the new API

Test Plan:

`pytest test/test_scaled_matmul_cuda.py`

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>
ghstack-source-id: 6f762a3
Pull-Request: #164142
@slayton58
Copy link
Contributor Author

@pytorchbot rebase

Quantization
============

.. automodule:: torch.quantization.scaled_mm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i'd probably not add stuff to torch.quantization, as it's deprecated and the name is ambiguous. Can we have a more specific name for scaled_mm, or just keep it together with the non-scaled gemms in terms of naming / docs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with this going wherever - It's in torch.quantization because there wasn't a clearly better place for it to go - it's not a functional version of a torch.nn op, so torch.nn.functional didn't seem like a good place

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Post some offline discussion, moved to torch.nn.functional.scaled_mm

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Sep 29, 2025
Summary:

* Add `torch.quantization.scaled_mm` as an abstraction around the C++
  methods
* Wraps `torch._scaled_mm_v2` API by default, but user can force use of
  the older `torch._scaled_mm` interface.
* Scaled MM tests now run on the new API

Test Plan:

`pytest test/test_scaled_matmul_cuda.py`

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>
ghstack-source-id: e727cbc
Pull-Request: #164142
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/slayton58/17/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/164142)

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
test/test_fx.py Outdated
"adaptive_avg_pool3d": LEN_ERROR,
"adaptive_max_pool2d_with_indices": LEN_ERROR,
"adaptive_max_pool3d_with_indices": LEN_ERROR,
"scaled_mm": LEN_ERROR,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC what does this signify?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

len(...) is non-traceable by default (errors out during the test), and scaled_mm uses len for some list processing - this prevents the test from running.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that this new op won't be compilable ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so - I see 2 immediate ways to fix this:

  • remove the error-checking for the deprecated fallback path and just pass the first scale etc. from the passed list, and rely on erroring out in C++ with an invalid scaling recipe
  • Remove the deprecated fallback path entirely

if len(kwargs) > 0:
raise RuntimeError("kwargs contains unexpected entries, ", kwargs.keys())

if use_deprecated_api:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC why this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The deprecated_api path? It allows for a back-compat path to isolate any differences from the implementations. I found it incredibly useful for debugging, I guess it could be removed if desired..

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Summary:

* Add `torch.nn.functional.scaled_mm` as an abstraction around the C++
  methods
* Wraps `torch._scaled_mm_v2` API by default, but user can force use of
  the older `torch._scaled_mm` interface.
* Scaled MM tests now run on the new API

Test Plan:

`pytest test/test_scaled_matmul_cuda.py`

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>


[ghstack-poisoned]
Summary:

* Add `torch.nn.functional.scaled_mm` as an abstraction around the C++
  methods
* Wraps `torch._scaled_mm_v2` API by default, but user can force use of
  the older `torch._scaled_mm` interface.
* Scaled MM tests now run on the new API

Test Plan:

`pytest test/test_scaled_matmul_cuda.py`

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>


[ghstack-poisoned]
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in a follow up we should have some more robust composability testing, the immediate things that come to mind is writing meta_registrations.py entry adding some compile tests and then making sure we (for now) fallback in inductor and then ultimately rewire the lowerings.

cc @eellison is the best way to get most of this testing still through common_methods_invocations.py?

Summary:

* Add `torch.nn.functional.scaled_mm` as an abstraction around the C++
  methods
* Wraps `torch._scaled_mm_v2` API by default, but user can force use of
  the older `torch._scaled_mm` interface.
* Scaled MM tests now run on the new API

Test Plan:

`pytest test/test_scaled_matmul_cuda.py`

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>


[ghstack-poisoned]
@slayton58
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / win-vs2022-cuda12.6-py3 / build

Details for Dev Infra team Raised by workflow job

Summary:

* Add `torch.nn.functional.scaled_mm` as an abstraction around the C++
  methods
* Wraps `torch._scaled_mm_v2` API by default, but user can force use of
  the older `torch._scaled_mm` interface.
* Scaled MM tests now run on the new API

Test Plan:

`pytest test/test_scaled_matmul_cuda.py`

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>


[ghstack-poisoned]
@slayton58
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 2, 5, linux.g6.4xlarge.experimental.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

[ghstack-poisoned]
@slayton58
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / test (default, 2, 2, linux.rocm.gpu.gfx942.1), trunk / linux-jammy-rocm-py3.10 / test (default, 1, 2, linux.rocm.gpu.gfx942.1)

Details for Dev Infra team Raised by workflow job

@slayton58
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: quantization release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants