评价此页

使用 PyTorch 和 Ray Train 进行大规模分布式训练#

作者: Ricardo Decal

本教程展示了如何利用 Ray Train 和 Ray Data 将 PyTorch 训练任务分发到多个 GPU 上,从而实现可扩展且生产就绪的模型训练。

你将学习如何
  • 使用 PyTorch 和 Hugging Face Transformers 预训练一个 GPT-2(约 1.24 亿参数)语言模型。

  • 通过 Ray Train 以最少的代码更改将训练任务分发到多个 GPU 上。

  • 利用 Ray Data 的分布式工作节点从 Hugging Face 数据集流式传输训练数据。

  • 保存和加载分布式检查点(checkpoint)。

  • 以最少的代码更改从单节点扩展到多节点集群。

  • 通过异构集群优化成本和性能。

  • 使用 Ray 仪表板监控训练过程。

先决条件
  • PyTorch v2.9+。

  • Ray Train (ray[train]) v2.52.1+。

  • tiktoken, datasets, 以及 transformers (Hugging Face)。

  • 建议使用一个或多个 GPU,但非必需。本教程在 g4dn.12xlarge 实例上进行过测试,该实例配有 4 个 NVIDIA T4 GPU(每个 GPU 16GB 内存)。

Ray Train 是一个用于分布式深度学习的可扩展框架。Ray Train 构建于 Ray 之上,后者是一个用于扩展 AI 和 Python 应用程序的统一框架,能够简化分布式计算的复杂性。Ray 同时也是开源项目,且是 PyTorch 基金会的一部分。

Ray Train 使你能够从单个 GPU 扩展到数百个 GPU,而无需重写训练循环。结合用于流式数据摄入的 Ray Data,你将获得一个端到端的分布式训练流水线,它涵盖了数据加载、分片、梯度同步、检查点保存和容错处理。

设置#

要安装依赖项,请运行 pip install "ray[train]" torch tiktoken datasets transformers

然后,导入所需的库

import os
import tempfile
import time

import numpy as np
import ray
import ray.train
import tiktoken
import torch
from datasets import load_dataset
from ray.train import CheckpointConfig, RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
from transformers import GPT2Config, GPT2LMHeadModel

# Enable smoke test to run this tutorial quickly.
SMOKE_TEST = True

# Reduce Ray Data verbosity
ray.data.DataContext.get_current().enable_progress_bars = False
ray.data.DataContext.get_current().print_on_execution_start = False

使用 Ray Data 加载数据集#

本教程使用 Wikitext-103 数据集,该数据集收集了维基百科上超过 1 亿个来自经过验证的优质和特色文章的词元(tokens)。

ray.data.from_huggingface() 函数将 Hugging Face 数据集转换为 Ray Dataset,从而在所有可用节点上实现分布式流式传输和预处理。

hf_ds = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1")
train_ds = ray.data.from_huggingface(hf_ds["train"])
val_ds = ray.data.from_huggingface(hf_ds["validation"])

# Limit dataset size for fast iteration during smoke tests.
if SMOKE_TEST:
    train_ds = train_ds.limit(2500)
    val_ds = val_ds.limit(2500)

print(f"Dataset schema:\n{train_ds.schema()}")
Downloading readme: 0.00B [00:00, ?B/s]
Downloading readme: 10.5kB [00:00, 43.8MB/s]
















Downloading data:   0%|          | 0.00/733k [00:00<?, ?B/s]
Downloading data: 100%|██████████| 733k/733k [00:00<00:00, 19.1MB/s]
















Downloading data:   0%|          | 0.00/157M [00:00<?, ?B/s]















Downloading data:  13%|█▎        | 21.0M/157M [00:00<00:00, 175MB/s]















Downloading data:  33%|███▎      | 52.4M/157M [00:00<00:00, 236MB/s]















Downloading data:  53%|█████▎    | 83.9M/157M [00:00<00:00, 237MB/s]















Downloading data:  73%|███████▎  | 115M/157M [00:00<00:00, 252MB/s]















Downloading data:  94%|█████████▎| 147M/157M [00:00<00:00, 253MB/s]
Downloading data: 100%|██████████| 157M/157M [00:00<00:00, 245MB/s]
















Downloading data:   0%|          | 0.00/157M [00:00<?, ?B/s]















Downloading data:  13%|█▎        | 21.0M/157M [00:00<00:00, 195MB/s]















Downloading data:  33%|███▎      | 52.4M/157M [00:00<00:00, 242MB/s]















Downloading data:  53%|█████▎    | 83.9M/157M [00:00<00:00, 233MB/s]















Downloading data:  73%|███████▎  | 115M/157M [00:00<00:00, 234MB/s]















Downloading data:  93%|█████████▎| 147M/157M [00:00<00:00, 241MB/s]
Downloading data: 100%|██████████| 157M/157M [00:00<00:00, 235MB/s]
















Downloading data:   0%|          | 0.00/657k [00:00<?, ?B/s]
Downloading data: 100%|██████████| 657k/657k [00:00<00:00, 22.1MB/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]
Generating test split: 100%|██████████| 4358/4358 [00:00<00:00, 532381.22 examples/s]

Generating train split:   0%|          | 0/1801350 [00:00<?, ? examples/s]
Generating train split:   5%|▌         | 91000/1801350 [00:00<00:01, 904334.63 examples/s]
Generating train split:  10%|█         | 184000/1801350 [00:00<00:01, 916810.65 examples/s]
Generating train split:  16%|█▌        | 280000/1801350 [00:00<00:01, 928976.12 examples/s]
Generating train split:  21%|██        | 375000/1801350 [00:00<00:01, 934716.63 examples/s]
Generating train split:  29%|██▊       | 516000/1801350 [00:00<00:01, 933129.29 examples/s]
Generating train split:  34%|███▍      | 610000/1801350 [00:00<00:01, 933157.90 examples/s]
Generating train split:  39%|███▉      | 706000/1801350 [00:00<00:01, 936093.69 examples/s]
Generating train split:  44%|████▍     | 800000/1801350 [00:00<00:01, 934823.67 examples/s]
Generating train split:  52%|█████▏    | 937675/1801350 [00:01<00:00, 923706.36 examples/s]
Generating train split:  57%|█████▋    | 1032675/1801350 [00:01<00:00, 928461.77 examples/s]
Generating train split:  63%|██████▎   | 1127675/1801350 [00:01<00:00, 931992.19 examples/s]
Generating train split:  68%|██████▊   | 1221675/1801350 [00:01<00:00, 931263.03 examples/s]
Generating train split:  73%|███████▎  | 1316675/1801350 [00:01<00:00, 934701.95 examples/s]
Generating train split:  78%|███████▊  | 1410675/1801350 [00:01<00:00, 933379.18 examples/s]
Generating train split:  84%|████████▎ | 1505675/1801350 [00:01<00:00, 935219.88 examples/s]
Generating train split:  89%|████████▉ | 1600675/1801350 [00:01<00:00, 937540.61 examples/s]
Generating train split:  94%|█████████▍| 1694675/1801350 [00:01<00:00, 934951.32 examples/s]
Generating train split:  99%|█████████▉| 1789675/1801350 [00:01<00:00, 936821.79 examples/s]
Generating train split: 100%|██████████| 1801350/1801350 [00:01<00:00, 932265.94 examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]
Generating validation split: 100%|██████████| 3760/3760 [00:00<00:00, 749018.43 examples/s]

Downloading readme: 0.00B [00:00, ?B/s]
Downloading readme: 10.5kB [00:00, 55.8MB/s]
















Downloading data:   0%|          | 0.00/733k [00:00<?, ?B/s]
Downloading data: 100%|██████████| 733k/733k [00:00<00:00, 18.0MB/s]
















Downloading data:   0%|          | 0.00/157M [00:00<?, ?B/s]















Downloading data:  13%|█▎        | 21.0M/157M [00:00<00:00, 159MB/s]















Downloading data:  33%|███▎      | 52.4M/157M [00:00<00:00, 226MB/s]















Downloading data:  53%|█████▎    | 83.9M/157M [00:00<00:00, 242MB/s]















Downloading data:  73%|███████▎  | 115M/157M [00:00<00:00, 241MB/s]















Downloading data:  94%|█████████▎| 147M/157M [00:00<00:00, 250MB/s]
Downloading data: 100%|██████████| 157M/157M [00:00<00:00, 239MB/s]
















Downloading data:   0%|          | 0.00/157M [00:00<?, ?B/s]















Downloading data:  13%|█▎        | 21.0M/157M [00:00<00:00, 163MB/s]















Downloading data:  33%|███▎      | 52.4M/157M [00:00<00:00, 227MB/s]















Downloading data:  53%|█████▎    | 83.9M/157M [00:00<00:00, 230MB/s]















Downloading data:  73%|███████▎  | 115M/157M [00:00<00:00, 223MB/s]















Downloading data:  93%|█████████▎| 147M/157M [00:00<00:00, 237MB/s]
Downloading data: 100%|██████████| 157M/157M [00:00<00:00, 230MB/s]
















Downloading data:   0%|          | 0.00/657k [00:00<?, ?B/s]
Downloading data: 100%|██████████| 657k/657k [00:00<00:00, 14.5MB/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]
Generating test split: 100%|██████████| 4358/4358 [00:00<00:00, 769341.17 examples/s]

Generating train split:   0%|          | 0/1801350 [00:00<?, ? examples/s]
Generating train split:   5%|▌         | 96000/1801350 [00:00<00:01, 944607.66 examples/s]
Generating train split:  11%|█         | 191000/1801350 [00:00<00:01, 943861.62 examples/s]
Generating train split:  16%|█▌        | 287000/1801350 [00:00<00:01, 943940.62 examples/s]
Generating train split:  21%|██        | 382000/1801350 [00:00<00:01, 942641.32 examples/s]
Generating train split:  29%|██▉       | 523000/1801350 [00:00<00:01, 938331.32 examples/s]
Generating train split:  34%|███▍      | 618000/1801350 [00:00<00:01, 938323.42 examples/s]
Generating train split:  40%|███▉      | 714000/1801350 [00:00<00:01, 941041.16 examples/s]
Generating train split:  47%|████▋     | 854000/1801350 [00:00<00:01, 935038.02 examples/s]
Generating train split:  53%|█████▎    | 948675/1801350 [00:01<00:00, 931013.58 examples/s]
Generating train split:  58%|█████▊    | 1045675/1801350 [00:01<00:00, 937218.46 examples/s]
Generating train split:  63%|██████▎   | 1140675/1801350 [00:01<00:00, 937490.67 examples/s]
Generating train split:  69%|██████▊   | 1235675/1801350 [00:01<00:00, 937917.04 examples/s]
Generating train split:  74%|███████▍  | 1331675/1801350 [00:01<00:00, 941460.99 examples/s]
Generating train split:  79%|███████▉  | 1426675/1801350 [00:01<00:00, 939156.24 examples/s]
Generating train split:  84%|████████▍ | 1521675/1801350 [00:01<00:00, 938546.97 examples/s]
Generating train split:  90%|████████▉ | 1617675/1801350 [00:01<00:00, 941511.90 examples/s]
Generating train split:  98%|█████████▊| 1759675/1801350 [00:01<00:00, 941276.21 examples/s]
Generating train split: 100%|██████████| 1801350/1801350 [00:01<00:00, 938932.61 examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]
Generating validation split: 100%|██████████| 3760/3760 [00:00<00:00, 754428.96 examples/s]
2026-06-03 00:27:10,188 WARNING services.py:2213 -- WARNING: The object store is using /tmp/ray instead of /dev/shm because /dev/shm has only 2147467264 bytes available. This will harm performance! You may be able to free up space by deleting files in /dev/shm. If you are inside a Docker container, you can increase /dev/shm size by passing '--shm-size=10.24gb' to 'docker run' (or add it to the run_options list in a Ray cluster config). Make sure to set this to more than 30% of available RAM.
2026-06-03 00:27:12,386 INFO worker.py:2003 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265
/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py:2051: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0
  warnings.warn(
Dataset schema:
Column  Type
------  ----
text    string

其模式(schema)如下所示

Column  Type
------  ----
text    string

这意味着数据集有一列名为 text,且为字符串类型。

检查原始数据#

使用 take(n) 获取少量行进行检查。每一行都是一个以列名为键的字典。

print("--- Raw data sample ---")
sample = train_ds.take(2)
for i, row in enumerate(sample):
    text_preview = (row["text"][:120] + "...") if len(row["text"]) > 120 else row["text"]
    print(f"  Row {i}: {text_preview!r}")
--- Raw data sample ---
2026-06-03 00:27:15,574 INFO dataset.py:3818 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2026-06-03 00:27:15,582 INFO logging.py:416 -- Registered dataset logger for dataset dataset_4_0
2026-06-03 00:27:15,598 WARNING resource_manager.py:169 -- ⚠️  Ray's object store is configured to use only 5.3% of available memory (9.3GiB out of 176.6GiB total). For optimal Ray Data performance, we recommend setting the object store to at least 50% of available memory. You can do this by setting the 'object_store_memory' parameter when calling ray.init() or by setting the RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION environment variable.
2026-06-03 00:27:15,598 WARNING __init__.py:28 -- Progress bars disabled. To enable, set `ray.data.DataContext.get_current().enable_progress_bars = True`.
2026-06-03 00:27:16,285 INFO streaming_executor.py:294 -- ✔️  Dataset dataset_4_0 execution finished in 0.69 seconds
  Row 0: ''
  Row 1: ' = Valkyria Chronicles III = \n'

你将看到如下输出

Row 0: ''
Row 1: ' = Valkyria Chronicles III = '

Wikitext-103 中的每一行都是维基百科文章的一行文本。连续的行属于同一篇文章,空行用于分隔段落。新文章以类似 = Article Title = 的标题行开头。下方的分词步骤会在每个标题行之前插入一个 <|endoftext|> 分隔符标记,以便模型学会根据文章边界重置上下文。

对数据进行分词和分块#

语言模型消耗的是固定长度的词元 ID 序列。预处理步骤将原始文本转换为用于预测下一个词元的 ID 序列。

本教程使用带有 GPT-2 编码(词汇表大小 50,257)的 tiktokentiktoken 是一个快速、独立的分词器,不依赖于 Hugging Face 的 transformers 库。

tokenize_and_chunk 函数执行以下操作

  • 对每批文本进行分词,并连接成单个流。文章标题行(例如 = Article Title =)会触发一个 <|endoftext|> 分隔符,以便模型在文章边界处重置上下文。

  • 将流拆分为 block_size 词元长度的固定块。

  • 返回每个块的 input_ids。在训练过程中,同一个张量既用作输入也用作标签,因为 GPT2LMHeadModel 在计算交叉熵损失时会在内部移动标签。

BLOCK_SIZE = 256
VOCAB_SIZE = 50257

encoding = tiktoken.get_encoding("gpt2")
EOT_TOKEN = encoding.eot_token  # <|endoftext|> token ID (50256)


def _is_article_title(text: str) -> bool:
    """Detect Wikitext article title lines like ' = Some Title = '."""
    stripped = text.strip()
    return stripped.startswith("= ") and stripped.endswith(" =") and not stripped.startswith("= =")


def tokenize_and_chunk(batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    """Tokenize text and split into fixed-length chunks for language modeling."""
    # Reconstruct the original text stream by joining rows with newlines.
    # Article title lines signal new articles, so we insert an
    # <|endoftext|> separator before them.
    all_tokens: list[int] = []
    for text in batch["text"]:
        if _is_article_title(text):
            all_tokens.append(EOT_TOKEN)
        all_tokens.extend(encoding.encode_ordinary(text + "\n"))

    # Split into fixed-length chunks of block_size tokens.
    num_chunks = len(all_tokens) // BLOCK_SIZE
    all_tokens = all_tokens[: num_chunks * BLOCK_SIZE]

    if num_chunks == 0:
        return {"input_ids": []}

    tokens_array = np.array(all_tokens, dtype=np.int64).reshape(num_chunks, BLOCK_SIZE)
    return {"input_ids": tokens_array}

使用 map_batches() 应用分词。此操作是惰性(lazy)的,这意味着 Ray Data 会推迟执行,直到下游消费者请求结果为止。惰性执行让 Ray 能够在任何工作开始前优化整个流水线。

# These do not trigger execution.
train_ds = train_ds.map_batches(tokenize_and_chunk, batch_format="numpy")
val_ds = val_ds.map_batches(tokenize_and_chunk, batch_format="numpy")

使用 take(2) 检查分词后的输出

print("--- After tokenization ---")
tokenized_sample = train_ds.take(2)
for i, row in enumerate(tokenized_sample):
    ids = row["input_ids"]
    print(f"  Row {i}: input_ids shape={ids.shape}, first 10 tokens={ids[:10].tolist()}")
    print(f"          Decoded: {encoding.decode(ids[:30].tolist())!r}...")
--- After tokenization ---
2026-06-03 00:27:17,195 INFO logging.py:416 -- Registered dataset logger for dataset dataset_7_0
2026-06-03 00:27:17,719 INFO streaming_executor.py:294 -- ✔️  Dataset dataset_7_0 execution finished in 0.52 seconds
  Row 0: input_ids shape=(256,), first 10 tokens=[198, 50256, 796, 569, 18354, 7496, 17740, 6711, 796, 220]
          Decoded: '\n<|endoftext|> = Valkyria Chronicles III = \n\n\n Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦'...
  Row 1: input_ids shape=(256,), first 10 tokens=[33687, 5303, 18024, 6909, 764, 317, 1588, 1074, 286, 8786]
          Decoded: " Takeshi Ozawa . A large team of writers handled the script . The game 's opening theme was sung by May 'n . \n\n It"...

现在每一行都包含一个 256 个词元的固定长度 input_ids 数组。

流式执行#

在内部,Ray 将数据划分为块(blocks)并分发给工作节点。这种基于块的架构实现了流式执行:一旦某个阶段输出了一个块,下一个阶段就可以立即开始处理它,而无需等待前一阶段完成整个数据集的处理。这意味着上述 map_batches 分词操作与训练循环在流水线中同步运行,因此完整数据集无需一次性全部加载到内存中。

当训练开始时,Ray Data 会记录执行计划。对于本教程,一种可能的计划如下

Execution plan: InputDataBuffer[Input]
    -> TaskPoolMapOperator[MapBatches(tokenize_and_chunk)]
    -> OutputSplitter[split(4, equal=True)]

这清楚地告诉了你 Ray Data 将如何通过分词进行流式传输,并将数据拆分给 4 个训练工作节点。

定义 Transformer 模型#

该模型是一个仅使用解码器的 Transformer 语言模型,使用了 Hugging Face 的 GPT2LMHeadModel。下方的超参数对应于标准的 GPT-2 “small” 架构。

def create_model():
    """Create a GPT-2 small model with random weights."""
    model = GPT2LMHeadModel(GPT2Config(
        vocab_size=VOCAB_SIZE,
        n_positions=BLOCK_SIZE,
        n_embd=768,
        n_layer=12,
        n_head=12,
    ))
    model.loss_type = "ForCausalLM"
    return model

验证模型大小

model = create_model()
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params / 1e6:.1f}M")

del model  # Free memory before training
Model parameters: 123.8M

你可以看到大约有 1.238 亿个参数。

定义分布式训练函数#

训练函数在每个工作进程上运行。Ray Train 管理分布式设置:它将模型封装在 DistributedDataParallel 中,在工作节点间分片数据,并自动同步梯度。

Ray Train 的关键集成点包括

  • ray.train.get_dataset_shard("train") 获取工作节点对应的数据集分片,Ray Data 会自动将数据集拆分给所有工作节点。

  • ray.train.torch.prepare_model(model) 将模型封装在 DistributedDataParallel 中,并将其移动到正确的 GPU 上。

  • shard.iter_torch_batches(batch_size=...) 返回一个 dict[str, torch.Tensor] 批次的迭代器,张量会自动放置在工作节点的 GPU 上。设置 prefetch_batches=2 会预取 2 个批次。

  • ray.train.report(metrics, checkpoint=...) 向驱动程序(driver)报告指标并保存检查点。

def train_func_per_worker(config: dict):
    """Training function executed by each distributed worker."""
    lr = config["lr"]
    weight_decay = config["weight_decay"]
    max_grad_norm = config["max_grad_norm"]
    epochs = config["epochs"]
    batch_size = config["batch_size_per_worker"]
    max_steps_per_epoch = config.get("max_steps_per_epoch")

    # --- Data -----------------------------------------------------------
    # Each worker gets an automatic shard of the dataset.
    train_data_shard = ray.train.get_dataset_shard("train")
    val_data_shard = ray.train.get_dataset_shard("validation")

    # --- Model ----------------------------------------------------------
    model = create_model()
    # prepare_model wraps the model in DistributedDataParallel and places
    # it on the correct device.
    model = ray.train.torch.prepare_model(model)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    # --- Training loop --------------------------------------------------
    for epoch in range(epochs):
        model.train()
        train_loss_sum = 0.0
        train_batches = 0
        train_tokens = 0
        epoch_start = time.perf_counter()

        # iter_torch_batches returns dicts of tensors already on the GPU.
        for batch in train_data_shard.iter_torch_batches(
            batch_size=batch_size, dtypes=torch.long, prefetch_batches=2
        ):
            input_ids = batch["input_ids"]

            # GPT2LMHeadModel shifts labels internally to align each
            # position with the next token, so we can use input_ids as
            # both the input and the labels.
            out = model(input_ids=input_ids, labels=input_ids)
            loss = out.loss

            optimizer.zero_grad()
            loss.backward()
            # Gradient clipping for training stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
            optimizer.step()

            train_loss_sum += loss.item()
            train_batches += 1
            train_tokens += input_ids.numel()

            if max_steps_per_epoch and train_batches >= max_steps_per_epoch:
                break

        train_elapsed = time.perf_counter() - epoch_start
        avg_train_loss = train_loss_sum / max(train_batches, 1)

        # --- Validation -----------------------------------------------------
        model.eval()
        val_loss_sum = 0.0
        val_batches = 0

        with torch.no_grad():
            for batch in val_data_shard.iter_torch_batches(
                batch_size=batch_size, dtypes=torch.long, prefetch_batches=2
            ):
                input_ids = batch["input_ids"]

                out = model(input_ids=input_ids, labels=input_ids)
                loss = out.loss
                val_loss_sum += loss.item()
                val_batches += 1

                if max_steps_per_epoch and val_batches >= max_steps_per_epoch:
                    break

        avg_val_loss = val_loss_sum / max(val_batches, 1)
        epoch_elapsed = time.perf_counter() - epoch_start

        # --- Report metrics and save checkpoint ------------------------------
        metrics = {
            "train_loss": round(avg_train_loss, 4),
            "val_loss": round(avg_val_loss, 4),
            "epoch": epoch,
            "epoch_time_sec": round(epoch_elapsed, 2),
            "epoch_tokens": train_tokens,
            "tokens_per_sec": round(train_tokens / max(train_elapsed, 1e-6), 2),
        }

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.module.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                os.path.join(temp_checkpoint_dir, "checkpoint.pt"),
            )
            checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
            ray.train.report(metrics=metrics, checkpoint=checkpoint)

