Skip to content

P2P send recv test gives errors #8074

@ajayvohra2005

Description

@ajayvohra2005

🐛 Bug

Trying to test simple xm.send and xm.recv gives error.

To Reproduce

Steps to reproduce the behavior:

  1. Run test code below
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend as xb
import torch.distributed


def test_p2p():
    torch.distributed.init_process_group(backend="xla", init_method="xla://")       

    rank = torch.distributed.get_rank()
    device = xm.xla_device()
    tensor = torch.arange(2, dtype=torch.float32, device=device) + 1 + 2 * rank
    
    next_map = { 0:4, 1:5, 2:6, 3:7}
    prev_map = { 4:0, 5:1, 6:2, 7:3}
    
    torch.distributed.barrier()
    
    if rank < 4:
        print(f"send at rank: {rank}, to:{next_map[rank]}, tensor: {tensor}")
        xm.send(tensor, next_map[rank])

    torch.distributed.barrier()
    
    recv_buffer = torch.zeros_like(tensor)
    if rank >= 4:
        print(f"recv at rank: {rank}, from:{prev_map[rank]} ... ")
        xm.recv(recv_buffer, prev_map[rank])
        print(f"recv at rank: {rank}, from:{prev_map[rank]}, recv_buffer: {recv_buffer}")
    
    torch.distributed.barrier()
    
    torch.distributed.destroy_process_group()
    
if __name__ == "__main__":
    test_p2p()

Expected behavior

Test code should run without errors

Log output showing error

