Conversation
| num_tokens_per_expert.view(ep_size, -1) | ||
| .sum(dim=1) | ||
| .to(torch.device("cpu"), non_blocking=True) | ||
| .to(torch.device("cpu"), non_blocking=False) |
There was a problem hiding this comment.
oh could you remind me of the reason we have to use non_blocking=False?
I think it may not matter too much as this two d2h syncs are adjacent to each other.
If we have to do this, we can remove the non_blocking arg as False is the default.
There was a problem hiding this comment.
that is a good point, we can avoid blocking for the first .to(), although yeah I don't think it changed tps very much
| torch.ops.aten._scaled_dot_product_efficient_attention.default, | ||
| torch.ops.aten._scaled_dot_product_flash_attention.default, | ||
| torch.ops._c10d_functional.reduce_scatter_tensor.default, | ||
| torch.ops._c10d_functional.all_to_all_single.default, |
There was a problem hiding this comment.
Could you confirm that this saves both:
- the
dist.all_to_all_singleto obtain routing info - the actual
all_to_all_single_autogradto route tokens
I think ideally we'd like both to be saved.
There was a problem hiding this comment.
Good catch, dist.all_to_all_single is actually a different op.
Do you know why we use two different all-to-alls here? I don't think dist.all_to_all_single works with SAC, it
mutates a "output" tensor that the user provides and returns a work object.
There was a problem hiding this comment.
oh, there may not be strong reason to.
Could you try the fun col version? https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L445
If it works we can switch to this one, and hopefully the AC policy would capture both, because underlying the same torch.ops._c10d_functional.all_to_all_single.default gets called.
There was a problem hiding this comment.
Updated to use the fun col version
|
Just a FYI, @soulitzer , #1675 conflicts with this PR. |
| ) | ||
| num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( | ||
| num_tokens_per_expert_group | ||
| ) |
There was a problem hiding this comment.
We need an explicit wait because num_tokens_per_expert_group gets used by a triton kernel, which doesn't realize that AsyncCollectiveTensor needs to be unwrapped.
There was a problem hiding this comment.
could you make this a comment in the code? I think it's very helpful.
There was a problem hiding this comment.
Hi @soulitzer, I have some questions here:
(1) If we need to issue this wait explicitly because of a triton kernel, does this mean in eager/aot_eager, it's okay to not explicit wait on num_tokens_per_expert_group?
(2) Would be super curious which triton kernel you are referring to : )
There was a problem hiding this comment.
Hmm not sure because this triton kernel is a kernel the user issues explicitly (not one that inductor generates).
Would be super curious which triton kernel you are referring to
I don't remember exactly, but it should be possible to check by removing the wait. The triton kernel will error because it tries to get the null storage ptr from the wrapper tensor subclass.
There was a problem hiding this comment.
i see, thank you for the clarification. I found this when checking tlparse where a wait_tensor is waiting for another wait_tensor. I'm just curious about what is happening here loll.
FYI, I didn't hit errors after removing the wait_tensor.
tianyu-l
left a comment
There was a problem hiding this comment.
Had one more question regarding saving the results of wait_tensor.
Also it would be great if you could share some profiler traces in PR summary.
| torch.ops.aten._scaled_dot_product_flash_attention.default, | ||
| torch.ops._c10d_functional.reduce_scatter_tensor.default, | ||
| torch.ops._c10d_functional.all_to_all_single.default, | ||
| torch.ops._c10d_functional.wait_tensor.default, |
There was a problem hiding this comment.
I wonder if this has any side effect, as in the mapping from all collectives and wait is many-to-one. In particular,
- Would this line save all the communication results, not only from a2a but also e.g. TP all-gather?
- Would not having this line save none of the communication results? I.e. did the
torch.ops._c10d_functional.reduce_scatter_tensor.default,line take effect?
There was a problem hiding this comment.
I don't think there should be side effect, unless there are other wait_tensors being explicitly called.
-
Ordinarily, the AsyncCollectiveTensor triggers the wait before executing the op in its torch dispatch, so it would actually be hidden from SAC (user modes execute before user subclasses unwrap). SAC should only be able to see / save the wait if we're calling it explicitly here.
-
The lines for reduce scatter, etc will save AsyncCollectiveTensor, and in the original forward, when wait happens via the subclasses's torch dispatch, the wait result should be cached onto the AsyncCollectiveTensor, so that a second wait should not be triggered during recompute.
That being said. I'm not actually entirely sure what happens when you executing wait explicitly on an AsyncCollectiveTensor again even though the collective has already been waited on. Checking again, removing it doesn't seem to affect tps, so I think I will remove it.
|
Added links to some profiler traces in the summary. From staring at the traces, saving the wait_tensor reduces the cpu overhead from 200us to 40us, but doesn't really seem to affect tps, so removing it to minimize risk of side effects. |
| ) | ||
| num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor( | ||
| num_tokens_per_expert_group | ||
| ) |
There was a problem hiding this comment.
could you make this a comment in the code? I think it's very helpful.
| .to(torch.device("cpu"), non_blocking=True) | ||
| .to(torch.device("cpu"), non_blocking=False) | ||
| ) | ||
| # NOTE: this would incur a device-to-host sync |
There was a problem hiding this comment.
please move this note to the actual blocking call above
| and "device" in kwargs | ||
| and str(kwargs["device"]) == "cpu" | ||
| ): | ||
| return CheckpointPolicy.MUST_SAVE |
There was a problem hiding this comment.
hmm did this take effect? I would guess we don't need to do any d2h sync in backward anymore, but in the traces I'm still seeing them in backward.
(save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a_wait_tensor.json.gz&bucket=pytorch
(don't save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a.json.gz&bucket=pytorch
There was a problem hiding this comment.
Hmm there is still cudaStreamSync in FlexAttentionBackward but it is expected since SAC only takes effect for the replay of the forward. Is there another place where you see it?
``` CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml ./run_train.sh --parallelism.expert_parallel_degree 4 --model.hf_assets_path "./assets/hf/deepseek-moe-16b-base" ``` Before (not saving a2a and to_copy) ``` [rank0]:[titan] 2025-09-01 16:03:30,779 - root - INFO - step: 1 loss: 12.0469 grad_norm: 1.8420 memory: 61.72GiB(77.99%) tps: 297 tflops: 4.48 mfu: 1.43% [rank0]:[titan] 2025-09-01 16:03:30,780 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-09-01 16:04:22,882 - root - INFO - step: 10 loss: 11.2800 grad_norm: 2.4321 memory: 70.02GiB(88.48%) tps: 708 tflops: 10.65 mfu: 3.41% ``` https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_recompute.json.gz&bucket=pytorch After (saving a2a and to_copy) ``` [rank0]:[titan] 2025-09-01 16:01:39,691 - root - INFO - step: 1 loss: 12.0470 grad_norm: 1.8420 memory: 64.49GiB(81.50%) tps: 321 tflops: 4.82 mfu: 1.55% [rank0]:[titan] 2025-09-01 16:01:39,691 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-09-01 16:02:25,603 - root - INFO - step: 10 loss: 11.2801 grad_norm: 2.4322 memory: 74.53GiB(94.17%) tps: 803 tflops: 12.08 mfu: 3.87% ``` (save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a_wait_tensor.json.gz&bucket=pytorch (don't save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a.json.gz&bucket=pytorch
``` CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml ./run_train.sh --parallelism.expert_parallel_degree 4 --model.hf_assets_path "./assets/hf/deepseek-moe-16b-base" ``` Before (not saving a2a and to_copy) ``` [rank0]:[titan] 2025-09-01 16:03:30,779 - root - INFO - step: 1 loss: 12.0469 grad_norm: 1.8420 memory: 61.72GiB(77.99%) tps: 297 tflops: 4.48 mfu: 1.43% [rank0]:[titan] 2025-09-01 16:03:30,780 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-09-01 16:04:22,882 - root - INFO - step: 10 loss: 11.2800 grad_norm: 2.4321 memory: 70.02GiB(88.48%) tps: 708 tflops: 10.65 mfu: 3.41% ``` https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_recompute.json.gz&bucket=pytorch After (saving a2a and to_copy) ``` [rank0]:[titan] 2025-09-01 16:01:39,691 - root - INFO - step: 1 loss: 12.0470 grad_norm: 1.8420 memory: 64.49GiB(81.50%) tps: 321 tflops: 4.82 mfu: 1.55% [rank0]:[titan] 2025-09-01 16:01:39,691 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-09-01 16:02:25,603 - root - INFO - step: 10 loss: 11.2801 grad_norm: 2.4322 memory: 74.53GiB(94.17%) tps: 803 tflops: 12.08 mfu: 3.87% ``` (save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a_wait_tensor.json.gz&bucket=pytorch (don't save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a.json.gz&bucket=pytorch
Before (not saving a2a and to_copy)
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_recompute.json.gz&bucket=pytorch
After (saving a2a and to_copy)
(save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a_wait_tensor.json.gz&bucket=pytorch
(don't save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a.json.gz&bucket=pytorch