Skip to content

Commit 47acb06

Browse files
authored
Merge 4c2f399 into 8154159
2 parents 8154159 + 4c2f399 commit 47acb06

1 file changed

Lines changed: 12 additions & 3 deletions

File tree

torch/testing/_internal/common_device_type.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -747,10 +747,13 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo
747747
)
748748

749749
# Replace your privateuse1 backend name with 'privateuse1'
750+
# This handles the case where PrivateUse1TestBase.device_type has been
751+
# changed from "privateuse1" to the actual backend name (e.g., "openreg")
752+
# by setUpClass being called during previous instantiate_device_type_tests calls
750753
if is_privateuse1_backend_available():
751754
privateuse1_backend_name = torch._C._get_privateuse1_backend_name()
752755

753-
def func_replace(x: str):
756+
def func_replace(x: str) -> str:
754757
return x.replace(privateuse1_backend_name, "privateuse1")
755758

756759
except_for = (
@@ -763,14 +766,20 @@ def func_replace(x: str):
763766
if not isinstance(only_for, str)
764767
else func_replace(only_for)
765768
)
769+
else:
770+
771+
def func_replace(x: str) -> str:
772+
return x
766773

767774
if except_for:
768775
device_type_test_bases = filter(
769-
lambda x: x.device_type not in except_for, device_type_test_bases
776+
lambda x: func_replace(x.device_type) not in except_for,
777+
device_type_test_bases,
770778
)
771779
if only_for:
772780
device_type_test_bases = filter(
773-
lambda x: x.device_type in only_for, device_type_test_bases
781+
lambda x: func_replace(x.device_type) in only_for,
782+
device_type_test_bases,
774783
)
775784

776785
return list(device_type_test_bases)

0 commit comments

Comments
 (0)