W0930 17:26:04.156000 140373368411968 torch/distributed/run.py:757] 
W0930 17:26:04.156000 140373368411968 torch/distributed/run.py:757] *****************************************
W0930 17:26:04.156000 140373368411968 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0930 17:26:04.156000 140373368411968 torch/distributed/run.py:757] *****************************************
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1727717166.852975  184724 coordination_service.cc:365] Initializing CoordinationService
I0000 00:00:1727717166.857299  186366 coordination_service.cc:579] /job:jax_worker/replica:0/task:0 has connected to coordination service. Incarnation: 1331949604963884975
I0000 00:00:1727717166.857361  186366 coordination_service.cc:541] Waiting for 7/8 tasks to connect.
I0000 00:00:1727717166.857385  186366 coordination_service.cc:544] Example stragglers:
/job:jax_worker/replica:0/task:2
/job:jax_worker/replica:0/task:4
/job:jax_worker/replica:0/task:1
I0000 00:00:1727717166.857775  186366 coordination_service.cc:579] /job:jax_worker/replica:0/task:2 has connected to coordination service. Incarnation: 12293608533655025016
I0000 00:00:1727717166.857797  186366 coordination_service.cc:541] Waiting for 6/8 tasks to connect.
I0000 00:00:1727717166.857801  186366 coordination_service.cc:544] Example stragglers:
/job:jax_worker/replica:0/task:4
/job:jax_worker/replica:0/task:1
/job:jax_worker/replica:0/task:5
I0000 00:00:1727717166.858749  186366 coordination_service.cc:579] /job:jax_worker/replica:0/task:3 has connected to coordination service. Incarnation: 6693673612313623099
I0000 00:00:1727717166.858794  186366 coordination_service.cc:541] Waiting for 5/8 tasks to connect.
I0000 00:00:1727717166.858811  186366 coordination_service.cc:544] Example stragglers:
/job:jax_worker/replica:0/task:4
/job:jax_worker/replica:0/task:1
/job:jax_worker/replica:0/task:5
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1727717166.859230  184726 coordination_service_agent.cc:303] Coordination agent has successfully connected.
I0000 00:00:1727717166.859488  184724 coordination_service_agent.cc:303] Coordination agent has successfully connected.
I0000 00:00:1727717166.860198  186366 coordination_service.cc:579] /job:jax_worker/replica:0/task:5 has connected to coordination service. Incarnation: 14604579324477984403
I0000 00:00:1727717166.860218  186366 coordination_service.cc:541] Waiting for 4/8 tasks to connect.
I0000 00:00:1727717166.860223  186366 coordination_service.cc:544] Example stragglers:
/job:jax_worker/replica:0/task:4
/job:jax_worker/replica:0/task:1
/job:jax_worker/replica:0/task:6
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1727717166.860551  184727 coordination_service_agent.cc:303] Coordination agent has successfully connected.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1727717166.861981  184729 coordination_service_agent.cc:303] Coordination agent has successfully connected.
I0000 00:00:1727717166.874591  186366 coordination_service.cc:579] /job:jax_worker/replica:0/task:6 has connected to coordination service. Incarnation: 1852783379941685202
I0000 00:00:1727717166.874617  186366 coordination_service.cc:541] Waiting for 3/8 tasks to connect.
I0000 00:00:1727717166.874622  186366 coordination_service.cc:544] Example stragglers:
/job:jax_worker/replica:0/task:4
/job:jax_worker/replica:0/task:1
/job:jax_worker/replica:0/task:7
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1727717166.876173  186366 coordination_service.cc:579] /job:jax_worker/replica:0/task:4 has connected to coordination service. Incarnation: 3757167233025618777
I0000 00:00:1727717166.876114  184730 coordination_service_agent.cc:303] Coordination agent has successfully connected.
I0000 00:00:1727717166.876193  186366 coordination_service.cc:541] Waiting for 2/8 tasks to connect.
I0000 00:00:1727717166.876203  186366 coordination_service.cc:544] Example stragglers:
/job:jax_worker/replica:0/task:1
/job:jax_worker/replica:0/task:7
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1727717166.878104  184728 coordination_service_agent.cc:303] Coordination agent has successfully connected.
I0000 00:00:1727717167.853755  186366 coordination_service.cc:579] /job:jax_worker/replica:0/task:1 has connected to coordination service. Incarnation: 2093214038750281383
I0000 00:00:1727717167.853780  186366 coordination_service.cc:541] Waiting for 1/8 tasks to connect.
I0000 00:00:1727717167.853784  186366 coordination_service.cc:544] Example stragglers:
/job:jax_worker/replica:0/task:7
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1727717167.853903  184725 coordination_service_agent.cc:303] Coordination agent has successfully connected.
I0000 00:00:1727717167.854359  186366 coordination_service.cc:579] /job:jax_worker/replica:0/task:7 has connected to coordination service. Incarnation: 17786123015033867819
I0000 00:00:1727717167.854374  186366 coordination_service.cc:541] Waiting for 0/8 tasks to connect.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1727717167.854489  184731 coordination_service_agent.cc:303] Coordination agent has successfully connected.
I0000 00:00:1727717169.854387  186856 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1727717169.873099  184727 service.cc:145] XLA service 0x56207f623d10 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1727717169.873143  184727 service.cc:153]   StreamExecutor device (0): NVIDIA A10G, Compute Capability 8.6
I0000 00:00:1727717169.875642  184727 se_gpu_pjrt_client.cc:853] Using BFC allocator.
I0000 00:00:1727717169.878063  184727 gpu_helpers.cc:107] XLA backend allocating 17787371520 bytes on device 3 for BFCAllocator.
I0000 00:00:1727717169.878535  184727 gpu_helpers.cc:147] XLA backend will use up to 5929123840 bytes on device 3 for CollectiveBFCAllocator.
I0000 00:00:1727717169.878706  184727 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1727717169.907579  186883 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1727717169.912426  184724 service.cc:145] XLA service 0x55f9dc3d8670 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1727717169.912561  184724 service.cc:153]   StreamExecutor device (0): NVIDIA A10G, Compute Capability 8.6
I0000 00:00:1727717169.914960  184724 se_gpu_pjrt_client.cc:853] Using BFC allocator.
I0000 00:00:1727717169.915052  184724 gpu_helpers.cc:107] XLA backend allocating 17787371520 bytes on device 0 for BFCAllocator.
I0000 00:00:1727717169.915148  184724 gpu_helpers.cc:147] XLA backend will use up to 5929123840 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1727717169.915343  184724 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1727717169.934043  186902 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1727717169.940388  184726 service.cc:145] XLA service 0x56185428cba0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1727717169.940470  184726 service.cc:153]   StreamExecutor device (0): NVIDIA A10G, Compute Capability 8.6
I0000 00:00:1727717169.941640  184726 se_gpu_pjrt_client.cc:853] Using BFC allocator.
I0000 00:00:1727717169.941711  184726 gpu_helpers.cc:107] XLA backend allocating 17787371520 bytes on device 2 for BFCAllocator.
I0000 00:00:1727717169.942866  184726 gpu_helpers.cc:147] XLA backend will use up to 5929123840 bytes on device 2 for CollectiveBFCAllocator.
I0000 00:00:1727717169.943044  184726 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1727717169.943217  186910 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1727717169.946296  186865 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1727717169.946400  186892 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1727717169.946771  184730 service.cc:145] XLA service 0x55f5ebc2e010 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1727717169.946812  184730 service.cc:153]   StreamExecutor device (0): NVIDIA A10G, Compute Capability 8.6
I0000 00:00:1727717169.949711  184730 se_gpu_pjrt_client.cc:853] Using BFC allocator.
I0000 00:00:1727717169.949764  184725 service.cc:145] XLA service 0x55c22b5706d0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1727717169.949802  184725 service.cc:153]   StreamExecutor device (0): NVIDIA A10G, Compute Capability 8.6
I0000 00:00:1727717169.952002  184730 gpu_helpers.cc:107] XLA backend allocating 17787371520 bytes on device 6 for BFCAllocator.
I0000 00:00:1727717169.953677  184730 gpu_helpers.cc:147] XLA backend will use up to 5929123840 bytes on device 6 for CollectiveBFCAllocator.
I0000 00:00:1727717169.953861  184730 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1727717169.954342  184725 se_gpu_pjrt_client.cc:853] Using BFC allocator.
I0000 00:00:1727717169.954509  184728 service.cc:145] XLA service 0x55e96db73d80 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1727717169.954565  184728 service.cc:153]   StreamExecutor device (0): NVIDIA A10G, Compute Capability 8.6
I0000 00:00:1727717169.955257  184725 gpu_helpers.cc:107] XLA backend allocating 17787371520 bytes on device 1 for BFCAllocator.
I0000 00:00:1727717169.955619  184725 gpu_helpers.cc:147] XLA backend will use up to 5929123840 bytes on device 1 for CollectiveBFCAllocator.
I0000 00:00:1727717169.955781  184725 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1727717169.956115  184728 se_gpu_pjrt_client.cc:853] Using BFC allocator.
I0000 00:00:1727717169.956204  184728 gpu_helpers.cc:107] XLA backend allocating 17787371520 bytes on device 4 for BFCAllocator.
I0000 00:00:1727717169.956527  184728 gpu_helpers.cc:147] XLA backend will use up to 5929123840 bytes on device 4 for CollectiveBFCAllocator.
I0000 00:00:1727717169.956749  184728 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1727717169.965518  186874 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1727717169.965584  186919 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1727717169.967922  184731 service.cc:145] XLA service 0x5582b5978080 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1727717169.967972  184731 service.cc:153]   StreamExecutor device (0): NVIDIA A10G, Compute Capability 8.6
I0000 00:00:1727717169.967999  184729 service.cc:145] XLA service 0x55a094abf3a0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1727717169.968059  184729 service.cc:153]   StreamExecutor device (0): NVIDIA A10G, Compute Capability 8.6
I0000 00:00:1727717169.968577  184731 se_gpu_pjrt_client.cc:853] Using BFC allocator.
I0000 00:00:1727717169.968634  184731 gpu_helpers.cc:107] XLA backend allocating 17787371520 bytes on device 7 for BFCAllocator.
I0000 00:00:1727717169.968678  184731 gpu_helpers.cc:147] XLA backend will use up to 5929123840 bytes on device 7 for CollectiveBFCAllocator.
I0000 00:00:1727717169.968749  184729 se_gpu_pjrt_client.cc:853] Using BFC allocator.
I0000 00:00:1727717169.968807  184729 gpu_helpers.cc:107] XLA backend allocating 17787371520 bytes on device 5 for BFCAllocator.
I0000 00:00:1727717169.968848  184731 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1727717169.968856  184729 gpu_helpers.cc:147] XLA backend will use up to 5929123840 bytes on device 5 for CollectiveBFCAllocator.
I0000 00:00:1727717169.969072  184729 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
recv at rank: 7, from:3 ... 
recv at rank: 6, from:2 ... 
recv at rank: 4, from:0 ... 
recv at rank: 5, from:1 ... 
F0000 00:00:1727717170.470668  184731 shape.h:207] Check failed: has_layout() element_type: TUPLE tuple_shapes { element_type: F32 dimensions: 2 layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } is_dynamic_dimension: false } tuple_shapes { element_type: U32 layout { tail_padding_alignment_in_elements: 1 } } tuple_shapes { element_type: TOKEN }
F0000 00:00:1727717170.483437  184730 shape.h:207] Check failed: has_layout() element_type: TUPLE tuple_shapes { element_type: F32 dimensions: 2 layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } is_dynamic_dimension: false } tuple_shapes { element_type: U32 layout { tail_padding_alignment_in_elements: 1 } } tuple_shapes { element_type: TOKEN }
F0000 00:00:1727717170.485669  184728 shape.h:207] Check failed: has_layout() element_type: TUPLE tuple_shapes { element_type: F32 dimensions: 2 layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } is_dynamic_dimension: false } tuple_shapes { element_type: U32 layout { tail_padding_alignment_in_elements: 1 } } tuple_shapes { element_type: TOKEN }
F0000 00:00:1727717170.488110  184729 shape.h:207] Check failed: has_layout() element_type: TUPLE tuple_shapes { element_type: F32 dimensions: 2 layout { minor_to_major: 0 tail_padding_alignment_in_elements: 1 } is_dynamic_dimension: false } tuple_shapes { element_type: U32 layout { tail_padding_alignment_in_elements: 1 } } tuple_shapes { element_type: TOKEN }
send at rank: 0, to:4, tensor: tensor([1., 2.], device='xla:0')
send at rank: 3, to:7, tensor: tensor([7., 8.], device='xla:0')
send at rank: 2, to:6, tensor: tensor([5., 6.], device='xla:0')
send at rank: 1, to:5, tensor: tensor([3., 4.], device='xla:0')
I0000 00:00:1727717171.099255  184724 coordination_service_agent.cc:472] Coordination agent has initiated Shutdown().
I0000 00:00:1727717171.180823  184727 coordination_service_agent.cc:472] Coordination agent has initiated Shutdown().
I0000 00:00:1727717171.395648  184725 coordination_service_agent.cc:472] Coordination agent has initiated Shutdown().
I0000 00:00:1727717171.431319  184726 coordination_service_agent.cc:472] Coordination agent has initiated Shutdown().
*** Check failure stack trace: ***
    @     0x7f37825c8159  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7f37781a358b  xla::Shape::layout()
    @     0x7f37789eb14a  xla::gpu::IrEmitterUnnested::EmitRecvThunk()
    @     0x7f37789f6900  xla::gpu::IrEmitterUnnested::EmitHloInstruction()
    @     0x7f37789f9a80  xla::gpu::IrEmitterUnnested::EmitHloComputation()
    @     0x7f37787ddb6e  xla::gpu::CompileModuleToLlvmIr()
    @     0x7f37787c0756  xla::gpu::GpuCompiler::CompileToBackendResult()
    @     0x7f37787c30f5  xla::gpu::GpuCompiler::RunBackend()
    @     0x7f3778783839  xla::Service::BuildExecutable()
    @     0x7f3778771533  xla::LocalService::CompileExecutables()
    @     0x7f37787636b7  xla::LocalClient::Compile()
    @     0x7f377873f44b  xla::PjRtStreamExecutorClient::Compile()
    @     0x7f377871226f  xla::StreamExecutorGpuClient::Compile()
    @     0x7f3778561177  torch_xla::runtime::PjRtComputationClient::Compile()
    @     0x7f3778357623  torch_xla::XLAGraphExecutor::Compile()
    @     0x7f377835953f  torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal()
    @     0x7f3778359b41  torch_xla::XLAGraphExecutor::SyncTensorsGraph()
    @     0x7f37781fc98a  torch_xla::XLATensor::ApplyPendingGraph()
    @     0x7f37782008ed  torch_xla::XLATensor::GetXlaData()
    @     0x7f3778200a5d  torch_xla::XLATensor::ToTensor()
    @     0x7f3778144288  torch_xla::XLANativeFunctions::_to_copy()
    @     0x7f3778392ab5  c10::impl::wrap_kernel_functor_unboxed_<>::call()
    @     0x7f37983d3a65  at::_ops::_to_copy::redispatch()
*** Check failure stack trace: ***
    @     0x7fe2a6c36159  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7fe29c81158b  xla::Shape::layout()
    @     0x7fe29d05914a  xla::gpu::IrEmitterUnnested::EmitRecvThunk()
    @     0x7fe29d064900  xla::gpu::IrEmitterUnnested::EmitHloInstruction()
    @     0x7fe29d067a80  xla::gpu::IrEmitterUnnested::EmitHloComputation()
    @     0x7fe29ce4bb6e  xla::gpu::CompileModuleToLlvmIr()
    @     0x7fe29ce2e756  xla::gpu::GpuCompiler::CompileToBackendResult()
    @     0x7fe29ce310f5  xla::gpu::GpuCompiler::RunBackend()
    @     0x7fe29cdf1839  xla::Service::BuildExecutable()
    @     0x7fe29cddf533  xla::LocalService::CompileExecutables()
    @     0x7fe29cdd16b7  xla::LocalClient::Compile()
    @     0x7fe29cdad44b  xla::PjRtStreamExecutorClient::Compile()
    @     0x7fe29cd8026f  xla::StreamExecutorGpuClient::Compile()
    @     0x7fe29cbcf177  torch_xla::runtime::PjRtComputationClient::Compile()
    @     0x7fe29c9c5623  torch_xla::XLAGraphExecutor::Compile()
    @     0x7fe29c9c753f  torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal()
    @     0x7fe29c9c7b41  torch_xla::XLAGraphExecutor::SyncTensorsGraph()
    @     0x7fe29c86a98a  torch_xla::XLATensor::ApplyPendingGraph()
    @     0x7fe29c86e8ed  torch_xla::XLATensor::GetXlaData()
    @     0x7fe29c86ea5d  torch_xla::XLATensor::ToTensor()
    @     0x7fe29c7b2288  torch_xla::XLANativeFunctions::_to_copy()
    @     0x7fe29ca00ab5  c10::impl::wrap_kernel_functor_unboxed_<>::call()
    @     0x7fe2bca41a65  at::_ops::_to_copy::redispatch()
