Add scaled_mm python API, test#164142
Add scaled_mm python API, test#164142slayton58 wants to merge 30 commits intogh/slayton58/17/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164142
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 946102e with merge base 3288fbf ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: * Add `torch.quantization.scaled_mm` as an abstraction around the C++ methods * Wraps `torch._scaled_mm_v2` API by default, but user can force use of the older `torch._scaled_mm` interface. * Scaled MM tests now run on the new API Test Plan: `pytest test/test_scaled_matmul_cuda.py` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 6f762a3 Pull-Request: #164142
|
@pytorchbot rebase |
docs/source/quantization.rst
Outdated
| Quantization | ||
| ============ | ||
|
|
||
| .. automodule:: torch.quantization.scaled_mm |
There was a problem hiding this comment.
nit: i'd probably not add stuff to torch.quantization, as it's deprecated and the name is ambiguous. Can we have a more specific name for scaled_mm, or just keep it together with the non-scaled gemms in terms of naming / docs?
There was a problem hiding this comment.
I'm fine with this going wherever - It's in torch.quantization because there wasn't a clearly better place for it to go - it's not a functional version of a torch.nn op, so torch.nn.functional didn't seem like a good place
There was a problem hiding this comment.
Post some offline discussion, moved to torch.nn.functional.scaled_mm
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Summary: * Add `torch.quantization.scaled_mm` as an abstraction around the C++ methods * Wraps `torch._scaled_mm_v2` API by default, but user can force use of the older `torch._scaled_mm` interface. * Scaled MM tests now run on the new API Test Plan: `pytest test/test_scaled_matmul_cuda.py` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: e727cbc Pull-Request: #164142
|
Successfully rebased |
test/test_fx.py
Outdated
| "adaptive_avg_pool3d": LEN_ERROR, | ||
| "adaptive_max_pool2d_with_indices": LEN_ERROR, | ||
| "adaptive_max_pool3d_with_indices": LEN_ERROR, | ||
| "scaled_mm": LEN_ERROR, |
There was a problem hiding this comment.
OOC what does this signify?
There was a problem hiding this comment.
len(...) is non-traceable by default (errors out during the test), and scaled_mm uses len for some list processing - this prevents the test from running.
There was a problem hiding this comment.
Does this mean that this new op won't be compilable ?
There was a problem hiding this comment.
I guess so - I see 2 immediate ways to fix this:
- remove the error-checking for the deprecated fallback path and just pass the first scale etc. from the passed list, and rely on erroring out in C++ with an invalid scaling recipe
- Remove the deprecated fallback path entirely
torch/nn/functional.py
Outdated
| if len(kwargs) > 0: | ||
| raise RuntimeError("kwargs contains unexpected entries, ", kwargs.keys()) | ||
|
|
||
| if use_deprecated_api: |
There was a problem hiding this comment.
The deprecated_api path? It allows for a back-compat path to isolate any differences from the implementations. I found it incredibly useful for debugging, I guess it could be removed if desired..
Summary: * Add `torch.nn.functional.scaled_mm` as an abstraction around the C++ methods * Wraps `torch._scaled_mm_v2` API by default, but user can force use of the older `torch._scaled_mm` interface. * Scaled MM tests now run on the new API Test Plan: `pytest test/test_scaled_matmul_cuda.py` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> [ghstack-poisoned]
Summary: * Add `torch.nn.functional.scaled_mm` as an abstraction around the C++ methods * Wraps `torch._scaled_mm_v2` API by default, but user can force use of the older `torch._scaled_mm` interface. * Scaled MM tests now run on the new API Test Plan: `pytest test/test_scaled_matmul_cuda.py` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> [ghstack-poisoned]
drisspg
left a comment
There was a problem hiding this comment.
Looks good in a follow up we should have some more robust composability testing, the immediate things that come to mind is writing meta_registrations.py entry adding some compile tests and then making sure we (for now) fallback in inductor and then ultimately rewire the lowerings.
cc @eellison is the best way to get most of this testing still through common_methods_invocations.py?
Summary: * Add `torch.nn.functional.scaled_mm` as an abstraction around the C++ methods * Wraps `torch._scaled_mm_v2` API by default, but user can force use of the older `torch._scaled_mm` interface. * Scaled MM tests now run on the new API Test Plan: `pytest test/test_scaled_matmul_cuda.py` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / win-vs2022-cuda12.6-py3 / build Details for Dev Infra teamRaised by workflow job |
Summary: * Add `torch.nn.functional.scaled_mm` as an abstraction around the C++ methods * Wraps `torch._scaled_mm_v2` API by default, but user can force use of the older `torch._scaled_mm` interface. * Scaled MM tests now run on the new API Test Plan: `pytest test/test_scaled_matmul_cuda.py` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 2, 5, linux.g6.4xlarge.experimental.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / test (default, 2, 2, linux.rocm.gpu.gfx942.1), trunk / linux-jammy-rocm-py3.10 / test (default, 1, 2, linux.rocm.gpu.gfx942.1) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Summary:
torch.nn.functional.scaled_mmas an abstraction around the C++methods
torch._scaled_mm_v2API by default, but user can force use ofthe older
torch._scaled_mminterface.Test Plan:
pytest test/test_scaled_matmul_cuda.pyReviewers:
Subscribers:
Tasks:
Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>