5656from torch .testing ._internal .distributed .checkpoint_utils import with_temp_dir
5757
5858c10d_ops = torch .ops .c10d
59+ funcol = torch .ops .c10d_functional
5960
6061
6162class TestFullyShardForwardInputs (FSDPTestMultiThread ):
@@ -927,7 +928,10 @@ def _test_train_parity_2d_mlp(
927928 replicate (ref_model , device_ids = [self .rank ], process_group = dp_pg )
928929 ref_optim = torch .optim .Adam (ref_model .parameters (), lr = 1e-2 , foreach = False )
929930 model .parallelize (
930- tp_mesh , dp_mesh , use_activation_checkpointing , reshard_after_forward
931+ tp_mesh ,
932+ dp_mesh ,
933+ use_activation_checkpointing ,
934+ reshard_after_forward = reshard_after_forward ,
931935 )
932936 optim = torch .optim .Adam (model .parameters (), lr = 1e-2 , foreach = False )
933937
@@ -943,6 +947,61 @@ def _test_train_parity_2d_mlp(
943947 _optim .step ()
944948 self .assertEqual (losses [0 ], losses [1 ])
945949
950+ @skip_if_lt_x_gpu (2 )
951+ def test_tp_with_fsdp_offloading (self ):
952+ global_mesh = init_device_mesh (
953+ "cuda" , (1 , self .world_size ), mesh_dim_names = ("dp" , "tp" )
954+ )
955+ dp_mesh , tp_mesh = global_mesh ["dp" ], global_mesh ["tp" ]
956+ torch .manual_seed (42 )
957+ mlp_dim = 16
958+ model = MLPStack (mlp_dim )
959+ ref_model = copy .deepcopy (model ).cuda ()
960+ ref_optim = torch .optim .Adam (ref_model .parameters (), lr = 1e-2 , foreach = False )
961+ # Parallelize with N-way TP and 1-way FSDP
962+ model .parallelize (
963+ tp_mesh ,
964+ dp_mesh ,
965+ use_activation_checkpointing = False ,
966+ reshard_after_forward = True ,
967+ offload_policy = CPUOffloadPolicy (),
968+ )
969+ for param in model .parameters ():
970+ self .assertEqual (param .device .type , "cpu" )
971+ num_mlps = sum (isinstance (module , MLP ) for module in model .modules ())
972+ optim = torch .optim .Adam (model .parameters (), lr = 1e-2 , foreach = False )
973+
974+ # NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops
975+ # called, but they will just be no-ops without issuing any kernels.
976+ # We prefer to keep the no-op check at the c10d level, not in FSDP.
977+ inp = torch .randn ((4 , mlp_dim ), device = "cuda" ) # same on all ranks
978+ for iter_idx in range (10 ):
979+ ref_optim .zero_grad ()
980+ optim .zero_grad ()
981+
982+ with CommDebugMode () as fwd_comm_mode :
983+ loss = model (inp ).sum ()
984+
985+ fwd_comm_counts = fwd_comm_mode .get_comm_counts ()
986+ self .assertEqual (len (fwd_comm_counts ), 2 )
987+ self .assertEqual (fwd_comm_counts [funcol .all_reduce ], num_mlps )
988+ self .assertEqual (fwd_comm_counts [c10d_ops ._allgather_base_ ], num_mlps )
989+ ref_loss = ref_model (inp ).sum ()
990+ self .assertEqual (loss , ref_loss )
991+
992+ with CommDebugMode () as bwd_comm_mode :
993+ loss .backward ()
994+ bwd_comm_counts = bwd_comm_mode .get_comm_counts ()
995+ self .assertEqual (len (bwd_comm_counts ), 3 )
996+ # First MLP's input gradient does not need to be all-reduced
997+ self .assertEqual (bwd_comm_counts [funcol .all_reduce ], num_mlps - 1 )
998+ self .assertEqual (bwd_comm_counts [c10d_ops ._allgather_base_ ], num_mlps )
999+ self .assertEqual (bwd_comm_counts [c10d_ops ._reduce_scatter_base_ ], num_mlps )
1000+ ref_loss .backward ()
1001+
1002+ optim .step ()
1003+ ref_optim .step ()
1004+
9461005 @skip_if_lt_x_gpu (2 )
9471006 @with_temp_dir
9481007 def test_train_parity_2d_transformer_checkpoint_resume (self ):
0 commit comments