*** Check failure stack trace: ***
    @     0x7f4879690159  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7f486f26b58b  xla::Shape::layout()
    @     0x7f486fab314a  xla::gpu::IrEmitterUnnested::EmitRecvThunk()
    @     0x7f486fabe900  xla::gpu::IrEmitterUnnested::EmitHloInstruction()
    @     0x7f486fac1a80  xla::gpu::IrEmitterUnnested::EmitHloComputation()
    @     0x7f486f8a5b6e  xla::gpu::CompileModuleToLlvmIr()
    @     0x7f486f888756  xla::gpu::GpuCompiler::CompileToBackendResult()
    @     0x7f486f88b0f5  xla::gpu::GpuCompiler::RunBackend()
    @     0x7f486f84b839  xla::Service::BuildExecutable()
    @     0x7f486f839533  xla::LocalService::CompileExecutables()
    @     0x7f486f82b6b7  xla::LocalClient::Compile()
    @     0x7f486f80744b  xla::PjRtStreamExecutorClient::Compile()
    @     0x7f486f7da26f  xla::StreamExecutorGpuClient::Compile()
    @     0x7f486f629177  torch_xla::runtime::PjRtComputationClient::Compile()
    @     0x7f486f41f623  torch_xla::XLAGraphExecutor::Compile()
    @     0x7f486f42153f  torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal()
    @     0x7f486f421b41  torch_xla::XLAGraphExecutor::SyncTensorsGraph()
    @     0x7f486f2c498a  torch_xla::XLATensor::ApplyPendingGraph()
    @     0x7f486f2c88ed  torch_xla::XLATensor::GetXlaData()
    @     0x7f486f2c8a5d  torch_xla::XLATensor::ToTensor()
    @     0x7f486f20c288  torch_xla::XLANativeFunctions::_to_copy()
    @     0x7f486f45aab5  c10::impl::wrap_kernel_functor_unboxed_<>::call()
    @     0x7f488f49ba65  at::_ops::_to_copy::redispatch()
