Skip to content

Commit 98aef5c

Browse files
committed
more
1 parent 1e37e5b commit 98aef5c

1 file changed

Lines changed: 13 additions & 1 deletion

File tree

torch/multiprocessing/reductions.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def rebuild_cuda_tensor(
169169
event_handle,
170170
event_sync_required,
171171
):
172+
storage_device = _device_from_uuid(storage_device)
172173
# If storage_handle is None, storage points to nullptr.
173174
if storage_handle is None or storage_size_bytes == 0:
174175
storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True)
@@ -365,7 +366,7 @@ def reduce_tensor(tensor):
365366
tensor_offset, # tensor offset in its storage
366367
type(storage),
367368
tensor.dtype,
368-
device,
369+
_device_to_uuid(device),
369370
handle, # identifier which CUDA allocation is the storage in.
370371
storage_size_bytes, # size(in bytes) of the storage
371372
storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation
@@ -645,3 +646,14 @@ def init_reductions():
645646
from torch.nn.parameter import Parameter
646647

647648
reduction.register(Parameter, reduce_tensor)
649+
650+
651+
def _device_to_uuid(device):
652+
return str(torch.cuda.get_device_properties(device).uuid)
653+
654+
655+
def _device_from_uuid(device_uuid):
656+
for device in range(torch.cuda.device_count()):
657+
if str(torch.cuda.get_device_properties(device).uuid) == device_uuid:
658+
return device
659+
raise Exception("Invalid device_uuid=" + device_uuid)

0 commit comments

Comments
 (0)