Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164141
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 49eb776 with merge base 3288fbf ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "release notes: quantization" |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
|
Rebased |
| bool use_fast_accum, | ||
| Tensor& out) { | ||
| // Restrictions: | ||
| // A, B are FP8, scales are fp32, shape M/N for A/B |
There was a problem hiding this comment.
Some of these should be TORCH_CHECK_VALUE or raise NotImplementedError
There was a problem hiding this comment.
What's the guidance for TORCH_CHECK vs. TORCH_CHECK_VALUE vs. raising NotImplementedError? From a quick look, only TORCH_CHECK is used in this particular file to this point
There was a problem hiding this comment.
@slayton58 TORCH_CHECK raise RuntimeError, TORCH_CHECK_VALUE raises ValueError, I don't know if we have a pre-exisitng macro for raising NotImplementedError. RuntimeError's are under-specified, it could be a transient error, a user error, etc. ValueError denotes it's a user error and the args to the specified function are invalid. NotImplementedError means something could be supported in upstream PyTorch, but we haven't implemented it yet.
Most of this file should probably be changed to ValueError at some point, but that can be a subsequent PR
|
@slayton58 I like the explicit instead of inferred scaling type. Question, why do the scaling recipe strings only repeat for fp8 types but not for MX types ? e.g.,
Seems like we should stick with one pattern for consistency. Is having it repeat twice (presumably one for each operand) to future proof against future recipes which may having different scaling types for each operand? |
| - func: _scaled_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor | ||
| variants: function | ||
| dispatch: | ||
| CUDA: _scaled_mm_cuda_v2 |
There was a problem hiding this comment.
We should think about how this will look for other backends, e.g. do we ever think that CPU will support some subset of recipes? I dont think that changes anything just wanted to note
There was a problem hiding this comment.
CPU can/will certainly support some subset (or superset) of recipes - I'm honestly not sure what is / will be supported though. A CPU backend wouldn't be hard to add (or a shim to dispatch to the existing _scaled_mm_cpu backend.
aten/src/ATen/native/cuda/Blas.cpp
Outdated
| * strictly-typed enum | ||
| */ | ||
| template <class EnumType> | ||
| std::vector<EnumType> convert_int_to_enum(ArrayRef<long>& v) { |
There was a problem hiding this comment.
Might be worth pybinding the type, I think I did this w/ SDPBackend to make sure things stay consistent
There was a problem hiding this comment.
I am 100% willing to do this, I don't like the approach this second - an example would be amazing. I think I've conflated no enums through native_functions with no enums shared at all between python/C++ in pytorch..
There was a problem hiding this comment.
Spoke offline but for provenance:
Line 2392 in 11f5f65
| bool found_impl = false; | ||
| ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE; | ||
|
|
||
| for (const auto& fn_entry : scale_kernel_dispatch) { |
There was a problem hiding this comment.
should it be a scan down finding the first matching ? or should it just be a direct map from recipe pair to impl?
There was a problem hiding this comment.
The big problem I have with a map -> impl, is that different implementations can (and imo really should) have different signatures - nvfp4xnvfp4 needs global scales, swizzles passed in addition for instance. I think this might be work-around-able with std::bind to present a unified API, but that's also messy..
There was a problem hiding this comment.
Yeah that makes sense, but to confirm every accpet_fn matches against the enum right? so there is no ambiguity?
@danielvegamyhre isn't a bad idea for consistency. |
| * Both inputs must be fp8 | ||
| * A, B must only have 1 scale each, A: {Blockwise_1x128 (float), B: {Blockwise_128x128 (float) | ||
| */ | ||
| bool check_deepseek_recipe(ScalingType expected_recipe_a, |
There was a problem hiding this comment.
From deepseek? What would you prefer?
There was a problem hiding this comment.
maybe just the 128 by recipes similar to the other ones
| using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>; | ||
| using namespace std::placeholders; | ||
|
|
||
| std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8> scale_kernel_dispatch = {{ |
There was a problem hiding this comment.
I think I wrote a utility called
Line 15 in 5f15110
which helps w/ seg faults as we expand this list, dont ask my how I know 😂
| { "mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}}; | ||
|
|
||
| Tensor& | ||
| _cutlass_scaled_gemm( |
There was a problem hiding this comment.
Nit: name is weird since it goes to alot more than just cutlass
| #ifdef USE_ROCM | ||
| auto tuning_ctx = at::cuda::tunable::getTuningContext(); | ||
| if (tuning_ctx->IsTunableOpEnabled()) { | ||
| #define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \ |
There was a problem hiding this comment.
I would love the AMD path to be a lil more modular. E.g. one if def that is like USE_ROCM -> go to rocm path, otherwise go to cuda path. It makes it alot easier to grok
| auto scaling_choice_b = ScalingType::RowWise; | ||
| // | ||
| // NVIDIA's cuBLAS only started supporting row-wise scaling in version 12.9, | ||
| // and only for compute capability 9.0+. In other cases we use CUTLASS. |
There was a problem hiding this comment.
Is rowwise also not supported on sm100 like the 128x recipes?
There was a problem hiding this comment.
9.0+ = >= 9.0, so yes, it should be on Blackwell too (I think my CC nomenclature here is different to yours :D )
There was a problem hiding this comment.
Ohh I just meant, that didn't the 128x recipes ONLY work on sm90 and not blackwell? Not sure if cublas did the same thing for the rowwise recipes
|
Starting merge as part of PR stack under #164142 |
Summary: * Add new scaled-MM API to future-proof / clean-up existing code. * Scaling is explicitly described rather than infer * Swizzling of scaled must now be defined (vs. inferred) * Adds API support for multi-level scaling * Refactor dispatch logic to make it easier to add new implementations Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> [ghstack-poisoned]
|
Starting merge as part of PR stack under #164142 |
2 similar comments
|
Starting merge as part of PR stack under #164142 |
|
Starting merge as part of PR stack under #164142 |
Stack from ghstack (oldest at bottom):
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>