Skip to content

Commit 0b62465

Browse files
Revert "Refine alignment check along dynamic dimension for grouped MMs (#155466)"
This reverts commit 830a335. Reverted #155466 on behalf of https://github.com/atalman due to breaks internal builds ([comment](#155466 (comment)))
1 parent fec8af8 commit 0b62465

5 files changed

Lines changed: 53 additions & 136 deletions

File tree

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
#include <ATen/ops/copy_native.h>
3737
#include <ATen/ops/dot_native.h>
3838
#include <ATen/ops/empty.h>
39-
#include <ATen/ops/empty_strided.h>
4039
#include <ATen/ops/gelu.h>
4140
#include <ATen/ops/max.h>
4241
#include <ATen/ops/mm_native.h>
@@ -1482,49 +1481,29 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
14821481
}
14831482

14841483
namespace {
1485-
at::Tensor create_grouped_gemm_output_tensor(const Tensor& mat_a,
1484+
c10::SmallVector<int64_t, 3> compute_grouped_gemm_output_size(const Tensor& mat_a,
14861485
const Tensor& mat_b,
1487-
const std::optional<at::Tensor>& offs,
1488-
std::optional<c10::ScalarType> out_dtype
1486+
const std::optional<at::Tensor>& offs
14891487
) {
1490-
c10::SmallVector<int64_t, 3> out_size;
14911488
const bool a_is_2d = mat_a.dim() == 2;
14921489
const bool b_is_2d = mat_b.dim() == 2;
14931490
if (a_is_2d) {
14941491
if (b_is_2d) {
1495-
out_size = {offs->size(0), mat_a.size(0), mat_b.size(1)};
1492+
return {offs->size(0), mat_a.size(0), mat_b.size(1)};
14961493
} else {
14971494
TORCH_CHECK(offs->size(0) == mat_b.size(0), "matrix batch sizes have to match");
1498-
out_size = {mat_a.size(0), mat_b.size(-1)};
1495+
return {mat_a.size(0), mat_b.size(-1)};
14991496
}
15001497
} else {
15011498
if (b_is_2d) {
15021499
// this case is not actually encountered for MoE gemms
15031500
TORCH_CHECK(offs->size(0) == mat_a.size(0), "matrix batch sizes have to match");
1504-
out_size = {mat_a.size(1), mat_b.size(1)};
1501+
return {mat_a.size(1), mat_b.size(1)};
15051502
} else { // regular bmm
15061503
TORCH_CHECK(mat_a.size(0) == mat_b.size(0), "batched dimension has to match");
1507-
out_size = {mat_a.size(0), mat_a.size(1), mat_b.size(-1)};
1504+
return {mat_a.size(0), mat_a.size(1), mat_b.size(-1)};
15081505
}
15091506
}
1510-
1511-
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
1512-
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
1513-
1514-
// For TMA transfers, strides of output tensor have to be either
1515-
// 1, or aligned to 16 bytes.
1516-
const auto last_dim = out_size.size() - 1;
1517-
const auto alignment = 16 / c10::elementSize(out_dtype_);
1518-
const int64_t size_padded = (out_size[last_dim] + alignment - 1) / alignment * alignment;
1519-
std::vector<int64_t> out_stride;
1520-
if (a_is_2d != b_is_2d) {
1521-
out_stride = {size_padded, 1};
1522-
} else {
1523-
out_stride = {out_size[1] * size_padded, size_padded, 1};
1524-
}
1525-
auto out = at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_));
1526-
1527-
return out;
15281507
}
15291508

15301509
bool check_valid_strides_and_return_transposed(const Tensor& mat) {
@@ -1540,7 +1519,7 @@ namespace {
15401519
TORCH_CHECK(tensor_strides[end_dim - 1] % alignment == 0, "strides should be multiple of 16 bytes");
15411520
return false;
15421521
} else {
1543-
TORCH_CHECK(false, "Invalid strides/sizes, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes");
1522+
TORCH_CHECK(false, "Tensor should have a contiguous dimension and not be self-overlapping, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes");
15441523
}
15451524
}
15461525

@@ -1648,7 +1627,11 @@ bool use_fast_accum) {
16481627
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
16491628
check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
16501629

1651-
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype);
1630+
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
1631+
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
1632+
const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs);
1633+
Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_));
1634+
16521635

16531636
at::cuda::detail::f8f8bf16_grouped_mm(
16541637
mat_a,
@@ -1684,7 +1667,6 @@ std::optional<c10::ScalarType> out_dtype) {
16841667
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
16851668
const bool a_is_2d = mat_a.dim() == 2;
16861669
const bool b_is_2d = mat_b.dim() == 2;
1687-
16881670
// check that the strides are valid, the fn will throw an error if not
16891671
check_valid_strides_and_return_transposed(mat_a);
16901672
check_valid_strides_and_return_transposed(mat_b);
@@ -1694,10 +1676,12 @@ std::optional<c10::ScalarType> out_dtype) {
16941676
TORCH_CHECK(offs->dim() == 1, "offs has to be 1D");
16951677
TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32");
16961678
}
1679+
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
1680+
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high output type is supported for grouped gemm");
16971681
TORCH_CHECK(!bias.has_value(), "Bias not supported yet");
16981682

