Skip to content

A large number of Tensors (>8000) in the graph will trigger an spmd sharding error #7161

@mars1248

Description

@mars1248

🐛 Bug

To Reproduce

Steps to reproduce the behavior:
test.sh

rm -rf ./hlo_logs
#export XLA_FLAGS="--xla_gpu_enable_async_all_gather=true \

#export XLA_FLAGS="--xla_dump_to=./hlo_logs"
export XLA_FLAGS="--xla_dump_to=./hlo_logs \
    --xla_gpu_enable_analytical_latency_estimator=true \
    --xla_cpu_enable_fast_math=false \
    --xla_gpu_simplify_all_fp_conversions=false \
    --xla_gpu_force_compilation_parallelism=64  \
    --xla_gpu_enable_pipelined_collectives=true \
    --xla_gpu_enable_pipelined_all_reduce=true \
    --xla_gpu_enable_async_collectives=true \
    --xla_disable_hlo_passes=post-scheduling-passes,gpu-schedule-postprocessing \
    --xla_gpu_enable_triton_gemm=false \
"

export PJRT_ALLOCATOR_PREALLOCATE=false
export PJRT_ALLOCATOR_FRACTION=0.75
export PJRT_ALLOCATOR_CUDA_ASYNC=false
export PT_XLA_DEBUG=1
#export TF_CPP_MIN_LOG_LEVEL=0
#export TF_CPP_VMODULE="lazy_graph_executor=4,xla_graph_executor=5,nccl_collective_thunk=5,gpu_executable=5,gpu_compiler=5,service=5,collectives=5,xla_graph_executor=5,pjrt_computation_client=5,pjrt_stream_executor_client=5"
export XLA_IR_DEBUG=1
export XLA_HLO_DEBUG=1
export XLA_SAVE_TENSORS_FILE=debug.txt
#XLA_USE_SPMD=1 \
export GPU_NUM_DEVICES=8 \
export PJRT_DEVICE=CUDA \
#CUDA_VISIBLE_DEVICES=1,2,3,4 \
#python test_activation_local.py
python test_multi_param_layer.py

test_multi_param_layer.py

from typing import Dict, List, Optional, Tuple, Union
import torch
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharding import Mesh
from torch_xla.amp import autocast, GradScaler
import numpy as np
import torch.optim as optim
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
import transformers
import torch_xla.debug.profiler as xp
from torch_xla.amp import syncfree
import time
import os
import sys
import math
from torch import nn
from torch.nn import Linear
from torch.autograd import Function
import torch.nn.functional as F
from torch.optim.adamw import AdamW

#device = "cuda"
#device = xm.xla_device()
xr.use_spmd()
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
# mesh shape will be (2,2) in this example
mesh_shape = (num_devices // 1, 1)

#mesh_shape = (2, None)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('fsdp', 'replica'))
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layers = torch.nn.ModuleList([nn.Linear(8, 8) for _ in range(400)])
        
            
    def forward(self, xt1):
        for layer in self.layers:
            xt1 = layer(xt1)
        return xt1

my_model = MyModel().to(device)

my_model = FSDPv2(my_model, mesh)
optimizer = syncfree.AdamW(my_model.parameters(), lr=0.01)
#optimizer = AdamW(my_model.parameters(), lr=0.01)
# loss = my_model(hidden_states.to(device)).sum()
# loss.backward()
# optimizer.step()
# print(loss)
t1 = torch.randn(8, 8)
for i in range(2):
    optimizer.zero_grad()
    ans = []
    partition_spec = [None] * len(t1.shape)
    partition_spec[0] = "fsdp"
    spec = xs.ShardingSpec(mesh, partition_spec)
    xt1 = xm.send_cpu_data_to_device(t1, xm.xla_device(), input_sharding=spec)[0]
    # ans.append(xt1)
    #xt1 = t1.to(device)
    #loss = my_model(xt1).sum()
    #print(loss)
    with autocast(xm.xla_device(), dtype=torch.bfloat16):
        loss = my_model(xt1).sum()
    # print(torch_xla._XLAC._get_xla_tensors_text([loss]))
    loss.backward()
    #optimizer.step()
    found_inf = torch.isnan(loss).to(torch.float32)
    optimizer.step(found_inf=found_inf)
    # xm.optimizer_step(optimizer)
    xm.mark_step()
  1. sh test.sh

Expected behavior

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CUDA
  • torch_xla version: latest

In the preceding example, if you change the number of linear to 10, it will work, but if you change it to 400, you will get an error. I observed that on the second compilation, all the input tensors were compressed into a tuple, and the sharding information was lost after the after compile optimization

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions