@@ -3288,51 +3288,10 @@ def forward(self, x):
32883288 self .assertEqual (exported .device .index , 0 )
32893289 self .assertEqual (exported .dtype , torch .bfloat16 )
32903290
3291- @unittest .skipIf (not torch .cuda .is_available (), "requires cuda" )
3292- def test_cuda_amp_autocast (self ):
3293- class MyModule (torch .nn .Module ):
3294- def forward (self , x ):
3295- a_float32 = torch .rand ((8 , 8 ), device = "cuda" )
3296- b_float32 = torch .rand ((8 , 8 ), device = "cuda" )
3297-
3298- with torch .cuda .amp .autocast (dtype = torch .torch .float64 ):
3299- c_float64 = torch .mm (a_float32 , b_float32 )
3300- return c_float64
3301-
3302- module = MyModule ()
3303- real = module (torch .tensor ([0.5 ]))
3304- real_device = real .device
3305- real_dtype = real .dtype
3306-
3307- graph , _ = torch ._dynamo .export (module , torch .tensor ([[0.0 , 0 ], [0 , 0 ]]))
3308- exported = graph (torch .tensor ([0.5 ]))
3309- self .assertEqual (exported .device , real_device )
3310- self .assertEqual (exported .dtype , real_dtype )
3311-
3312- self .assertEqual (exported .device .type , "cuda" )
3313- self .assertEqual (exported .device .index , 0 )
3314- self .assertEqual (exported .dtype , torch .float64 )
3315-
3316- def test_is_autocast_cpu_enabled (self ):
3317- def fn (a_float32 , b_float32 ):
3318- with torch .cpu .amp .autocast (dtype = torch .bfloat16 ):
3319- c_float16 = torch .mm (a_float32 , b_float32 )
3320- if torch .is_autocast_cpu_enabled ():
3321- c_float16 = c_float16 + 1
3322- return c_float16
3323-
3324- a = torch .rand ((8 , 8 ))
3325- b = torch .rand ((8 , 8 ))
3326- ref = fn (a , b )
3327- opt_fn = torch ._dynamo .optimize ("eager" , nopython = True )(fn )
3328- res = opt_fn (a , b )
3329- self .assertTrue (same (ref , res ))
3330-
33313291 @unittest .skipIf (
33323292 not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater ,
33333293 "Can't run fused SDPA on this platform" ,
33343294 )
3335- @patch .object (torch ._dynamo .config , "dynamic_shapes" , False )
33363295 def test_autocast_sdpa (self ):
33373296 class MyModule (torch .nn .Module ):
33383297 def forward (self , query , key , value ):
0 commit comments