1699-
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype);
1700-
1683+
const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs);
1684+
Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_));
17011685
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
17021686
return out;
17031687
#else

aten/src/ATen/native/cuda/GroupMMCommon.cuh

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -47,42 +47,10 @@ __global__ void prepare_grouped_gemm_data(
4747
if (offs != nullptr) {
4848
int32_t start = tid == 0 ? 0 : offs[tid - 1];
4949
delta = offs[tid] - start;
50-
if (K < 0) {
51-
// CUTLASS cannot handle delta=0 here.
52-
CUDA_KERNEL_ASSERT(delta >0 && "expected ofsets to be greater than 0\n");
53-
} else {
54-
CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
55-
}
56-
57-
// TMA transfers require global memory tensor addresses to be
58-
// aligned to 16 bytes.
59-
if (tid < blockDim.x - 1) {
60-
// Check this requirement for input tensors, in case group
61-
// addresses are increased along the dynamic dimension.
62-
if ((K < 0 && a_row_major) || // 2D/2D: check along K dimension
63-
(M < 0 && !a_row_major)) { // 3D/2D: check along N dimension
64-
int align = 128 / cutlass::sizeof_bits<DtypeA>::value;
65-
CUDA_KERNEL_ASSERT(
66-
delta % align == 0 &&
67-
"expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n");
68-
}
69-
if ((K < 0 && !b_row_major) || // 2D/2D: check along K dimension
70-
(N < 0 && b_row_major)) { // 3D/2D: check along N dimension
71-
int align = 128 / cutlass::sizeof_bits<DtypeB>::value;
72-
CUDA_KERNEL_ASSERT(
73-
delta % align == 0 &&
74-
"expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n");
75-
}
76-
77-
// Check the same requirement for output tensor (that is always
78-
// contiguous, and in row-major layout).
79-
if (N < 0) {
80-
int align = 128 / cutlass::sizeof_bits<DtypeOutput>::value;
81-
CUDA_KERNEL_ASSERT(
82-
delta % align == 0 &&
83-
"expected output tensor dynamic dimension byte size to be non-negative multiple of 16\n");
84-
}
85-
}
50+
int align = 16 / sizeof(DtypeA);
51+
CUDA_KERNEL_ASSERT(
52+
delta >=0 && delta % align == 0 &&
53+
"expected dynamic dimension byte size to be non-negative multiple of 16 \n");
8654
}
8755
int64_t lda, ldb, ldoutput;
8856
if (M < 0) {
@@ -113,6 +81,7 @@ __global__ void prepare_grouped_gemm_data(
11381
} else if (K < 0) {
11482
// A, B is 2d, output is 3d
11583
K = delta;
84+
CUDA_KERNEL_ASSERT(delta > 0 && "can't handle K=0");
11685
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
11786
ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1];
11887
ldoutput = tensor_StrideOutput[1];

test/test_matmul_cuda.py

Lines changed: 10 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def grouped_mm_helper(self, alist, blist, gOlist, agradlist, bgradlist, outlist)
315315
def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, use_torch_compile):
316316
device = "cuda"
317317
dtype = torch.bfloat16
318-
m, n, k, n_groups = 16, 32, 64, 4
318+
m, n, k, n_groups = 16, 32, 64, 4 # all sizes have to be divisible by 16
319319
if a_row_major:
320320
a = torch.randn(m, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups]
321321
else:
@@ -382,9 +382,6 @@ def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major, use_torch_c
382382
b_contig = b if b_row_major else b.transpose(-2, -1)
383383
self.assertTrue(b_contig.is_contiguous() is not strided)
384384
for check_zero_size in (False, True):
385-
if check_zero_size and n_groups <= 1:
386-
continue
387-
388385
a.grad = None
389386
b.grad = None
390387
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
@@ -487,9 +484,6 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, use_torch_c
487484
b_contig = b if b_row_major else b.transpose(-2, -1)
488485
self.assertTrue(b_contig.is_contiguous() is not strided)
489486
for check_zero_size in (False, True):
490-
if check_zero_size and n_groups <= 1:
491-
continue
492-
493487
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
494488
if check_zero_size:
495489
offs[0] = offs[1]
@@ -1651,27 +1645,17 @@ def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist
16511645
for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist):
16521646
out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
16531647
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
1654-
self.assertEqual(out, out_ref, atol=5e-2, rtol=5e-4)
1655-
1656-
# Testing only _scaled_grouped_mm() with multiple shapes, as
1657-
# _scaled_mm() already has more combinations of parameters than
1658-
# _scaled_grouped_mm(), for supporing more than one inputs layout
1659-
# combinations.
1648+
self.assertEqual(out, out_ref, atol=8e-2, rtol=8e-4)
16601649

