|
| 1 | +import itertools |
1 | 2 | import torch |
2 | 3 | import unittest |
3 | 4 | from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, TEST_WITH_SLOW |
4 | 5 | from torch.testing._internal.common_device_type import \ |
5 | | - (instantiate_device_type_tests, dtypes, skipCUDAIfRocm, skipMeta, ops) |
| 6 | + (instantiate_device_type_tests, deviceCountAtLeast, dtypes, skipCUDAIfRocm, skipMeta, ops) |
6 | 7 | from torch._six import inf, nan |
7 | 8 | from torch.testing._internal.common_methods_invocations import foreach_unary_op_db |
8 | 9 |
|
@@ -842,6 +843,102 @@ def test_add_list_slow_path(self, device, dtype): |
842 | 843 | torch._foreach_add_([tensor1], [tensor2]) |
843 | 844 | self.assertEqual(res, [tensor1]) |
844 | 845 |
|
| 846 | + # Below three methods are test to check whether foreach ops works or not |
| 847 | + # when tensors are on different devices |
| 848 | + # but tensors with the same index are on the same device. |
| 849 | + @skipMeta |
| 850 | + @deviceCountAtLeast(2) |
| 851 | + @ops(foreach_unary_op_db) |
| 852 | + def test_unary_op_tensors_on_different_devices(self, devices, dtype, op): |
| 853 | + if 'abs' in op.ref.__name__ and dtype == torch.bool: |
| 854 | + return |
| 855 | + for dev_list in itertools.combinations(['cpu'] + devices, 2): |
| 856 | + if 'cpu' in dev_list and dtype == torch.float16: |
| 857 | + continue |
| 858 | + # devices of `tensors` are: [dev_list[0], dev_list[0], dev_list[1]] |
| 859 | + tensors = self._get_test_data(dev_list[0], dtype, 3) |
| 860 | + tensors = [t.to(dev_list[i - 1]) if i > 0 else t for i, t in enumerate(tensors)] |
| 861 | + expected = [op.ref(t) for t in tensors] |
| 862 | + actual = op.get_method()(tensors) |
| 863 | + self.assertEqual(expected, actual) |
| 864 | + if 'abs' in op.ref.__name__ and dtype in torch.testing.get_all_complex_dtypes(): |
| 865 | + continue |
| 866 | + if dtype not in torch.testing.get_all_int_dtypes(): |
| 867 | + op.get_inplace()(tensors) |
| 868 | + self.assertEqual(expected, tensors) |
| 869 | + |
| 870 | + # Test that if foreach ops works or not |
| 871 | + # when tensors are on different devices |
| 872 | + # but tensors with the same index are on the same device. |
| 873 | + @skipMeta |
| 874 | + @deviceCountAtLeast(2) |
| 875 | + @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=True)) |
| 876 | + def test_bin_op_tensors_on_different_devices(self, devices, dtype): |
| 877 | + for foreach_op, foreach_op_, native_op in self.bin_ops: |
| 878 | + if dtype == torch.bool and "sub" in native_op.__name__: |
| 879 | + self.skipTest("Subtraction does not work with bool") |
| 880 | + for dev0, dev1 in itertools.combinations(['cpu'] + devices, 2): |
| 881 | + # devices of `tensors` are |
| 882 | + # `tensors1`: [dev0, dev1] |
| 883 | + # `tensors2`: [dev0, dev1] |
| 884 | + tensors1 = self._get_test_data(dev0, dtype, 2) |
| 885 | + tensors2 = self._get_test_data(dev1, dtype, 2) |
| 886 | + tmp1, tmp2 = tensors2[0], tensors1[1] |
| 887 | + tensors2[0] = tmp2 |
| 888 | + tensors1[1] = tmp1 |
| 889 | + |
| 890 | + expected = [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)] |
| 891 | + actual = foreach_op(tensors1, tensors2) |
| 892 | + self.assertEqual(expected, actual) |
| 893 | + if not (dtype in torch.testing.get_all_int_dtypes() and "div" in native_op.__name__): |
| 894 | + foreach_op_(tensors1, tensors2) |
| 895 | + self.assertEqual(expected, tensors1) |
| 896 | + |
| 897 | + # Test that if foreach ops works or not |
| 898 | + # when tensors are on different devices |
| 899 | + # but tensors with the same index are on the same device. |
| 900 | + @skipMeta |
| 901 | + @deviceCountAtLeast(3) |
| 902 | + @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=True)) |
| 903 | + def test_pointwise_op_tensors_on_different_devices(self, devices, dtype): |
| 904 | + pointwise_ops = [ |
| 905 | + (torch._foreach_addcmul, torch._foreach_addcmul_, torch.addcmul), |
| 906 | + (torch._foreach_addcdiv, torch._foreach_addcdiv_, torch.addcdiv), |
| 907 | + ] |
| 908 | + for foreach_op, foreach_op_, native_op in pointwise_ops: |
| 909 | + if dtype == torch.bool and "sub" in native_op.__name__: |
| 910 | + continue |
| 911 | + for dev_list in itertools.combinations(['cpu'] + devices, 3): |
| 912 | + if dtype == torch.bool: |
| 913 | + continue |
| 914 | + if 'cpu' in dev_list and dtype in (torch.float16, torch.bfloat16): |
| 915 | + continue |
| 916 | + if "addcdiv" in native_op.__name__ and dtype in torch.testing.get_all_int_dtypes(): |
| 917 | + continue |
| 918 | + |
| 919 | + # devices of `tensors` are as follows |
| 920 | + # tensors1: [dev0, dev1, dev2] |
| 921 | + # tensors2: [dev0, dev1, dev2] |
| 922 | + # tensors3: [dev0, dev1, dev2] |
| 923 | + dev0, dev1, dev2 = dev_list |
| 924 | + tensors1 = self._get_test_data(dev0, dtype, 3) |
| 925 | + tensors2 = self._get_test_data(dev1, dtype, 3) |
| 926 | + tensors3 = self._get_test_data(dev2, dtype, 3) |
| 927 | + tmp21, tmp31 = tensors2[0], tensors3[0] |
| 928 | + tensors2[0], tensors3[0] = tensors1[1:] |
| 929 | + tensors1[1:] = [tmp21, tmp31] |
| 930 | + tmp23 = tensors2[2] |
| 931 | + tmp32 = tensors3[1] |
| 932 | + tensors2[2] = tmp32 |
| 933 | + tensors3[1] = tmp23 |
| 934 | + |
| 935 | + expected = [native_op(t1, t2, t3) for t1, t2, t3 in zip(tensors1, tensors2, tensors3)] |
| 936 | + actual = foreach_op(tensors1, tensors2, tensors3) |
| 937 | + self.assertEqual(expected, actual) |
| 938 | + foreach_op_(tensors1, tensors2, tensors3) |
| 939 | + self.assertEqual(expected, tensors1) |
| 940 | + |
| 941 | + |
845 | 942 | instantiate_device_type_tests(TestForeach, globals()) |
846 | 943 |
|
847 | 944 | if __name__ == '__main__': |
|
0 commit comments