@@ -747,10 +747,13 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo
747747 )
748748
749749 # Replace your privateuse1 backend name with 'privateuse1'
750+ # This handles the case where PrivateUse1TestBase.device_type has been
751+ # changed from "privateuse1" to the actual backend name (e.g., "openreg")
752+ # by setUpClass being called during previous instantiate_device_type_tests calls
750753 if is_privateuse1_backend_available ():
751754 privateuse1_backend_name = torch ._C ._get_privateuse1_backend_name ()
752755
753- def func_replace (x : str ):
756+ def func_replace (x : str ) -> str :
754757 return x .replace (privateuse1_backend_name , "privateuse1" )
755758
756759 except_for = (
@@ -763,14 +766,20 @@ def func_replace(x: str):
763766 if not isinstance (only_for , str )
764767 else func_replace (only_for )
765768 )
769+ else :
770+
771+ def func_replace (x : str ) -> str :
772+ return x
766773
767774 if except_for :
768775 device_type_test_bases = filter (
769- lambda x : x .device_type not in except_for , device_type_test_bases
776+ lambda x : func_replace (x .device_type ) not in except_for ,
777+ device_type_test_bases ,
770778 )
771779 if only_for :
772780 device_type_test_bases = filter (
773- lambda x : x .device_type in only_for , device_type_test_bases
781+ lambda x : func_replace (x .device_type ) in only_for ,
782+ device_type_test_bases ,
774783 )
775784
776785 return list (device_type_test_bases )
0 commit comments