*** Check failure stack trace: ***
    @     0x7fd05d889159  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7fd05346458b  xla::Shape::layout()
    @     0x7fd053cac14a  xla::gpu::IrEmitterUnnested::EmitRecvThunk()
    @     0x7fd053cb7900  xla::gpu::IrEmitterUnnested::EmitHloInstruction()
    @     0x7fd053cbaa80  xla::gpu::IrEmitterUnnested::EmitHloComputation()
    @     0x7fd053a9eb6e  xla::gpu::CompileModuleToLlvmIr()
    @     0x7fd053a81756  xla::gpu::GpuCompiler::CompileToBackendResult()
    @     0x7fd053a840f5  xla::gpu::GpuCompiler::RunBackend()
    @     0x7fd053a44839  xla::Service::BuildExecutable()
    @     0x7fd053a32533  xla::LocalService::CompileExecutables()
    @     0x7fd053a246b7  xla::LocalClient::Compile()
    @     0x7fd053a0044b  xla::PjRtStreamExecutorClient::Compile()
    @     0x7fd0539d326f  xla::StreamExecutorGpuClient::Compile()
    @     0x7fd053822177  torch_xla::runtime::PjRtComputationClient::Compile()
    @     0x7fd053618623  torch_xla::XLAGraphExecutor::Compile()
    @     0x7fd05361a53f  torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal()
    @     0x7fd05361ab41  torch_xla::XLAGraphExecutor::SyncTensorsGraph()
    @     0x7fd0534bd98a  torch_xla::XLATensor::ApplyPendingGraph()
    @     0x7fd0534c18ed  torch_xla::XLATensor::GetXlaData()
    @     0x7fd0534c1a5d  torch_xla::XLATensor::ToTensor()
    @     0x7fd053405288  torch_xla::XLANativeFunctions::_to_copy()
    @     0x7fd053653ab5  c10::impl::wrap_kernel_functor_unboxed_<>::call()
    @     0x7fd073694a65  at::_ops::_to_copy::redispatch()

