@@ -169,6 +169,7 @@ def rebuild_cuda_tensor(
169169 event_handle ,
170170 event_sync_required ,
171171):
172+ storage_device = _device_from_uuid (storage_device )
172173 # If storage_handle is None, storage points to nullptr.
173174 if storage_handle is None or storage_size_bytes == 0 :
174175 storage = storage_cls (0 , dtype = dtype , device = storage_device , _internal = True )
@@ -365,7 +366,7 @@ def reduce_tensor(tensor):
365366 tensor_offset , # tensor offset in its storage
366367 type (storage ),
367368 tensor .dtype ,
368- device ,
369+ _device_to_uuid ( device ) ,
369370 handle , # identifier which CUDA allocation is the storage in.
370371 storage_size_bytes , # size(in bytes) of the storage
371372 storage_offset_bytes , # offset(in bytes) of the storage in the CUDA allocation
@@ -645,3 +646,14 @@ def init_reductions():
645646 from torch .nn .parameter import Parameter
646647
647648 reduction .register (Parameter , reduce_tensor )
649+
650+
651+ def _device_to_uuid (device ):
652+ return str (torch .cuda .get_device_properties (device ).uuid )
653+
654+
655+ def _device_from_uuid (device_uuid ):
656+ for device in range (torch .cuda .device_count ()):
657+ if str (torch .cuda .get_device_properties (device ).uuid ) == device_uuid :
658+ return device
659+ raise Exception ("Invalid device_uuid=" + device_uuid )
0 commit comments