配置并启动分布式训练#

TorchTrainer 将一切整合在一起。运行 trainer.fit() 最终会触发完整数据流水线和训练循环的执行。Trainer 接受以下参数

  • train_func_per_worker:每个工作节点执行的函数。

  • train_loop_config:传递给训练函数的超参数字典。

  • datasets:Ray Datasets 字典。Ray Train 会自动在工作节点间拆分每个数据集。

  • scaling_config:指定工作节点数量以及是否使用 GPU。

设置 num_workers=4 将启动 4 个并行工作节点,每个 GPU 对应一个。Ray Train 会在后台处理 torch.distributed 初始化、NCCL 后端设置和 DistributedDataParallel 封装。在日志中,你会看到每个工作节点被分配了一个 rank 和设备。

Started training worker group of size 4:

* (ip=10.0.176.183, pid=25636) world_rank=0, local_rank=0, node_rank=0
* (ip=10.0.176.183, pid=25637) world_rank=1, local_rank=1, node_rank=0
...
Moving model to device: cuda:0
Wrapping provided model in DistributedDataParallel.

batch_size_per_worker 是每个工作节点在每个梯度步骤处理的序列数量。使用 4 个工作节点,每个工作节点的批次大小为 16,全局有效批次大小为 4 × 16 = 64 个序列,或每个优化步骤 64 × 256 = 4,096 个词元。

USE_GPU = torch.cuda.is_available()
NUM_WORKERS = max(torch.cuda.device_count(), 1)  # One worker per available GPU
NUM_EPOCHS = 5
BATCH_SIZE_PER_WORKER = 16
LR = 3e-4
WEIGHT_DECAY = 0.1
MAX_GRAD_NORM = 1.0

trainer = TorchTrainer(
    train_loop_per_worker=train_func_per_worker,
    train_loop_config={
        "lr": LR,
        "weight_decay": WEIGHT_DECAY,
        "max_grad_norm": MAX_GRAD_NORM,
        "epochs": NUM_EPOCHS,
        "batch_size_per_worker": BATCH_SIZE_PER_WORKER,
        "max_steps_per_epoch": 5 if SMOKE_TEST else None,
    },
    # Register the datasets,
    datasets={"train": train_ds, "validation": val_ds},
    scaling_config=ScalingConfig(
        num_workers=NUM_WORKERS,
        use_gpu=USE_GPU,
    ),
    run_config=RunConfig(
        name="gpt2-small-pretraining",
        storage_path="/tmp/ray-train-checkpoints",
    ),
)

result = trainer.fit()
(TrainController pid=7979) Requesting resources: {'GPU': 1} * 4
(TrainController pid=7979) Attempting to start training worker group of size 4 with the following resources: [{'GPU': 1}] * 4
(RayTrainWorker pid=8100) Setting up process group for: env:// [rank=0, world_size=4]
(TrainController pid=7979) Started training worker group of size 4:
(TrainController pid=7979) - (ip=172.17.0.2, pid=8100) world_rank=0, local_rank=0, node_rank=0
(TrainController pid=7979) - (ip=172.17.0.2, pid=8099) world_rank=1, local_rank=1, node_rank=0
(TrainController pid=7979) - (ip=172.17.0.2, pid=8102) world_rank=2, local_rank=2, node_rank=0
(TrainController pid=7979) - (ip=172.17.0.2, pid=8101) world_rank=3, local_rank=3, node_rank=0
(RayTrainWorker pid=8100) Moving model to device: cuda:0
(RayTrainWorker pid=8100) Wrapping provided model in DistributedDataParallel.
(SplitCoordinator pid=8563) Registered dataset logger for dataset train_8_0
(SplitCoordinator pid=8563) ⚠️  Ray's object store is configured to use only 5.3% of available memory (9.3GiB out of 176.6GiB total). For optimal Ray Data performance, we recommend setting the object store to at least 50% of available memory. You can do this by setting the 'object_store_memory' parameter when calling ray.init() or by setting the RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION environment variable.
(SplitCoordinator pid=8563) Progress bars disabled. To enable, set `ray.data.DataContext.get_current().enable_progress_bars = True`.
(SplitCoordinator pid=8563) ✔️  Dataset train_8_0 execution finished in 1.64 seconds
(RayTrainWorker pid=8100) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-06-03_00-27-40.004632)
(RayTrainWorker pid=8100) Reporting training result 1: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-06-03_00-27-40.004632), metrics={'train_loss': 9.8121, 'val_loss': 8.9404, 'epoch': 0, 'epoch_time_sec': 5.16, 'epoch_tokens': 20480, 'tokens_per_sec': 4607.68}, validation=False)
(SplitCoordinator pid=8564) Registered dataset logger for dataset validation_10_0
(SplitCoordinator pid=8564) ⚠️  Ray's object store is configured to use only 5.3% of available memory (9.3GiB out of 176.6GiB total). For optimal Ray Data performance, we recommend setting the object store to at least 50% of available memory. You can do this by setting the 'object_store_memory' parameter when calling ray.init() or by setting the RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION environment variable.
(SplitCoordinator pid=8564) Progress bars disabled. To enable, set `ray.data.DataContext.get_current().enable_progress_bars = True`.
(SplitCoordinator pid=8564) ✔️  Dataset validation_10_0 execution finished in 0.18 seconds
(SplitCoordinator pid=8563) Registered dataset logger for dataset train_8_1
(SplitCoordinator pid=8563) ✔️  Dataset train_8_1 execution finished in 0.19 seconds
(RayTrainWorker pid=8099) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-06-03_00-27-46.998696) [repeated 4x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.rayai.org.cn/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(RayTrainWorker pid=8099) Reporting training result 2: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-06-03_00-27-46.998696), metrics={'train_loss': 8.4915, 'val_loss': 8.2414, 'epoch': 1, 'epoch_time_sec': 3.16, 'epoch_tokens': 20480, 'tokens_per_sec': 8125.6}, validation=False) [repeated 4x across cluster]
(SplitCoordinator pid=8564) Registered dataset logger for dataset validation_10_1
(SplitCoordinator pid=8564) ✔️  Dataset validation_10_1 execution finished in 0.16 seconds
(SplitCoordinator pid=8563) Registered dataset logger for dataset train_8_2
(SplitCoordinator pid=8563) ✔️  Dataset train_8_2 execution finished in 0.19 seconds
(RayTrainWorker pid=8100) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-06-03_00-27-54.170024) [repeated 4x across cluster]
(RayTrainWorker pid=8100) Reporting training result 3: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-06-03_00-27-54.170024), metrics={'train_loss': 7.6134, 'val_loss': 7.756, 'epoch': 2, 'epoch_time_sec': 3.15, 'epoch_tokens': 20480, 'tokens_per_sec': 8152.04}, validation=False) [repeated 4x across cluster]
(SplitCoordinator pid=8564) Registered dataset logger for dataset validation_10_2
(SplitCoordinator pid=8564) ✔️  Dataset validation_10_2 execution finished in 0.16 seconds
(SplitCoordinator pid=8563) Registered dataset logger for dataset train_8_3
(SplitCoordinator pid=8563) ✔️  Dataset train_8_3 execution finished in 0.19 seconds
(RayTrainWorker pid=8100) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-06-03_00-28-01.252648) [repeated 4x across cluster]
(RayTrainWorker pid=8100) Reporting training result 4: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-06-03_00-28-01.252648), metrics={'train_loss': 7.2063, 'val_loss': 7.7138, 'epoch': 3, 'epoch_time_sec': 3.18, 'epoch_tokens': 20480, 'tokens_per_sec': 8044.78}, validation=False) [repeated 4x across cluster]
(SplitCoordinator pid=8564) Registered dataset logger for dataset validation_10_3
(SplitCoordinator pid=8564) ✔️  Dataset validation_10_3 execution finished in 0.16 seconds
(SplitCoordinator pid=8563) Registered dataset logger for dataset train_8_4
(SplitCoordinator pid=8563) ✔️  Dataset train_8_4 execution finished in 0.25 seconds
(RayTrainWorker pid=8100) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-06-03_00-28-08.552107) [repeated 4x across cluster]
(RayTrainWorker pid=8100) Reporting training result 5: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-06-03_00-28-08.552107), metrics={'train_loss': 7.2632, 'val_loss': 7.808, 'epoch': 4, 'epoch_time_sec': 3.25, 'epoch_tokens': 20480, 'tokens_per_sec': 7818.89}, validation=False) [repeated 4x across cluster]
(SplitCoordinator pid=8564) Registered dataset logger for dataset validation_10_4
(SplitCoordinator pid=8564) ✔️  Dataset validation_10_4 execution finished in 0.16 seconds

