Skip to content

Commit db2b273

Browse files
kurtamohlerfacebook-github-bot
authored andcommitted
Reland: Fix CUDA device guard usage when first arg of kernel is scalar (#39956)
Summary: Reland PR #39870 Closes #38889 Pull Request resolved: #39956 Differential Revision: D22027956 Pulled By: ngimel fbshipit-source-id: e6029f450e2da3782b2d05bcc2012c19b82291da
1 parent e62d655 commit db2b273

2 files changed

Lines changed: 20 additions & 1 deletion

File tree

aten/src/ATen/native/cuda/Loops.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
9090
using arg2_t = typename traits::template arg<1>::type;
9191
auto a = iter.scalar_value<arg1_t>(1);
9292
iter.remove_operand(1);
93+
const OptionalDeviceGuard device_guard(device_of(iter.tensor(1)));
9394
gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
9495
return f(a, b);
9596
});

test/test_torch.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9547,7 +9547,7 @@ def run_test(input_):
95479547
M, N = input_.shape
95489548
input_.zero_()
95499549
for i in range(min(M, N)):
9550-
input_[i][i] = 1
9550+
input_[i][i] = 1
95519551
output1 = input_.argmax(dim=0)
95529552
output2 = input_.sum(dim=0)
95539553
for i in range(min(M, N)):
@@ -14166,6 +14166,24 @@ def test_cross_device_binary_ops(self, devices):
1416614166
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
1416714167
op(cpu_tensor, a)
1416814168

14169+
# This test ensures that a scalar Tensor can be safely used
14170+
# in a binary operation in conjuction with a Tensor on all
14171+
# available CUDA devices
14172+
@deviceCountAtLeast(2)
14173+
@onlyCUDA
14174+
def test_binary_op_scalar_device_unspecified(self, devices):
14175+
scalar_val = torch.tensor(1.)
14176+
for default_device in devices:
14177+
with torch.cuda.device(default_device):
14178+
for device in devices:
14179+
device_obj = torch.device(device)
14180+
x = torch.rand(3, device=device)
14181+
y0 = x * scalar_val
14182+
self.assertEqual(y0.device, device_obj)
14183+
y1 = scalar_val * x
14184+
self.assertEqual(y1.device, device_obj)
14185+
self.assertEqual(y0, y1)
14186+
1416914187
# Tests that CPU scalars (including zero dim tensors) can be used in
1417014188
# binary operations with CUDA tensors.
1417114189
@onlyCUDA

0 commit comments

Comments
 (0)