Skip to content

Commit 4f940a2

Browse files
committed
Pallas: generalize permutation detection for N-D tensor transposes
Generalize the permutation detection logic in the Pallas backend to handle arbitrary N-D tensor transposes, not just 2D swaps. Adds collapsed-dimension detection for cases where iteration dimensions don't match tensor dimensions directly. Skips permutation detection on GPU where inputs are flattened to 1D. Also adds cat_unbacked tests to expected failures (pre-existing dlpack issue on TPU, unrelated to this change). Authored with assistance from Claude
1 parent e07e28d commit 4f940a2

8 files changed

Lines changed: 1201 additions & 313 deletions

File tree

test/inductor/pallas_expected_failures/CpuTests.test_lerp_cpu

Lines changed: 0 additions & 5 deletions
This file was deleted.

test/inductor/pallas_expected_failures/CpuTests.test_permute1_cpu

Lines changed: 0 additions & 5 deletions
This file was deleted.

test/inductor/pallas_expected_failures/CpuTests.test_prod_cpu

Lines changed: 0 additions & 3 deletions
This file was deleted.

test/inductor/pallas_expected_failures/CpuTests.test_scaled_dot_product_attention_cpu

Lines changed: 0 additions & 5 deletions
This file was deleted.

test/inductor/pallas_skip_tests/CpuTests.test_large_block_sizes_cpu

Whitespace-only changes.

test/inductor/test_pallas.py

Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,6 +1721,352 @@ def transformer_block(x, w_q, w_k, w_v, w_proj, w_fc, w_out, mask):
17211721
expected = transformer_block(x, w_q, w_k, w_v, w_proj, w_fc, w_out, mask)
17221722
self.assertEqual(result, expected)
17231723

1724+
def _run_transformer_layer(
1725+
self, seq_len, hidden_dim, num_heads, head_dim, ffn_dim, atol=1e-5, rtol=1.3e-6
1726+
):
1727+
"""Run a Llama-style transformer layer forward pass and verify correctness.
1728+
1729+
Architecture: RMSNorm -> Multi-Head Attention -> Residual ->
1730+
RMSNorm -> SwiGLU FFN -> Residual
1731+
"""
1732+
torch._dynamo.reset()
1733+
1734+
def transformer_layer(
1735+
x,
1736+
rms_w1,
1737+
rms_w2,
1738+
w_q,
1739+
w_k,
1740+
w_v,
1741+
w_o,
1742+
w_gate,
1743+
w_up,
1744+
w_down,
1745+
mask,
1746+
):
1747+
T, C = x.shape
1748+
1749+
# Pre-attention RMSNorm
1750+
variance = x.pow(2).mean(-1, keepdim=True)
1751+
h = x * torch.rsqrt(variance + 1e-6) * rms_w1
1752+
1753+
# Multi-head self-attention
1754+
q = (h @ w_q).view(T, num_heads, head_dim).permute(1, 0, 2) # (H, T, D)
1755+
k = (h @ w_k).view(T, num_heads, head_dim).permute(1, 0, 2)
1756+
v = (h @ w_v).view(T, num_heads, head_dim).permute(1, 0, 2)
1757+
1758+
scale = 1.0 / (head_dim**0.5)
1759+
att = (q @ k.transpose(-2, -1)) * scale # (H, T, T)
1760+
att = att + mask # causal mask broadcasts (T, T) -> (H, T, T)
1761+
att = torch.softmax(att, dim=-1)
1762+
attn_out = (att @ v).permute(1, 0, 2).contiguous().view(T, C) # (T, C)
1763+
1764+
x = x + (attn_out @ w_o)
1765+
1766+
# Pre-FFN RMSNorm
1767+
variance = x.pow(2).mean(-1, keepdim=True)
1768+
h = x * torch.rsqrt(variance + 1e-6) * rms_w2
1769+
1770+
# SwiGLU FFN
1771+
gate = torch.nn.functional.silu(h @ w_gate)
1772+
up = h @ w_up
1773+
x = x + ((gate * up) @ w_down)
1774+
1775+
return x
1776+
1777+
compiled = self._compile(transformer_layer)
1778+
1779+
# Initialize weights with small values for numerical stability
1780+
s = 0.02
1781+
w_q = torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s
1782+
w_k = torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s
1783+
w_v = torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s
1784+
w_o = torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s
1785+
w_gate = torch.randn(hidden_dim, ffn_dim, device=self.DEVICE) * s
1786+
w_up = torch.randn(hidden_dim, ffn_dim, device=self.DEVICE) * s
1787+
w_down = torch.randn(ffn_dim, hidden_dim, device=self.DEVICE) * s
1788+
rms_w1 = torch.ones(hidden_dim, device=self.DEVICE)
1789+
rms_w2 = torch.ones(hidden_dim, device=self.DEVICE)
1790+
1791+
# Causal mask (T, T) - broadcasts over heads
1792+
mask = torch.triu(
1793+
torch.full((seq_len, seq_len), float("-inf"), device=self.DEVICE),
1794+
diagonal=1,
1795+
)
1796+
1797+
x = torch.randn(seq_len, hidden_dim, device=self.DEVICE) * 0.02
1798+
1799+
result = compiled(
1800+
x,
1801+
rms_w1,
1802+
rms_w2,
1803+
w_q,
1804+
w_k,
1805+
w_v,
1806+
w_o,
1807+
w_gate,
1808+
w_up,
1809+
w_down,
1810+
mask,
1811+
)
1812+
expected = transformer_layer(
1813+
x,
1814+
rms_w1,
1815+
rms_w2,
1816+
w_q,
1817+
w_k,
1818+
w_v,
1819+
w_o,
1820+
w_gate,
1821+
w_up,
1822+
w_down,
1823+
mask,
1824+
)
1825+
self.assertEqual(result, expected, atol=atol, rtol=rtol)
1826+
1827+
@skip_if_cuda
1828+
def test_transformer_layer_tiny(self):
1829+
"""Test full Llama-style transformer layer at tiny dimensions."""
1830+
self._run_transformer_layer(
1831+
seq_len=32,
1832+
hidden_dim=64,
1833+
num_heads=2,
1834+
head_dim=32,
1835+
ffn_dim=256,
1836+
)
1837+
1838+
@skip_if_cuda
1839+
def test_transformer_layer_medium(self):
1840+
"""Test full Llama-style transformer layer at Llama-7B dimensions."""
1841+
self._run_transformer_layer(
1842+
seq_len=128,
1843+
hidden_dim=4096,
1844+
num_heads=32,
1845+
head_dim=128,
1846+
ffn_dim=11008,
1847+
atol=1e-4,
1848+
rtol=1e-4,
1849+
)
1850+
1851+
@skip_if_cuda
1852+
def test_transformer_layer_large(self):
1853+
"""Test full Llama-style transformer layer at Llama-405B dimensions."""
1854+
self._run_transformer_layer(
1855+
seq_len=32,
1856+
hidden_dim=16384,
1857+
num_heads=128,
1858+
head_dim=128,
1859+
ffn_dim=53248,
1860+
atol=2e-3,
1861+
rtol=1e-3,
1862+
)
1863+
1864+
@skip_if_cuda
1865+
def test_permute_contiguous_3d(self):
1866+
"""Test that permute + contiguous on a 3D tensor produces correct results."""
1867+
1868+
def fn(x):
1869+
return x.permute(1, 0, 2).contiguous()
1870+
1871+
compiled = self._compile(fn)
1872+
x = torch.randn(2, 32, 32, device=self.DEVICE)
1873+
result = compiled(x)
1874+
expected = fn(x)
1875+
self.assertEqual(result, expected)
1876+
1877+
@skip_if_cuda
1878+
def test_transpose_contiguous_2d(self):
1879+
"""Test that transpose + contiguous on a 2D tensor compiles and runs."""
1880+
1881+
def fn(x):
1882+
return x.transpose(0, 1).contiguous()
1883+
1884+
compiled = self._compile(fn)
1885+
x = torch.randn(2, 32, device=self.DEVICE)
1886+
result = compiled(x)
1887+
expected = fn(x)
1888+
self.assertEqual(result, expected)
1889+
1890+
@skip_if_cuda
1891+
def test_permute_contiguous_2d_asymmetric(self):
1892+
"""Test transpose+contiguous on asymmetric 2D shapes."""
1893+
1894+
def fn(x):
1895+
return x.transpose(0, 1).contiguous()
1896+
1897+
compiled = self._compile(fn)
1898+
for shape in [(10000, 2), (2, 10000), (8, 128), (128, 8), (3, 256)]:
1899+
with self.subTest(shape=shape):
1900+
x = torch.randn(*shape, device=self.DEVICE)
1901+
result = compiled(x)
1902+
expected = fn(x)
1903+
self.assertEqual(result, expected)
1904+
1905+
@skip_if_cuda
1906+
def test_permute_contiguous_3d_all_perms(self):
1907+
"""Test non-identity 3D permutations with distinct dim sizes."""
1908+
all_perms = [
1909+
(1, 0, 2),
1910+
(0, 2, 1),
1911+
(2, 1, 0), # full-rank detection
1912+
]
1913+
x = torch.randn(2, 1152, 2048, device=self.DEVICE)
1914+
1915+
for perm in all_perms:
1916+
with self.subTest(perm=perm):
1917+
1918+
def fn(x, p=perm):
1919+
return x.permute(*p).contiguous()
1920+
1921+
compiled = self._compile(fn)
1922+
result = compiled(x)
1923+
expected = fn(x)
1924+
self.assertEqual(result, expected)
1925+
1926+
@skip_if_cuda
1927+
@skip_if_tpu
1928+
def test_permute_contiguous_3d_collapsed(self):
1929+
"""Test 3D permutations that require collapsed-dim detection."""
1930+
all_perms = [
1931+
(2, 0, 1),
1932+
(1, 2, 0),
1933+
]
1934+
x = torch.randn(2, 1152, 2048, device=self.DEVICE)
1935+
1936+
for perm in all_perms:
1937+
with self.subTest(perm=perm):
1938+
1939+
def fn(x, p=perm):
1940+
return x.permute(*p).contiguous()
1941+
1942+
compiled = self._compile(fn)
1943+
result = compiled(x)
1944+
expected = fn(x)
1945+
self.assertEqual(result, expected)
1946+
1947+
@skip_if_cuda
1948+
def test_permute_contiguous_3d_large_notile(self):
1949+
"""Test 3D perms with large shape but no tiling (dim=1024 exact fit)."""
1950+
perms = [(1, 0, 2), (0, 2, 1), (2, 1, 0)]
1951+
1952+
x = torch.randn(2, 1152, 1024, device=self.DEVICE)
1953+
1954+
for perm in perms:
1955+
with self.subTest(perm=perm):
1956+
1957+
def fn(x, p=perm):
1958+
return x.permute(*p).contiguous()
1959+
1960+
compiled = self._compile(fn)
1961+
result = compiled(x)
1962+
expected = fn(x)
1963+
self.assertEqual(result, expected)
1964+
1965+
@skip_if_cuda
1966+
def test_permute_contiguous_3d_medium(self):
1967+
"""Test 3D perms with medium shape that triggers tiling on last dim."""
1968+
perms = [(1, 0, 2), (0, 2, 1), (2, 1, 0)]
1969+
1970+
x = torch.randn(2, 8, 2048, device=self.DEVICE)
1971+
1972+
for perm in perms:
1973+
with self.subTest(perm=perm):
1974+
1975+
def fn(x, p=perm):
1976+
return x.permute(*p).contiguous()
1977+
1978+
compiled = self._compile(fn)
1979+
result = compiled(x)
1980+
expected = fn(x)
1981+
self.assertEqual(result, expected)
1982+
1983+
@skip_if_cuda
1984+
def test_permute_contiguous_3d_small(self):
1985+
"""Test 3D perms with small shapes that produce grid=(1,)."""
1986+
all_perms = [
1987+
(1, 0, 2),
1988+
(0, 2, 1),
1989+
(2, 1, 0),
1990+
(2, 0, 1),
1991+
(1, 2, 0),
1992+
]
1993+
# (2,0,1) on (2,8,16) triggers a Mosaic "unsupported shape cast"
1994+
# bug on TPU due to internal tile padding (128x16 -> 16x128).
1995+
tpu_skip = {(2, 0, 1)} if self.DEVICE == "tpu" else set()
1996+
x = torch.randn(2, 8, 16, device=self.DEVICE)
1997+
1998+
for perm in all_perms:
1999+
if perm in tpu_skip:
2000+
continue
2001+
with self.subTest(perm=perm):
2002+
2003+
def fn(x, p=perm):
2004+
return x.permute(*p).contiguous()
2005+
2006+
compiled = self._compile(fn)
2007+
result = compiled(x)
2008+
expected = fn(x)
2009+
self.assertEqual(result, expected)
2010+
2011+
@skip_if_cuda
2012+
@skip_if_tpu
2013+
def test_permute_contiguous_4d(self):
2014+
"""Test all 23 non-identity 4D permutations with multi-tile grids."""
2015+
all_perms = [
2016+
(0, 1, 3, 2),
2017+
(0, 2, 1, 3),
2018+
(0, 2, 3, 1),
2019+
(0, 3, 1, 2),
2020+
(0, 3, 2, 1),
2021+
(1, 0, 2, 3),
2022+
(1, 0, 3, 2),
2023+
(1, 2, 0, 3),
2024+
(1, 2, 3, 0),
2025+
(1, 3, 0, 2),
2026+
(1, 3, 2, 0),
2027+
(2, 0, 1, 3),
2028+
(2, 0, 3, 1),
2029+
(2, 1, 0, 3),
2030+
(2, 1, 3, 0),
2031+
(2, 3, 0, 1),
2032+
(2, 3, 1, 0),
2033+
(3, 0, 1, 2),
2034+
(3, 0, 2, 1),
2035+
(3, 1, 0, 2),
2036+
(3, 1, 2, 0),
2037+
(3, 2, 0, 1),
2038+
(3, 2, 1, 0),
2039+
]
2040+
needs_large_first = {
2041+
(0, 3, 2, 1),
2042+
(1, 3, 0, 2),
2043+
(1, 3, 2, 0),
2044+
(2, 3, 1, 0),
2045+
(3, 0, 2, 1),
2046+
(3, 1, 0, 2),
2047+
(3, 1, 2, 0),
2048+
(3, 2, 0, 1),
2049+
(3, 2, 1, 0),
2050+
}
2051+
shapes = {
2052+
"large_last": (2, 4, 128, 2048),
2053+
"large_first": (1152, 1152, 2, 4),
2054+
}
2055+
2056+
for perm in all_perms:
2057+
key = "large_first" if perm in needs_large_first else "large_last"
2058+
shape = shapes[key]
2059+
with self.subTest(perm=perm, shape=shape):
2060+
x = torch.randn(*shape, device=self.DEVICE)
2061+
2062+
def fn(x, p=perm):
2063+
return x.permute(*p).contiguous()
2064+
2065+
compiled = self._compile(fn)
2066+
result = compiled(x)
2067+
expected = fn(x)
2068+
self.assertEqual(result, expected)
2069+
17242070
def test_warpgroup_size_2d_aligned_32x8(self):
17252071
"""Test 2D tensor with 32x8 = 256 elements (2 warpgroups)."""
17262072

0 commit comments

Comments
 (0)