检查结果#

训练结束后,Result 对象包含最终指标和检查点。result.metrics 来自最后一次 ray.train.report() 调用。result.checkpoint 包含最后一次 ray.train.report() 调用保存的检查点。

print("\nTraining finished!")
Training finished!

result.metrics 包含最后一次 ray.train.report() 调用报告的指标字典。

{'train_loss': 7.0646, 'val_loss': 7.6051, 'epoch': 4,
 'epoch_time_sec': 12.34, 'epoch_tokens': 20480, 'tokens_per_sec': 1759.8}

每个工作节点的日志显示了每个 epoch 的训练损失、验证损失和吞吐量指标。由于权重是随机的且训练步数很少,预计损失会很高(约 10-11)。

检查点保存#

在生产环境训练中,你可以启用检查点保存功能,使训练任务能够抵御意外故障。检查点保存允许你利用 故障容错 部分中描述的 Ray Train 容错机制。

Ray Train 提供了多种检查点优化方案。异步上传功能使你能够在后台将检查点流式传输到远程存储的同时继续训练。分布式检查点保存并行上传每个工作节点的分片,避免了汇集到单个工作节点内存中的步骤,从而降低了大型模型出现内存溢出(OOM)错误的风险。

有关使用 Ray Train 进行检查点保存的完整指南,请参阅 Ray Train 检查点保存指南

