Skip to content

[Bug] Multi-node distributed training broken only in one-GPU-per-node setting (one-line fix) #4142

@Maxusmusti

Description

@Maxusmusti

Issue initially discovered by @Fiona-Waters

When running multi-node distributed training with each node having a single GPU (or reproducible with CUDA_VISIBLE_DEVICES=0), Unsloth incorrectly disables distributed training by patching accelerate.state.PartialState._prepare_backend to return DistributedType.NO.

This is because the check in unsloth/models/_utils.py only looks at DEVICE_COUNT (which comes from torch.cuda.device_count()), which returns the number of locally visible devices. It does not account for multi-node setups where each node may have only 1 visible GPU but WORLD_SIZE > 1.

Steps to Reproduce

  1. Set up a 2-node distributed training run, each with 1 GPU per node or set CUDA_VISIBLE_DEVICES=0 on each node
  2. Launch with torchrun:
  # Node 0
  CUDA_VISIBLE_DEVICES=0 torchrun --nnodes 2 --nproc-per-node 1 --node-rank 0 \
      --rdzv-id 117 --rdzv-endpoint 10.241.128.23:5151 train.py

  # Node 1
  CUDA_VISIBLE_DEVICES=0 torchrun --nnodes 2 --nproc-per-node 1 --node-rank 1 \
      --rdzv-id 117 --rdzv-endpoint 10.241.128.23:5151 train.py

Expected Behavior

Distributed training initializes normally across the 2 nodes.

Actual Behavior

Unsloth sees DEVICE_COUNT == 1 and patches out _prepare_backend, forcing DistributedType.NO. Distributed communication fails to initialize.

Root Cause

In https://github.com/unslothai/unsloth/blob/main/unsloth/models/_utils.py#L1476:

  if DEVICE_COUNT == 1:
      from accelerate.utils.dataclasses import DistributedType
      def _prepare_backend(self, *args, **kwargs):
          return None, DistributedType.NO
      import accelerate.state
      accelerate.state.PartialState._prepare_backend = _prepare_backend

DEVICE_COUNT is set from torch.cuda.device_count() in https://github.com/unslothai/unsloth/blob/main/unsloth/device_type.py, which only counts locally visible GPUs. There is no check for WORLD_SIZE or other torchrun environment variables that indicate a multi-node setup.

Suggested Fix

Add a WORLD_SIZE check so the patch is only applied when we are genuinely on a single device in a single-node setup:

  import os
  world_size = int(os.environ.get("WORLD_SIZE", "1"))
  if DEVICE_COUNT == 1 and world_size <= 1:
      ...

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions