@@ -1721,6 +1721,352 @@ def transformer_block(x, w_q, w_k, w_v, w_proj, w_fc, w_out, mask):
17211721 expected = transformer_block (x , w_q , w_k , w_v , w_proj , w_fc , w_out , mask )
17221722 self .assertEqual (result , expected )
17231723
1724+ def _run_transformer_layer (
1725+ self , seq_len , hidden_dim , num_heads , head_dim , ffn_dim , atol = 1e-5 , rtol = 1.3e-6
1726+ ):
1727+ """Run a Llama-style transformer layer forward pass and verify correctness.
1728+
1729+ Architecture: RMSNorm -> Multi-Head Attention -> Residual ->
1730+ RMSNorm -> SwiGLU FFN -> Residual
1731+ """
1732+ torch ._dynamo .reset ()
1733+
1734+ def transformer_layer (
1735+ x ,
1736+ rms_w1 ,
1737+ rms_w2 ,
1738+ w_q ,
1739+ w_k ,
1740+ w_v ,
1741+ w_o ,
1742+ w_gate ,
1743+ w_up ,
1744+ w_down ,
1745+ mask ,
1746+ ):
1747+ T , C = x .shape
1748+
1749+ # Pre-attention RMSNorm
1750+ variance = x .pow (2 ).mean (- 1 , keepdim = True )
1751+ h = x * torch .rsqrt (variance + 1e-6 ) * rms_w1
1752+
1753+ # Multi-head self-attention
1754+ q = (h @ w_q ).view (T , num_heads , head_dim ).permute (1 , 0 , 2 ) # (H, T, D)
1755+ k = (h @ w_k ).view (T , num_heads , head_dim ).permute (1 , 0 , 2 )
1756+ v = (h @ w_v ).view (T , num_heads , head_dim ).permute (1 , 0 , 2 )
1757+
1758+ scale = 1.0 / (head_dim ** 0.5 )
1759+ att = (q @ k .transpose (- 2 , - 1 )) * scale # (H, T, T)
1760+ att = att + mask # causal mask broadcasts (T, T) -> (H, T, T)
1761+ att = torch .softmax (att , dim = - 1 )
1762+ attn_out = (att @ v ).permute (1 , 0 , 2 ).contiguous ().view (T , C ) # (T, C)
1763+
1764+ x = x + (attn_out @ w_o )
1765+
1766+ # Pre-FFN RMSNorm
1767+ variance = x .pow (2 ).mean (- 1 , keepdim = True )
1768+ h = x * torch .rsqrt (variance + 1e-6 ) * rms_w2
1769+
1770+ # SwiGLU FFN
1771+ gate = torch .nn .functional .silu (h @ w_gate )
1772+ up = h @ w_up
1773+ x = x + ((gate * up ) @ w_down )
1774+
1775+ return x
1776+
1777+ compiled = self ._compile (transformer_layer )
1778+
1779+ # Initialize weights with small values for numerical stability
1780+ s = 0.02
1781+ w_q = torch .randn (hidden_dim , hidden_dim , device = self .DEVICE ) * s
1782+ w_k = torch .randn (hidden_dim , hidden_dim , device = self .DEVICE ) * s
1783+ w_v = torch .randn (hidden_dim , hidden_dim , device = self .DEVICE ) * s
1784+ w_o = torch .randn (hidden_dim , hidden_dim , device = self .DEVICE ) * s
1785+ w_gate = torch .randn (hidden_dim , ffn_dim , device = self .DEVICE ) * s
1786+ w_up = torch .randn (hidden_dim , ffn_dim , device = self .DEVICE ) * s
1787+ w_down = torch .randn (ffn_dim , hidden_dim , device = self .DEVICE ) * s
1788+ rms_w1 = torch .ones (hidden_dim , device = self .DEVICE )
1789+ rms_w2 = torch .ones (hidden_dim , device = self .DEVICE )
1790+
1791+ # Causal mask (T, T) - broadcasts over heads
1792+ mask = torch .triu (
1793+ torch .full ((seq_len , seq_len ), float ("-inf" ), device = self .DEVICE ),
1794+ diagonal = 1 ,
1795+ )
1796+
1797+ x = torch .randn (seq_len , hidden_dim , device = self .DEVICE ) * 0.02
1798+
1799+ result = compiled (
1800+ x ,
1801+ rms_w1 ,
1802+ rms_w2 ,
1803+ w_q ,
1804+ w_k ,
1805+ w_v ,
1806+ w_o ,
1807+ w_gate ,
1808+ w_up ,
1809+ w_down ,
1810+ mask ,
1811+ )
1812+ expected = transformer_layer (
1813+ x ,
1814+ rms_w1 ,
1815+ rms_w2 ,
1816+ w_q ,
1817+ w_k ,
1818+ w_v ,
1819+ w_o ,
1820+ w_gate ,
1821+ w_up ,
1822+ w_down ,
1823+ mask ,
1824+ )
1825+ self .assertEqual (result , expected , atol = atol , rtol = rtol )
1826+
1827+ @skip_if_cuda
1828+ def test_transformer_layer_tiny (self ):
1829+ """Test full Llama-style transformer layer at tiny dimensions."""
1830+ self ._run_transformer_layer (
1831+ seq_len = 32 ,
1832+ hidden_dim = 64 ,
1833+ num_heads = 2 ,
1834+ head_dim = 32 ,
1835+ ffn_dim = 256 ,
1836+ )
1837+
1838+ @skip_if_cuda
1839+ def test_transformer_layer_medium (self ):
1840+ """Test full Llama-style transformer layer at Llama-7B dimensions."""
1841+ self ._run_transformer_layer (
1842+ seq_len = 128 ,
1843+ hidden_dim = 4096 ,
1844+ num_heads = 32 ,
1845+ head_dim = 128 ,
1846+ ffn_dim = 11008 ,
1847+ atol = 1e-4 ,
1848+ rtol = 1e-4 ,
1849+ )
1850+
1851+ @skip_if_cuda
1852+ def test_transformer_layer_large (self ):
1853+ """Test full Llama-style transformer layer at Llama-405B dimensions."""
1854+ self ._run_transformer_layer (
1855+ seq_len = 32 ,
1856+ hidden_dim = 16384 ,
1857+ num_heads = 128 ,
1858+ head_dim = 128 ,
1859+ ffn_dim = 53248 ,
1860+ atol = 2e-3 ,
1861+ rtol = 1e-3 ,
1862+ )
1863+
1864+ @skip_if_cuda
1865+ def test_permute_contiguous_3d (self ):
1866+ """Test that permute + contiguous on a 3D tensor produces correct results."""
1867+
1868+ def fn (x ):
1869+ return x .permute (1 , 0 , 2 ).contiguous ()
1870+
1871+ compiled = self ._compile (fn )
1872+ x = torch .randn (2 , 32 , 32 , device = self .DEVICE )
1873+ result = compiled (x )
1874+ expected = fn (x )
1875+ self .assertEqual (result , expected )
1876+
1877+ @skip_if_cuda
1878+ def test_transpose_contiguous_2d (self ):
1879+ """Test that transpose + contiguous on a 2D tensor compiles and runs."""
1880+
1881+ def fn (x ):
1882+ return x .transpose (0 , 1 ).contiguous ()
1883+
1884+ compiled = self ._compile (fn )
1885+ x = torch .randn (2 , 32 , device = self .DEVICE )
1886+ result = compiled (x )
1887+ expected = fn (x )
1888+ self .assertEqual (result , expected )
1889+
1890+ @skip_if_cuda
1891+ def test_permute_contiguous_2d_asymmetric (self ):
1892+ """Test transpose+contiguous on asymmetric 2D shapes."""
1893+
1894+ def fn (x ):
1895+ return x .transpose (0 , 1 ).contiguous ()
1896+
1897+ compiled = self ._compile (fn )
1898+ for shape in [(10000 , 2 ), (2 , 10000 ), (8 , 128 ), (128 , 8 ), (3 , 256 )]:
1899+ with self .subTest (shape = shape ):
1900+ x = torch .randn (* shape , device = self .DEVICE )
1901+ result = compiled (x )
1902+ expected = fn (x )
1903+ self .assertEqual (result , expected )
1904+
1905+ @skip_if_cuda
1906+ def test_permute_contiguous_3d_all_perms (self ):
1907+ """Test non-identity 3D permutations with distinct dim sizes."""
1908+ all_perms = [
1909+ (1 , 0 , 2 ),
1910+ (0 , 2 , 1 ),
1911+ (2 , 1 , 0 ), # full-rank detection
1912+ ]
1913+ x = torch .randn (2 , 1152 , 2048 , device = self .DEVICE )
1914+
1915+ for perm in all_perms :
1916+ with self .subTest (perm = perm ):
1917+
1918+ def fn (x , p = perm ):
1919+ return x .permute (* p ).contiguous ()
1920+
1921+ compiled = self ._compile (fn )
1922+ result = compiled (x )
1923+ expected = fn (x )
1924+ self .assertEqual (result , expected )
1925+
1926+ @skip_if_cuda
1927+ @skip_if_tpu
1928+ def test_permute_contiguous_3d_collapsed (self ):
1929+ """Test 3D permutations that require collapsed-dim detection."""
1930+ all_perms = [
1931+ (2 , 0 , 1 ),
1932+ (1 , 2 , 0 ),
1933+ ]
1934+ x = torch .randn (2 , 1152 , 2048 , device = self .DEVICE )
1935+
1936+ for perm in all_perms :
1937+ with self .subTest (perm = perm ):
1938+
1939+ def fn (x , p = perm ):
1940+ return x .permute (* p ).contiguous ()
1941+
1942+ compiled = self ._compile (fn )
1943+ result = compiled (x )
1944+ expected = fn (x )
1945+ self .assertEqual (result , expected )
1946+
1947+ @skip_if_cuda
1948+ def test_permute_contiguous_3d_large_notile (self ):
1949+ """Test 3D perms with large shape but no tiling (dim=1024 exact fit)."""
1950+ perms = [(1 , 0 , 2 ), (0 , 2 , 1 ), (2 , 1 , 0 )]
1951+
1952+ x = torch .randn (2 , 1152 , 1024 , device = self .DEVICE )
1953+
1954+ for perm in perms :
1955+ with self .subTest (perm = perm ):
1956+
1957+ def fn (x , p = perm ):
1958+ return x .permute (* p ).contiguous ()
1959+
1960+ compiled = self ._compile (fn )
1961+ result = compiled (x )
1962+ expected = fn (x )
1963+ self .assertEqual (result , expected )
1964+
1965+ @skip_if_cuda
1966+ def test_permute_contiguous_3d_medium (self ):
1967+ """Test 3D perms with medium shape that triggers tiling on last dim."""
1968+ perms = [(1 , 0 , 2 ), (0 , 2 , 1 ), (2 , 1 , 0 )]
1969+
1970+ x = torch .randn (2 , 8 , 2048 , device = self .DEVICE )
1971+
1972+ for perm in perms :
1973+ with self .subTest (perm = perm ):
1974+
1975+ def fn (x , p = perm ):
1976+ return x .permute (* p ).contiguous ()
1977+
1978+ compiled = self ._compile (fn )
1979+ result = compiled (x )
1980+ expected = fn (x )
1981+ self .assertEqual (result , expected )
1982+
1983+ @skip_if_cuda
1984+ def test_permute_contiguous_3d_small (self ):
1985+ """Test 3D perms with small shapes that produce grid=(1,)."""
1986+ all_perms = [
1987+ (1 , 0 , 2 ),
1988+ (0 , 2 , 1 ),
1989+ (2 , 1 , 0 ),
1990+ (2 , 0 , 1 ),
1991+ (1 , 2 , 0 ),
1992+ ]
1993+ # (2,0,1) on (2,8,16) triggers a Mosaic "unsupported shape cast"
1994+ # bug on TPU due to internal tile padding (128x16 -> 16x128).
1995+ tpu_skip = {(2 , 0 , 1 )} if self .DEVICE == "tpu" else set ()
1996+ x = torch .randn (2 , 8 , 16 , device = self .DEVICE )
1997+
1998+ for perm in all_perms :
1999+ if perm in tpu_skip :
2000+ continue
2001+ with self .subTest (perm = perm ):
2002+
2003+ def fn (x , p = perm ):
2004+ return x .permute (* p ).contiguous ()
2005+
2006+ compiled = self ._compile (fn )
2007+ result = compiled (x )
2008+ expected = fn (x )
2009+ self .assertEqual (result , expected )
2010+
2011+ @skip_if_cuda
2012+ @skip_if_tpu
2013+ def test_permute_contiguous_4d (self ):
2014+ """Test all 23 non-identity 4D permutations with multi-tile grids."""
2015+ all_perms = [
2016+ (0 , 1 , 3 , 2 ),
2017+ (0 , 2 , 1 , 3 ),
2018+ (0 , 2 , 3 , 1 ),
2019+ (0 , 3 , 1 , 2 ),
2020+ (0 , 3 , 2 , 1 ),
2021+ (1 , 0 , 2 , 3 ),
2022+ (1 , 0 , 3 , 2 ),
2023+ (1 , 2 , 0 , 3 ),
2024+ (1 , 2 , 3 , 0 ),
2025+ (1 , 3 , 0 , 2 ),
2026+ (1 , 3 , 2 , 0 ),
2027+ (2 , 0 , 1 , 3 ),
2028+ (2 , 0 , 3 , 1 ),
2029+ (2 , 1 , 0 , 3 ),
2030+ (2 , 1 , 3 , 0 ),
2031+ (2 , 3 , 0 , 1 ),
2032+ (2 , 3 , 1 , 0 ),
2033+ (3 , 0 , 1 , 2 ),
2034+ (3 , 0 , 2 , 1 ),
2035+ (3 , 1 , 0 , 2 ),
2036+ (3 , 1 , 2 , 0 ),
2037+ (3 , 2 , 0 , 1 ),
2038+ (3 , 2 , 1 , 0 ),
2039+ ]
2040+ needs_large_first = {
2041+ (0 , 3 , 2 , 1 ),
2042+ (1 , 3 , 0 , 2 ),
2043+ (1 , 3 , 2 , 0 ),
2044+ (2 , 3 , 1 , 0 ),
2045+ (3 , 0 , 2 , 1 ),
2046+ (3 , 1 , 0 , 2 ),
2047+ (3 , 1 , 2 , 0 ),
2048+ (3 , 2 , 0 , 1 ),
2049+ (3 , 2 , 1 , 0 ),
2050+ }
2051+ shapes = {
2052+ "large_last" : (2 , 4 , 128 , 2048 ),
2053+ "large_first" : (1152 , 1152 , 2 , 4 ),
2054+ }
2055+
2056+ for perm in all_perms :
2057+ key = "large_first" if perm in needs_large_first else "large_last"
2058+ shape = shapes [key ]
2059+ with self .subTest (perm = perm , shape = shape ):
2060+ x = torch .randn (* shape , device = self .DEVICE )
2061+
2062+ def fn (x , p = perm ):
2063+ return x .permute (* p ).contiguous ()
2064+
2065+ compiled = self ._compile (fn )
2066+ result = compiled (x )
2067+ expected = fn (x )
2068+ self .assertEqual (result , expected )
2069+
17242070 def test_warpgroup_size_2d_aligned_32x8 (self ):
17252071 """Test 2D tensor with 32x8 = 256 elements (2 warpgroups)."""
17262072
0 commit comments