扩展到多节点集群#

上述代码运行在单台 4 GPU 机器上。扩展到多节点集群只需两处更改

  1. 增加 ``num_workers`` 以匹配集群中的总 GPU 数量。

  2. 设置一个共享存储路径,以便所有节点都能访问检查点。

例如,要在 4 个节点(每个节点 4 个 GPU,共 16 个 GPU)的集群上进行训练

trainer = TorchTrainer(
    train_loop_per_worker=train_func_per_worker,
    train_loop_config={...},
    datasets={"train": train_ds, "validation": val_ds},
    scaling_config=ScalingConfig(
        num_workers=16,  # 4 nodes x 4 GPUs
        use_gpu=True,
    ),
    run_config=RunConfig(
        # Shared storage accessible from all nodes
        storage_path="s3://my-bucket/ray-checkpoints",
        checkpoint_config=CheckpointConfig(num_to_keep=2),
    ),
)

Ray Train 会自动

  • 在所有可用节点上启动工作节点,并在需要时通过 Ray 集群自动缩放功能启动新节点。

  • 在所有工作节点间进行数据分片。

训练函数无需任何更改。无论在 1 个 GPU 还是 256 个 GPU 上,相同的 train_func_per_worker 运行方式完全相同。

本教程使用 DistributedDataParallel (DDP),它在每个 GPU 上复制整个模型。对于无法放入单个 GPU 的大型模型,可以通过设置 prepare_model(parallel_strategy="fsdp") 来切换到 FullyShardedDataParallel (FSDP),从而在工作节点间分片参数、梯度和优化器状态。

