注意
转到页面底部下载完整示例代码。
Mosaic:PyTorch 内存性能分析#
作者: Basil Wong
如何捕获并分析 PyTorch 内存快照
识别激活检查点带来的内存节省
调试遗留代码导致的异常内存占用
将内存分析集成到训练流水线中
PyTorch v2.0.0 或更高版本
支持 CUDA 的 GPU
对 PyTorch 训练循环有基本了解
本教程演示了如何使用 Mosaic,这是一个用于 PyTorch 的内存快照后处理分析工具。Mosaic 有助于分析分布式深度学习中的 GPU 内存使用情况,提供有关内存分配、峰值占用以及跨并行工作进程的内存不平衡的详细见解。
Mosaic 在调试 405B LLaMA 训练期间的内存溢出(OOM)问题中发挥了重要作用,现已开源。
Mosaic 简介#
概述#
在分布式深度学习中,了解 GPU 内存的使用对于优化训练效率和调试内存溢出(OOM)错误至关重要。Mosaic 是一款专为大规模作业设计的内存使用后分析工具。它有助于分析在 PyTorch 训练作业执行期间捕获的内存快照,提供有关内存分配、峰值占用以及跨并行工作进程的内存不平衡的详细见解。
入门#
克隆 mosaic 仓库并从 mosaic 目录安装
git clone https://github.com/facebookresearch/mosaic
cd mosaic
python3 -m venv venv
source venv/bin/activate
pip3 install -r requirements.txt
pip3 install -e .
或者,直接通过 pip 安装
pip install git+https://github.com/facebookresearch/mosaic.git
简单使用示例#
1. 峰值内存使用分析
当解决内存溢出(OOM)等内存问题时,关注峰值内存使用情况至关重要。mosaic_get_memory_usage_peak 命令会显示导致峰值内存使用的内存分配堆栈跟踪。
mosaic_get_memory_usage_peak --snapshot <path to snapshot>
2. 类别化内存性能分析
Mosaic 将分配归类为不同类别(激活、反向传播、优化器等)
激活内存: 为反向传播保存的张量
梯度内存: 反向传播期间计算的梯度
优化器状态: Adam/SGD 动量和方差缓冲区
参数内存: 模型权重
mosaic_get_memory_profile --snapshot <path> --out-path <html> \
--profile categories
HTML 输出示例看起来如下
类别化内存性能分析,显示按类型划分的内存分布(激活、梯度、优化器等)#
若要保持类别的分配顺序,请添加 --preserve-allocation-order
mosaic_get_memory_profile --snapshot <path> --out-path <html> \
--profile categories --preserve-allocation-order
使用 --preserve-allocation-order 的类别化性能分析显示了按时间顺序排列的内存分配#
3. 自定义字典性能分析
用于通过正则表达式匹配进行针对性分析
mosaic_get_memory_profile --snapshot <path> --profile custom \
--custom-profile '{"ncclx": "ncclx"}'
这对于跟踪特定的内核、优化器或自定义代码模式非常有价值
使用正则表达式模式进行自定义性能分析,以跟踪 NCCL 通信等特定操作#
依赖项和导入#
让我们为本教程设置所需的依赖项和导入。
import subprocess
import sys
import shutil
from contextlib import contextmanager
import pickle
# Fix for sphinx-gallery environment where __main__.__file__ may not exist
# This is needed for transformers library compatibility
import os
if not hasattr(sys.modules["__main__"], "__file__"):
# Use this file's path as a fallback, or a dummy path if __file__ is not available
try:
sys.modules["__main__"].__file__ = os.path.abspath(__file__)
except NameError:
# __file__ not available, use transformers modeling file as fallback
import transformers.modeling_utils
sys.modules["__main__"].__file__ = transformers.modeling_utils.__file__
import torch
from torch.utils.data import DataLoader, Dataset
# Install dependencies if needed
try:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
except ImportError:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-q", "transformers"]
)
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
try:
from mosaic.libmosaic.analyzer.memory_abstract import MemoryAbstract
except ImportError:
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"-q",
"git+https://github.com/facebookresearch/mosaic.git",
]
)
from mosaic.libmosaic.analyzer.memory_abstract import MemoryAbstract
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
案例 1:了解激活检查点带来的内存差异#
本节演示如何使用 Mosaic 来分析和比较不同模型配置之间的 GPU 内存使用情况。
我们将执行的操作
训练 GPT-2 并捕获内存快照(基准)
启用激活检查点并再次训练(修改后)
使用 Mosaic 精确识别内存节省发生的位置
用于激活检查点比较的训练函数#
def run_training_ac(
activation_checkpointing: bool,
snapshot_path: str,
batch_size: int = 4,
seq_length: int = 512,
num_steps: int = 5,
):
"""Run training loop and capture memory snapshot.
Args:
activation_checkpointing: Whether to enable gradient checkpointing.
snapshot_path: Path to save the memory snapshot.
batch_size: Training batch size.
seq_length: Sequence length for input tokens.
num_steps: Number of training steps to run.
Returns:
Peak GPU memory usage in GB.
"""
# Clear any previous memory
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
device = torch.device("cuda")
# Load model
print(f"Loading GPT-2 (activation_checkpointing={activation_checkpointing})...")
model = GPT2LMHeadModel.from_pretrained("gpt2")
if activation_checkpointing:
model.gradient_checkpointing_enable()
print("Activation checkpointing is ENABLED")
else:
print("Activation checkpointing is DISABLED")
model = model.to(device)
model.train()
# Create dataset and dataloader
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
dataset = RandomTokenDataset(
vocab_size=tokenizer.vocab_size,
seq_length=seq_length,
num_samples=100,
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Setup optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
# Training loop with memory capture
print(f"Running {num_steps} training steps...")
with capture_memory_snapshot(snapshot_path):
for step, batch in enumerate(dataloader):
if step >= num_steps:
break
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"])
loss = outputs.loss
loss.backward()
optimizer.step()
print(f" Step {step + 1}/{num_steps}, Loss: {loss.item():.4f}")
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
print(f"✓ Peak GPU memory: {peak_memory_gb:.2f} GB")
# Cleanup
del model, optimizer
torch.cuda.empty_cache()
return peak_memory_gb
运行基准训练(不带激活检查点)#
注意
本教程需要支持 CUDA 的 GPU。如果您在 Google Colab 中运行,请确保选择 GPU 运行时:Runtime → Change runtime type → Hardware accelerator → GPU
if not torch.cuda.is_available():
print("=" * 60)
print("WARNING: No CUDA GPU detected!")
print("=" * 60)
print("\nThis tutorial requires a CUDA-capable GPU for memory profiling.")
print("\nIf you're running in Google Colab:")
print(" 1. Go to Runtime → Change runtime type")
print(" 2. Set Hardware accelerator to 'GPU'")
print(" 3. Click 'Save' and re-run the notebook")
print("\nSkipping GPU memory profiling examples...")
HAS_CUDA = False
else:
HAS_CUDA = True
# Check if Mosaic CLI is available
HAS_MOSAIC_CLI = shutil.which("mosaic_get_memory_profile") is not None
if HAS_CUDA and not HAS_MOSAIC_CLI:
print("Note: Mosaic CLI not found. Install Mosaic to generate HTML profiles.")
print(" pip install git+https://github.com/facebookresearch/mosaic.git")
if HAS_CUDA:
print("=" * 60)
print("BASELINE: Training WITHOUT Activation Checkpointing")
print("=" * 60)
baseline_memory = run_training_ac(
activation_checkpointing=False,
snapshot_path="snapshot_baseline.pickle",
batch_size=4,
seq_length=512,
num_steps=5,
)
运行修改后的训练(带有激活检查点)#
if HAS_CUDA:
print("\n" + "=" * 60)
print("MODIFIED: Training WITH Activation Checkpointing")
print("=" * 60)
ac_memory = run_training_ac(
activation_checkpointing=True,
snapshot_path="snapshot_with_ac.pickle",
batch_size=4,
seq_length=512,
num_steps=5,
)
# Summary
print("\n" + "=" * 60)
print("MEMORY COMPARISON SUMMARY")
print("=" * 60)
print(f"Baseline (no AC): {baseline_memory:.2f} GB")
print(f"With AC: {ac_memory:.2f} GB")
if baseline_memory > 0:
saved_pct = 100 * (baseline_memory - ac_memory) / baseline_memory
print(
f"Memory Saved: {baseline_memory - ac_memory:.2f} GB ({saved_pct:.1f}%)"
)
使用 Mosaic 生成类别化内存配置文件#
使用 Mosaic 为这两个快照生成 HTML 配置文件。
if HAS_CUDA and HAS_MOSAIC_CLI:
print("\n" + "=" * 60)
print("MOSAIC: Categorical Memory Profiling")
print("=" * 60)
# Generate HTML profiles using subprocess
print("\nGenerating baseline profile...")
result1 = subprocess.run(
[
"mosaic_get_memory_profile",
"--snapshot",
"snapshot_baseline.pickle",
"--out-path",
"profile_baseline.html",
"--profile",
"categories",
"--preserve-allocation-order",
"--plotter_sampling_rate",
"20",
],
capture_output=True,
text=True,
)
print(result1.stdout)
if result1.stderr:
print(result1.stderr)
print("\nGenerating activation checkpointing profile...")
result2 = subprocess.run(
[
"mosaic_get_memory_profile",
"--snapshot",
"snapshot_with_ac.pickle",
"--out-path",
"profile_with_ac.html",
"--profile",
"categories",
"--preserve-allocation-order",
"--plotter_sampling_rate",
"20",
],
capture_output=True,
text=True,
)
print(result2.stdout)
if result2.stderr:
print(result2.stderr)
if result1.returncode == 0 and result2.returncode == 0:
print("\nGenerated profile_baseline.html")
print("Generated profile_with_ac.html")
print("\nDownload these files to view the interactive memory profiles.")
else:
print("\nNote: Mosaic profile generation encountered issues.")
print("This may happen if running in an environment without full Mosaic support.")
下载生成的文件(Google Colab)#
如果要在 Google Colab 中运行,请取消注释以下行以下载生成的快照和配置文件
# from google.colab import files
#
# print("Downloading memory snapshots and profiles...")
# files.download('snapshot_baseline.pickle')
# files.download('snapshot_with_ac.pickle')
# files.download('profile_baseline.html')
# files.download('profile_with_ac.html')
结果解读:激活检查点#
生成的 HTML 配置文件可视化了随时间变化的内存使用情况,分配按类别着色。以下是配置文件的样子
基准(无激活检查点): 注意在整个前向传播过程中持续存在的大量激活内存(以一种颜色显示)。#
带有激活检查点: 激活内存显著减少,因为中间激活在反向传播期间被丢弃并重新计算。#
我们的观察结果#
基于 Mosaic 类别化性能分析结果
指标 |
基准 |
带激活检查点 |
差异 |
|---|---|---|---|
总峰值内存 |
4.62 GB |
2.55 GB |
2.07 GB (降低 45%) |
激活内存 |
2.93 GB |
872.79 MB |
节省 2.08 GB (降低 71%) |
反向/梯度内存 |
793.39 MB |
785.27 MB |
8 MB (微小变化) |
优化器状态 |
949.4 MB |
949.4 MB |
无变化 |
未知 |
32 KB |
32 KB |
无变化 |
关键洞察#
主要发现: 激活内存从 2.93 GB 降至 872 MB(降低 71%),这几乎占了总内存节省的全部。
为什么会这样?#
激活检查点是一种内存优化技术,它
无 AC(基准): 前向传播的所有中间激活都会存储在内存中,供反向传播使用。GPT-2 有 12 层 Transformer 层,每层存储多个激活(注意力输出、MLP 输出等)。对于 batch_size=4,seq_length=512 的情况,这些加起来很快就会占用大量内存。
带 AC(优化后): 仅存储检查点边界处的激活;中间激活在反向传播期间重新计算。这显著减少了激活内存(在我们这个例子中为 71%),而其他内存类别保持不变。
Mosaic 如何提供帮助#
Mosaic 的类别化性能分析立即识别出
激活内存是差异最大的类别(节省了 2.08 GB)
反向/梯度内存几乎保持不变(793 MB → 785 MB)
优化器状态保持不变 (949 MB) - 这是预料之中的,因为模型参数没有变化
没有 Mosaic: 您需要手动检测代码、跟踪分配并自行进行分类。
使用 Mosaic: 您可以获得即时的类别化明细和精确的数字,使得识别/量化内存优化变得简单。
案例 2:调试异常内存占用#
本节演示当模型使用的内存超过预期且您不确定原因时,如何使用 Mosaic 进行调试。
我们将执行的操作
训练 GPT-2 并捕获内存快照。
训练带有引入额外内存错误的 GPT-2,并捕获内存快照。
使用 Mosaic 识别导致额外内存占用的潜在因素。
有问题的模型#
此模型有遗留的调试代码,会产生不必要的 GPU 内存开销。有人在调试期间添加了投影层以“分析隐藏状态”,但忘记在训练前将其删除。
class GPT2WithDebugOverhead(torch.nn.Module):
"""GPT2 wrapper with abandoned 'feature analysis' code that bloats peak memory.
This wrapper adds extra projection layers that consume memory but serve no
purpose - simulating abandoned debug code that was never cleaned up.
"""
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
config = base_model.config
# BUG: Large projection layers from an abandoned experiment
self.debug_projections = torch.nn.ModuleList(
[
torch.nn.Linear(config.n_embd, config.n_embd * 4)
for _ in range(config.n_layer)
]
)
debug_params = sum(p.numel() for p in self.debug_projections.parameters())
print(f" [DEBUG] Added {config.n_layer} debug projection layers")
print(f" [DEBUG] Extra parameters: {debug_params:,}")
def forward(self, input_ids=None, labels=None, **kwargs):
# Run normal GPT-2 forward with hidden states
outputs = self.base_model(
input_ids=input_ids,
labels=labels,
output_hidden_states=True,
**kwargs,
)
# BUG: Project all hidden states through debug layers
projected = []
for _layer_idx, (hidden, proj) in enumerate(
zip(outputs.hidden_states[1:], self.debug_projections)
):
proj_hidden = proj(hidden)
projected.append(proj_hidden)
# Tie to loss so gradients flow through
debug_regularization = sum(p.mean() for p in projected) * 1e-10
return CausalLMOutputWithCrossAttentions(
loss=outputs.loss + debug_regularization,
logits=outputs.logits,
)
用于调试比较的训练函数#
def run_training_clean(snapshot_path, num_steps=3):
"""Training with the normal model."""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
device = torch.device("cuda")
print("Loading clean model (no debug overhead)...")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
model.train()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
dataset = RandomTokenDataset(
vocab_size=tokenizer.vocab_size, seq_length=512, seed=42
)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
print("Running training (should contain no debug overhead)...")
with capture_memory_snapshot(snapshot_path):
for step, batch in enumerate(dataloader):
if step >= num_steps:
break
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"])
loss = outputs.loss
loss.backward()
optimizer.step()
print(f" Step {step + 1}, Loss: {loss.item():.4f}")
peak_memory = torch.cuda.max_memory_allocated() / 1024**3
print(f"Peak GPU memory: {peak_memory:.2f} GB")
del model, optimizer
torch.cuda.empty_cache()
return peak_memory
def run_training_with_bug(snapshot_path, num_steps=3):
"""Training with the buggy model."""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
device = torch.device("cuda")
print("Loading buggy model with debug overhead...")
# Load pretrained GPT-2 and wrap it with the debug overhead
base_model = GPT2LMHeadModel.from_pretrained("gpt2")
model = GPT2WithDebugOverhead(base_model).to(device)
model.train()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
dataset = RandomTokenDataset(
vocab_size=tokenizer.vocab_size, seq_length=512, seed=42
)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
print("Running training (WITH debug overhead bug)...")
with capture_memory_snapshot(snapshot_path):
for step, batch in enumerate(dataloader):
if step >= num_steps:
break
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"])
loss = outputs.loss
loss.backward()
optimizer.step()
print(f" Step {step + 1}, Loss: {loss.item():.4f}")
peak_memory = torch.cuda.max_memory_allocated() / 1024**3
print(f"Peak GPU memory: {peak_memory:.2f} GB")
del model, optimizer
torch.cuda.empty_cache()
return peak_memory
运行基准(干净模型)训练#
if HAS_CUDA:
print("\n" + "=" * 60)
print("Training with baseline model")
print("=" * 60)
baseline_memory_debug = run_training_clean(
"snapshot_debug_baseline.pickle", num_steps=3
)
运行带有错误的训练#
if HAS_CUDA:
print("\n" + "=" * 60)
print("Training with debug projection overhead (BUG)")
print("=" * 60)
buggy_memory = run_training_with_bug("snapshot_with_bug.pickle", num_steps=3)
使用 Mosaic 查找问题#
分析两个快照以识别额外内存使用的来源。我们将分别在每个快照上运行 Mosaic 的峰值内存分析。
分析基准(干净)快照#
if HAS_CUDA and HAS_MOSAIC_CLI:
print("=" * 60)
print("MOSAIC: Analyzing the Baseline Snapshot")
print("=" * 60)
result = subprocess.run(
["mosaic_get_memory_usage_peak", "--snapshot", "snapshot_debug_baseline.pickle"],
capture_output=True,
text=True,
)
print(result.stdout)
if result.stderr:
print(result.stderr)
分析有问题的快照#
if HAS_CUDA and HAS_MOSAIC_CLI:
print("=" * 60)
print("MOSAIC: Analyzing the Buggy Snapshot")
print("=" * 60)
result = subprocess.run(
["mosaic_get_memory_usage_peak", "--snapshot", "snapshot_with_bug.pickle"],
capture_output=True,
text=True,
)
print(result.stdout)
if result.stderr:
print(result.stderr)
分析 Mosaic 输出#
当您运行 Mosaic 的峰值内存分析时,它会显示每次内存分配的堆栈跟踪。让我们看看如何找到导致内存膨胀的遗留或不必要的代码。
1. 优化器状态分配增量
在有问题的快照输出中,我们可以看到前两个堆栈跟踪代表优化器状态分配(例如 Adam 优化器状态的 zeros_like)。请参阅堆栈跟踪中的 torch/optim/adam.py。
在有问题的模型快照中,我们可以看到总共多出了大约 0.21 GB 的内存
版本 |
堆栈跟踪位置 |
调用次数 |
内存(每个跟踪) |
|---|---|---|---|
问题模型 |
第 1 和第 2 |
172 次调用 |
0.569 GB + 0.569 GB |
基准 |
第 2 和第 3 |
148 次调用 |
0.464 GB + 0.464 GB |
这告诉我们:优化器正在跟踪更多的张量!这是您发现计算图中存在额外参数或张量的第一个线索。
2. 额外的激活分配
有问题的版本显示了基准模型中没有的额外分配。向下滚动有问题的模型的 Mosaic 输出,我们可以看到包含以下内容的额外堆栈跟踪
torch::autograd::Engine::evaluate_function:我们处于反向传播阶段AddmmBackward0::apply:计算 addmm 操作的梯度底部的
empty_cuda:分配一个新的 CUDA 张量来存储梯度
来自矩阵乘法梯度的 0.176 GB (
AddmmBackward0,mm_mat1_backward)
内存总量说明#
总峰值动态内存使用量: 这是执行期间发生变化的峰值内存,相对于快照的起点进行测量。它跟踪在跟踪的执行时间轴期间发生的内存分配。
总静态内存使用量: 这是在跟踪开始之前存在的“起始内存”或基准内存。它由 PyTorch 可视化工具估计,并在整个快照期间保持不变(没有伴随堆栈跟踪)。
注意
在快照中,您可能会观察到总静态内存使用量的差异,这解释了剩余的差异。
总体峰值内存使用量: 动态 + 静态
if HAS_CUDA:
print("\n" + "=" * 60)
print("COMPARISON")
print("=" * 60)
print(f"Baseline (clean model): {baseline_memory_debug:.2f} GB")
print(f"With bug (debug projections): {buggy_memory:.2f} GB")
print(
f"Extra memory from bug: {buggy_memory - baseline_memory_debug:.2f} GB"
)
案例 3:将内存分析集成到您的训练流水线中#
本节演示如何使用 Mosaic 在训练期间自动捕获内存快照,获取用于监控/仪表板的结构化内存分解数据,并使用 Mosaic 以编程方式(作为 Python 依赖项)为大规模训练构建自动内存监控。
Mosaic 将内存分析直接集成到您的训练流水线中。
带有自动内存捕获的训练#
def run_training_with_memory_capture(
batch_size=4,
seq_length=512,
num_steps=5,
snapshot_path="training_snapshot.pickle",
):
"""Run training and automatically capture memory snapshot."""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
device = torch.device("cuda")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
model.train()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
dataset = RandomTokenDataset(tokenizer.vocab_size, seq_length)
dataloader = DataLoader(dataset, batch_size=batch_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
print(f"Running {num_steps} training steps with memory capture...")
with capture_memory_snapshot(snapshot_path):
for step, batch in enumerate(dataloader):
if step >= num_steps:
break
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"])
outputs.loss.backward()
optimizer.step()
print(f" Step {step + 1}/{num_steps}, Loss: {outputs.loss.item():.4f}")
peak_memory_gb = torch.cuda.max_memory_allocated() / 1024**3
print(f"✓ PyTorch reported peak memory: {peak_memory_gb:.3f} GB")
del model, optimizer
torch.cuda.empty_cache()
return snapshot_path
if HAS_CUDA:
print("\n" + "=" * 60)
print("CASE 3: Pipeline Integration")
print("=" * 60)
pipeline_snapshot_path = run_training_with_memory_capture(batch_size=4, seq_length=512)
通过 Python API 进行 Mosaic 内存分析#
无需使用 CLI 命令,我们可以直接使用 Mosaic 的 Python API 进行编程集成。
if HAS_CUDA:
print("\n" + "=" * 60)
print("MOSAIC MEMORY ANALYSIS (via Python API)")
print("=" * 60)
# Load and analyze the memory snapshot
memory_abstract = MemoryAbstract(memory_snapshot_file=pipeline_snapshot_path)
memory_abstract.load_memory_snapshot()
# Analyze peak memory usage
memory_abstract.memory_snapshot.analyze_memory_snapshot(opt="memory_peak")
# Get results
dynamic_peak = memory_abstract.memory_snapshot.dynamic_memory_peak
static_memory = memory_abstract.memory_snapshot.static_memory
overall_peak = dynamic_peak + static_memory
print(f"Peak dynamic memory: {dynamic_peak / 1024**3:.3f} GiB")
print(f"Static memory: {static_memory / 1024**3:.3f} GiB")
print(f"Overall peak memory: {overall_peak / 1024**3:.3f} GiB")
print("✓ Analysis complete using Mosaic Python API")
可重用的内存分析函数#
创建一个用于分析训练内存快照的可重用函数。
def analyze_training_memory(snapshot_path):
"""Analyze a memory snapshot using Mosaic's Python API.
Returns a structured dictionary with memory breakdown.
Args:
snapshot_path: Path to the memory snapshot pickle file.
Returns:
Dictionary containing memory analysis results.
"""
# Load snapshot
memory_abstract = MemoryAbstract(memory_snapshot_file=snapshot_path)
memory_abstract.load_memory_snapshot()
# Analyze peak memory
memory_abstract.memory_snapshot.analyze_memory_snapshot(opt="memory_peak")
# Extract results
dynamic_peak = memory_abstract.memory_snapshot.dynamic_memory_peak
static_memory = memory_abstract.memory_snapshot.static_memory
overall_peak = dynamic_peak + static_memory
return {
"snapshot_path": snapshot_path,
"dynamic_peak_memory_bytes": dynamic_peak,
"static_memory_bytes": static_memory,
"overall_peak_memory_bytes": overall_peak,
"dynamic_peak_memory_gib": dynamic_peak / 1024**3,
"static_memory_gib": static_memory / 1024**3,
"overall_peak_memory_gib": overall_peak / 1024**3,
}
if HAS_CUDA:
analysis = analyze_training_memory(pipeline_snapshot_path)
print("\nMemory Analysis Result:")
for key, value in analysis.items():
print(f" {key}: {value}")
带有内存监控的完整训练流水线#
这演示了一个生产就绪的训练流水线,其中集成了 Mosaic 内存监控,可用于 CI/CD、监控仪表板或容量规划。
def training_pipeline_with_memory_monitoring(
model_name: str,
batch_size: int,
seq_length: int,
num_steps: int = 5,
snapshot_path: str = "pipeline_snapshot.pickle",
) -> dict:
"""Complete training pipeline with integrated Mosaic memory monitoring.
Can be integrated into CI/CD, monitoring dashboards, or capacity planning.
Args:
model_name: HuggingFace model name to use.
batch_size: Training batch size.
seq_length: Sequence length for input tokens.
num_steps: Number of training steps.
snapshot_path: Path to save the memory snapshot.
Returns:
Dictionary containing training and memory analysis report.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Setup
print(f"Loading model: {model_name}")
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Training with memory capture
print(f"Running {num_steps} training steps...")
with capture_memory_snapshot(snapshot_path):
for step in range(num_steps):
input_ids = torch.randint(
0, tokenizer.vocab_size, (batch_size, seq_length)
).to(device)
outputs = model(input_ids=input_ids, labels=input_ids)
outputs.loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f" Step {step + 1}/{num_steps}, Loss: {outputs.loss.item():.4f}")
pytorch_peak_gb = torch.cuda.max_memory_allocated() / 1024**3
# Mosaic analysis using Python API
print("Analyzing memory with Mosaic...")
memory_abstract = MemoryAbstract(memory_snapshot_file=snapshot_path)
memory_abstract.load_memory_snapshot()
memory_abstract.memory_snapshot.analyze_memory_snapshot(opt="memory_peak")
dynamic_peak = memory_abstract.memory_snapshot.dynamic_memory_peak
static_memory = memory_abstract.memory_snapshot.static_memory
overall_peak = dynamic_peak + static_memory
report = {
"model": model_name,
"config": {
"batch_size": batch_size,
"seq_length": seq_length,
"num_steps": num_steps,
},
"pytorch_peak_memory_gb": pytorch_peak_gb,
"mosaic_analysis": {
"dynamic_peak_gib": dynamic_peak / 1024**3,
"static_memory_gib": static_memory / 1024**3,
"overall_peak_gib": overall_peak / 1024**3,
},
"snapshot_path": snapshot_path,
}
del model, optimizer
torch.cuda.empty_cache()
return report
# Run the pipeline
if HAS_CUDA:
report = training_pipeline_with_memory_monitoring(
"gpt2", batch_size=4, seq_length=512, num_steps=5
)
print("\n" + "=" * 60)
print("PIPELINE REPORT")
print("=" * 60)
print(f"Model: {report['model']}")
print(f"Config: {report['config']}")
print(f"PyTorch Peak Memory: {report['pytorch_peak_memory_gb']:.3f} GB")
print(f"Mosaic Dynamic Peak: {report['mosaic_analysis']['dynamic_peak_gib']:.3f} GiB")
print(f"Mosaic Overall Peak: {report['mosaic_analysis']['overall_peak_gib']:.3f} GiB")
CI/CD 和仪表板集成模式#
这些模式展示了如何将 Mosaic 分析集成到自动化工作流中。
import json
模式 1:CI/CD 内存回归测试#
def check_memory_regression(report, threshold_gib=5.0):
"""Check if memory usage exceeds threshold for CI/CD pipelines.
Args:
report: Memory analysis report from training_pipeline_with_memory_monitoring.
threshold_gib: Maximum allowed memory in GiB.
Raises:
AssertionError: If memory exceeds threshold.
"""
peak = report["mosaic_analysis"]["overall_peak_gib"]
assert peak < threshold_gib, (
f"Memory regression! {peak:.2f} GiB > {threshold_gib} GiB"
)
print(f"Memory check passed: {peak:.2f} GiB < {threshold_gib} GiB threshold")
模式 2:导出为 JSON 以供仪表板使用#
if HAS_CUDA:
check_memory_regression(report, threshold_gib=8.0)
with open("memory_report.json", "w") as f:
json.dump(report, f, indent=2, default=str)
print("Memory report exported to memory_report.json")
结论#
本教程演示了 Mosaic 内存性能分析的三个关键用例
案例 1:激活检查点分析
使用 Mosaic 比较基准模型和优化模型之间的内存使用情况
确定激活检查点将激活内存减少了 71%
Mosaic 的类别化性能分析使得确定内存节省变得简单
案例 2:调试异常内存占用
创建了一个带有遗留调试代码的“有缺陷”模型
使用
mosaic_get_memory_usage_peak识别额外的分配堆栈跟踪显示优化器状态跟踪了额外的参数
案例 3:流水线集成
通过 Mosaic 的 Python API 展示了编程用法
展示了带有结构化报告的 CI/CD 和仪表板集成模式