Environment

Docker image

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_cuda_12.1

Nvidia GPUs

nvidia-smi
Wed Sep 25 22:47:32 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A10G                    Off |   00000000:00:16.0 Off |                    0 |
|  0%   31C    P0             61W /  300W |    1356MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A10G                    Off |   00000000:00:17.0 Off |                    0 |
|  0%   27C    P8             15W /  300W |     120MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A10G                    Off |   00000000:00:18.0 Off |                    0 |
|  0%   27C    P8             16W /  300W |     120MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A10G                    Off |   00000000:00:19.0 Off |                    0 |
|  0%   26C    P8             15W /  300W |     120MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA A10G                    Off |   00000000:00:1A.0 Off |                    0 |
|  0%   26C    P8             16W /  300W |     120MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA A10G                    Off |   00000000:00:1B.0 Off |                    0 |
|  0%   26C    P8             16W /  300W |     120MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA A10G                    Off |   00000000:00:1C.0 Off |                    0 |
|  0%   26C    P8             15W /  300W |     120MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA A10G                    Off |   00000000:00:1D.0 Off |                    0 |
|  0%   26C    P8             16W /  300W |     120MiB /  23028MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+

OS

6.2.0-1017-aws #17~22.04.1-Ubuntu SMP Fri Nov 17 21:07:13 UTC 2023 x86_64 GNU/Linux

Metadata

Metadata

Labels

distributedSPMD and other distributed things.usabilityBugs/features related to improving the usability of PyTorch/XLA

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions