Skip to content

Commit 9117779

Browse files
Andrew Gupytorchmergebot
authored andcommitted
[FSDP2] Added test for N-way TP and 1-way FSDP with CPU offloading (#127024)
This PR shows that we can use FSDP solely for CPU offloading when composing with N-way TP. Each FSDP mesh is just 1 rank. This was motivated from an ask on Slack :) Pull Request resolved: #127024 Approved by: https://github.com/weifengpy, https://github.com/wanchaol ghstack dependencies: #127004
1 parent 87f79af commit 9117779

2 files changed

Lines changed: 63 additions & 6 deletions

File tree

test/distributed/_composable/fsdp/test_fully_shard_training.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
5757

5858
c10d_ops = torch.ops.c10d
59+
funcol = torch.ops.c10d_functional
5960

6061

6162
class TestFullyShardForwardInputs(FSDPTestMultiThread):
@@ -927,7 +928,10 @@ def _test_train_parity_2d_mlp(
927928
replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
928929
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
929930
model.parallelize(
930-
tp_mesh, dp_mesh, use_activation_checkpointing, reshard_after_forward
931+
tp_mesh,
932+
dp_mesh,
933+
use_activation_checkpointing,
934+
reshard_after_forward=reshard_after_forward,
931935
)
932936
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
933937

@@ -943,6 +947,61 @@ def _test_train_parity_2d_mlp(
943947
_optim.step()
944948
self.assertEqual(losses[0], losses[1])
945949

950+
@skip_if_lt_x_gpu(2)
951+
def test_tp_with_fsdp_offloading(self):
952+
global_mesh = init_device_mesh(
953+
"cuda", (1, self.world_size), mesh_dim_names=("dp", "tp")
954+
)
955+
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
956+
torch.manual_seed(42)
957+
mlp_dim = 16
958+
model = MLPStack(mlp_dim)
959+
ref_model = copy.deepcopy(model).cuda()
960+
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
961+
# Parallelize with N-way TP and 1-way FSDP
962+
model.parallelize(
963+
tp_mesh,
964+
dp_mesh,
965+
use_activation_checkpointing=False,
966+
reshard_after_forward=True,
967+
offload_policy=CPUOffloadPolicy(),
968+
)
969+
for param in model.parameters():
970+
self.assertEqual(param.device.type, "cpu")
971+
num_mlps = sum(isinstance(module, MLP) for module in model.modules())
972+
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
973+
974+
# NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops
975+
# called, but they will just be no-ops without issuing any kernels.
976+
# We prefer to keep the no-op check at the c10d level, not in FSDP.
977+
inp = torch.randn((4, mlp_dim), device="cuda") # same on all ranks
978+
for iter_idx in range(10):
979+
ref_optim.zero_grad()
980+
optim.zero_grad()
981+
982+
with CommDebugMode() as fwd_comm_mode:
983+
loss = model(inp).sum()
984+
985+
fwd_comm_counts = fwd_comm_mode.get_comm_counts()
986+
self.assertEqual(len(fwd_comm_counts), 2)
987+
self.assertEqual(fwd_comm_counts[funcol.all_reduce], num_mlps)
988+
self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
989+
ref_loss = ref_model(inp).sum()
990+
self.assertEqual(loss, ref_loss)
991+
992+
with CommDebugMode() as bwd_comm_mode:
993+
loss.backward()
994+
bwd_comm_counts = bwd_comm_mode.get_comm_counts()
995+
self.assertEqual(len(bwd_comm_counts), 3)
996+
# First MLP's input gradient does not need to be all-reduced
997+
self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1)
998+
self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
999+
self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps)
1000+
ref_loss.backward()
1001+
1002+
optim.step()
1003+
ref_optim.step()
1004+
9461005
@skip_if_lt_x_gpu(2)
9471006
@with_temp_dir
9481007
def test_train_parity_2d_transformer_checkpoint_resume(self):

torch/testing/_internal/common_fsdp.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ def parallelize(
893893
tp_mesh: DeviceMesh,
894894
dp_mesh: DeviceMesh,
895895
use_activation_checkpointing: bool,
896-
reshard_after_forward: bool,
896+
**fsdp_kwargs,
897897
) -> "MLPStack":
898898
parallelize_plan = {
899899
# Pass `use_local_output=False` to keep as DTensor to preserve
@@ -915,10 +915,8 @@ def parallelize(
915915
continue
916916
if use_activation_checkpointing:
917917
checkpoint(module)
918-
fully_shard(
919-
module, mesh=dp_mesh, reshard_after_forward=reshard_after_forward
920-
)
921-
fully_shard(self, mesh=dp_mesh, reshard_after_forward=reshard_after_forward)
918+
fully_shard(module, mesh=dp_mesh, **fsdp_kwargs)
919+
fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)
922920
return self
923921

924922

0 commit comments

Comments
 (0)