16611650
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
16621651
@xfailIfSM100OrLater
16631652
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
1664-
@parametrize(
1665-
"n_groups, m, n, k",
1666-
[(2, 1, 16, 16),
1667-
(4, 16, 16, 16)],
1668-
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
1669-
)
16701653
@parametrize("fast_accum", [False, True])
16711654
@parametrize("strided", [False, True])
16721655
@parametrize("use_torch_compile", [False, True])
1673-
def test_scaled_grouped_gemm_2d_2d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile):
1656+
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile):
16741657
device = "cuda"
1658+
m, n, k, n_groups = 16, 32, 64, 4 # all sizes have to be divisible by 16
16751659
a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
16761660
b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
16771661
scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32)
@@ -1701,26 +1685,18 @@ def test_scaled_grouped_gemm_2d_2d(self, n_groups, m, n, k, fast_accum, strided,
17011685
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
17021686
@xfailIfSM100OrLater
17031687
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
1704-
@parametrize(
1705-
"n_groups, m, n, k",
1706-
[(2, 1, 16, 16),
1707-
(4, 16, 16, 16)],
1708-
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
1709-
)
17101688
@parametrize("fast_accum", [False, True])
17111689
@parametrize("strided", [False, True])
17121690
@parametrize("use_torch_compile", [False, True])
1713-
def test_scaled_grouped_gemm_2d_3d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile):
1691+
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile):
17141692
device = "cuda"
17151693
s_int = int(strided)
1694+
m, n, k, n_groups = 16, 32, 64, 4
17161695
a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
17171696
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
17181697
self.assertTrue(a.is_contiguous() is not strided)
17191698
self.assertTrue(b.is_contiguous() is not strided)
17201699
for check_zero_size in (True, False):
1721-
if check_zero_size and n_groups <= 1:
1722-
continue
1723-
17241700
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
17251701
if check_zero_size:
17261702
offs[0] = offs[1]
@@ -1751,18 +1727,13 @@ def test_scaled_grouped_gemm_2d_3d(self, n_groups, m, n, k, fast_accum, strided,
17511727
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
17521728
@xfailIfSM100OrLater
17531729
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
1754-
@parametrize(
1755-
"n_groups, m, n, k",
1756-
[(2, 1, 16, 16),
1757-
(4, 16, 16, 16)],
1758-
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
1759-
)
17601730
@parametrize("fast_accum", [False, True])
17611731
@parametrize("strided", [False, True])
17621732
@parametrize("use_torch_compile", [False, True])
1763-
def test_scaled_grouped_gemm_3d_3d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile):
1733+
def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile):
17641734
device = "cuda"
17651735
s_int = int(strided)
1736+
m, n, k, n_groups = 16, 32, 64, 4
17661737
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
17671738
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
17681739
self.assertTrue(a.is_contiguous() is not strided)
@@ -1786,28 +1757,20 @@ def test_scaled_grouped_gemm_3d_3d(self, n_groups, m, n, k, fast_accum, strided,
17861757
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
17871758
@xfailIfSM100OrLater
17881759
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
1789-
@parametrize(
1790-
"n_groups, m, n, k",
1791-
[(2, 1, 16, 16),
1792-
(4, 16, 16, 16)],
1793-
name_fn=lambda n_groups, m, n, k: f"{n_groups}_{m}_{n}_{k}",
1794-
)
17951760
@parametrize("fast_accum", [False, True])
17961761
@parametrize("strided", [False, True])
17971762
@parametrize("use_torch_compile", [False, True])
1798-
def test_scaled_grouped_gemm_3d_2d(self, n_groups, m, n, k, fast_accum, strided, use_torch_compile):
1763+
def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile):
17991764
device = "cuda"
18001765
s_int = int(strided)
1766+
m, n, k, n_groups = 16, 32, 64, 4
18011767
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
18021768
b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
18031769
self.assertTrue(a.is_contiguous() is not strided)
18041770
self.assertTrue(b.is_contiguous() is not strided)
18051771
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
18061772
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32)
18071773
for check_zero_size in (True, False):
1808-
if check_zero_size and n_groups <= 1:
1809-
continue
1810-
18111774
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
18121775
if check_zero_size:
18131776
offs[0] = offs[1]

torch/_inductor/kernel/mm_scaled_grouped.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def early_config_prune(g, m, configs, named_args):
120120
return pruned_configs
121121

122122

123+
# Copied from fbgemm grouped_gemm.py
123124
triton_grouped_mm_source = r"""
124125
{%- if SCALED %}
125126
{%- if A_IS_2D or B_IS_2D %}
@@ -670,7 +671,7 @@ def _tuned_grouped_mm_common(
670671
)
671672

672673

673-
@register_lowering(aten._grouped_mm.default, type_promotion_kind=None)
674+
@register_lowering(aten._grouped_mm, type_promotion_kind=None)
674675
def tuned_grouped_mm(
675676
mat_a: TensorBox,
676677
mat_b: TensorBox,
@@ -682,7 +683,7 @@ def tuned_grouped_mm(
682683
"""Auto-tuning for _grouped_mm() operator."""
683684

684685
return _tuned_grouped_mm_common(
685-
"aten._grouped_mm.default",
686+
"aten._grouped_mm",
686687
"grouped_mm",
687688
aten__grouped_mm,
688689
triton_grouped_mm_template,

0 commit comments

Comments
 (0)