@@ -1939,6 +1939,41 @@ def forward(self, x, y):
19391939 shifts = torch .arange (0 , 64 , 8 , device = x .device , dtype = torch .int64 )
19401940 return (expanded >> shifts ) & 255
19411941
1942+ torch .cuda .caching_allocator_enable (False )
1943+ model = Repro ()
1944+ example_inputs = (
1945+ torch .randint (
1946+ 0 , 256 , (200 , INNER_DIM ), device = self .device , dtype = torch .int64
1947+ ),
1948+ torch .randn (50 , 8 , device = self .device ),
1949+ )
1950+ spec = {
1951+ "x" : (Dim .DYNAMIC , Dim .STATIC ),
1952+ "y" : (Dim .DYNAMIC , Dim .STATIC ),
1953+ }
1954+ self .check_model (model , example_inputs , dynamic_shapes = spec )
1955+ torch .cuda .caching_allocator_enable (True )
1956+
1957+ @skipIfMPS
1958+ @config .patch ({"triton.autotune_at_compile_time" : None })
1959+ @torch .fx .experimental ._config .patch ("backed_size_oblivious" , True )
1960+ def test_slice_negative_index_backed_symints_no_unbacked (self ):
1961+ # x[-s1:] where x.size(0) = s0-1 should produce Max(s0-1 - s1, 0),
1962+ # not an unbacked symint with a bad fallback value.
1963+ if self .device != GPU_TYPE :
1964+ raise unittest .SkipTest ("requires triton" )
1965+
1966+ INNER_DIM = 4224
1967+
1968+ class Repro (torch .nn .Module ):
1969+ def forward (self , x , y ):
1970+ x_trimmed = x [:- 1 ]
1971+ sliced = x_trimmed [- y .size (0 ) :]
1972+ reshaped = sliced .reshape (- 1 , 128 , 33 )
1973+ expanded = reshaped .unsqueeze (3 ).expand (- 1 , 128 , 33 , 8 )
1974+ shifts = torch .arange (0 , 64 , 8 , device = x .device , dtype = torch .int64 )
1975+ return (expanded >> shifts ) & 255
1976+
19421977 torch .cuda .caching_allocator_enable (False )
19431978 try :
19441979 model = Repro ()
0 commit comments