异构集群:分离数据和训练资源#

由于 Ray Data 和 Ray Train 是独立的系统,它们不需要共享同一台机器。默认情况下,Ray Data 预处理和训练工作节点都在同一节点上运行。但是,你可以选择向集群添加仅 CPU 节点,Ray Data 会自动在这些节点上调度预处理任务,从而将昂贵的 GPU 节点留给训练任务使用。

当数据预处理成为瓶颈时,这非常有用。如果你发现因为工作节点等待数据而导致 GPU 利用率较低,可以添加廉价的仅 CPU 节点到集群,Ray Data 会将预处理扩展到这些节点上。

有关详细信息,请参阅 配置数据摄入

故障容错#

长时间运行的分布式训练任务容易受到硬件故障的影响。这些故障包括硬件故障、网络故障或抢占。如果没有故障容错,这些事件中的任何一个都可能迫使你从头开始重新训练,浪费时间和计算资源。

Ray Train 具有自动处理这些故障的功能。当工作进程崩溃时,Ray Train 会在原地重启它并恢复训练。如果整个节点宕机,Ray Train 会提供一个替换节点并从最近的检查点恢复,这样只会丢失少量的进度。这使得中断训练任务并在以后恢复成为可能。

要启用自动故障恢复,请在 RunConfig 中配置 FailureConfigmax_failures 参数控制 Ray Train 在放弃前可容忍的连续故障次数

from ray.train import FailureConfig

run_config = RunConfig(
    storage_path="s3://my-bucket/ray-checkpoints",
    failure_config=FailureConfig(max_failures=3),
    checkpoint_config=CheckpointConfig(num_to_keep=2),
)

有关详细信息,请参阅 Ray Train 故障容错指南

监控你的训练任务#

在运行分布式训练时,监控至关重要。Ray 仪表板显示实时指标,包括

  • 每个 epoch 的训练损失和验证指标

  • 每个工作节点的 GPU 利用率和内存使用情况

  • 数据加载吞吐量

  • 工作节点状态和错误日志

要查看仪表板,请打开 Ray 初始化后打印在日志中的链接。通常该链接为 https://:8265

仪表板让你能够

  • 监控所有工作节点的训练进度

  • 检查来自单个工作节点的日志

  • 识别数据加载或通信瓶颈

  • 查看每个工作节点的 CPU、GPU 和内存资源使用情况

  • 通过详细的错误消息和堆栈跟踪调试故障

有关详细信息,请参阅 Ray Train 监控文档

结论#

在本教程中,你

  • 使用 Hugging Face Transformers 和 PyTorch 预训练了一个 GPT-2(约 1.24 亿参数)语言模型。

  • 使用 Ray Data 的分布式流式传输加载并预处理了 Wikitext-103 数据集。

  • 使用 Ray Train 的 TorchTrainer 在 4 个 GPU 上运行了分布式训练,且仅对标准 PyTorch 训练循环进行了最小限度的改动。

  • 学习了如何保存和加载用于模型恢复的分布式检查点。

  • 学习了如何通过更改 ScalingConfigRunConfig 扩展到多节点集群。

  • 学习了异构集群如何让你在 CPU 节点上运行数据预处理,并在 GPU 节点上运行训练,从而实现成本和性能优化。

  • 了解了 Ray Train 用于生产训练任务的故障容错机制。

  • 通过 Ray 仪表板监控了训练过程。

延伸阅读#

脚本总运行时间: (1 分 17.526 秒)