Skip to content

Commit def4b9b

Browse files
committed
Test list of tensors on different devices
1 parent 7bd4b2c commit def4b9b

1 file changed

Lines changed: 98 additions & 1 deletion

File tree

test/test_foreach.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import itertools
12
import torch
23
import unittest
34
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, TEST_WITH_SLOW
45
from torch.testing._internal.common_device_type import \
5-
(instantiate_device_type_tests, dtypes, skipCUDAIfRocm, skipMeta, ops)
6+
(instantiate_device_type_tests, deviceCountAtLeast, dtypes, skipCUDAIfRocm, skipMeta, ops)
67
from torch._six import inf, nan
78
from torch.testing._internal.common_methods_invocations import foreach_unary_op_db
89

@@ -842,6 +843,102 @@ def test_add_list_slow_path(self, device, dtype):
842843
torch._foreach_add_([tensor1], [tensor2])
843844
self.assertEqual(res, [tensor1])
844845

846+
# Below three methods are test to check whether foreach ops works or not
847+
# when tensors are on different devices
848+
# but tensors with the same index are on the same device.
849+
@skipMeta
850+
@deviceCountAtLeast(2)
851+
@ops(foreach_unary_op_db)
852+
def test_unary_op_tensors_on_different_devices(self, devices, dtype, op):
853+
if 'abs' in op.ref.__name__ and dtype == torch.bool:
854+
return
855+
for dev_list in itertools.combinations(['cpu'] + devices, 2):
856+
if 'cpu' in dev_list and dtype == torch.float16:
857+
continue
858+
# devices of `tensors` are: [dev_list[0], dev_list[0], dev_list[1]]
859+
tensors = self._get_test_data(dev_list[0], dtype, 3)
860+
tensors = [t.to(dev_list[i - 1]) if i > 0 else t for i, t in enumerate(tensors)]
861+
expected = [op.ref(t) for t in tensors]
862+
actual = op.get_method()(tensors)
863+
self.assertEqual(expected, actual)
864+
if 'abs' in op.ref.__name__ and dtype in torch.testing.get_all_complex_dtypes():
865+
continue
866+
if dtype not in torch.testing.get_all_int_dtypes():
867+
op.get_inplace()(tensors)
868+
self.assertEqual(expected, tensors)
869+
870+
# Test that if foreach ops works or not
871+
# when tensors are on different devices
872+
# but tensors with the same index are on the same device.
873+
@skipMeta
874+
@deviceCountAtLeast(2)
875+
@dtypes(*torch.testing.get_all_dtypes(include_bfloat16=True))
876+
def test_bin_op_tensors_on_different_devices(self, devices, dtype):
877+
for foreach_op, foreach_op_, native_op in self.bin_ops:
878+
if dtype == torch.bool and "sub" in native_op.__name__:
879+
self.skipTest("Subtraction does not work with bool")
880+
for dev0, dev1 in itertools.combinations(['cpu'] + devices, 2):
881+
# devices of `tensors` are
882+
# `tensors1`: [dev0, dev1]
883+
# `tensors2`: [dev0, dev1]
884+
tensors1 = self._get_test_data(dev0, dtype, 2)
885+
tensors2 = self._get_test_data(dev1, dtype, 2)
886+
tmp1, tmp2 = tensors2[0], tensors1[1]
887+
tensors2[0] = tmp2
888+
tensors1[1] = tmp1
889+
890+
expected = [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
891+
actual = foreach_op(tensors1, tensors2)
892+
self.assertEqual(expected, actual)
893+
if not (dtype in torch.testing.get_all_int_dtypes() and "div" in native_op.__name__):
894+
foreach_op_(tensors1, tensors2)
895+
self.assertEqual(expected, tensors1)
896+
897+
# Test that if foreach ops works or not
898+
# when tensors are on different devices
899+
# but tensors with the same index are on the same device.
900+
@skipMeta
901+
@deviceCountAtLeast(3)
902+
@dtypes(*torch.testing.get_all_dtypes(include_bfloat16=True))
903+
def test_pointwise_op_tensors_on_different_devices(self, devices, dtype):
904+
pointwise_ops = [
905+
(torch._foreach_addcmul, torch._foreach_addcmul_, torch.addcmul),
906+
(torch._foreach_addcdiv, torch._foreach_addcdiv_, torch.addcdiv),
907+
]
908+
for foreach_op, foreach_op_, native_op in pointwise_ops:
909+
if dtype == torch.bool and "sub" in native_op.__name__:
910+
continue
911+
for dev_list in itertools.combinations(['cpu'] + devices, 3):
912+
if dtype == torch.bool:
913+
continue
914+
if 'cpu' in dev_list and dtype in (torch.float16, torch.bfloat16):
915+
continue
916+
if "addcdiv" in native_op.__name__ and dtype in torch.testing.get_all_int_dtypes():
917+
continue
918+
919+
# devices of `tensors` are as follows
920+
# tensors1: [dev0, dev1, dev2]
921+
# tensors2: [dev0, dev1, dev2]
922+
# tensors3: [dev0, dev1, dev2]
923+
dev0, dev1, dev2 = dev_list
924+
tensors1 = self._get_test_data(dev0, dtype, 3)
925+
tensors2 = self._get_test_data(dev1, dtype, 3)
926+
tensors3 = self._get_test_data(dev2, dtype, 3)
927+
tmp21, tmp31 = tensors2[0], tensors3[0]
928+
tensors2[0], tensors3[0] = tensors1[1:]
929+
tensors1[1:] = [tmp21, tmp31]
930+
tmp23 = tensors2[2]
931+
tmp32 = tensors3[1]
932+
tensors2[2] = tmp32
933+
tensors3[1] = tmp23
934+
935+
expected = [native_op(t1, t2, t3) for t1, t2, t3 in zip(tensors1, tensors2, tensors3)]
936+
actual = foreach_op(tensors1, tensors2, tensors3)
937+
self.assertEqual(expected, actual)
938+
foreach_op_(tensors1, tensors2, tensors3)
939+
self.assertEqual(expected, tensors1)
940+
941+
845942
instantiate_device_type_tests(TestForeach, globals())
846943

847944
if __name__ == '__main__':

0 commit comments

Comments
 (0)