[FSDP][dtensor] use _StridedShard to represent nested sharding for correct full_tensor() result#130760
[FSDP][dtensor] use _StridedShard to represent nested sharding for correct full_tensor() result#130760XilunWu wants to merge 17 commits intogh/XilunWu/90/basefrom
Conversation
…rrect full_tensor() result [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130760
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 3 Unrelated FailuresAs of commit 2d4314e with merge base da32021 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ding for correct full_tensor() result" cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
test/distributed/_composable/fsdp/test_fully_shard_state_dict.py
Outdated
Show resolved
Hide resolved
…ding for correct full_tensor() result" Fixes issue #129229 #129206 **Summary** 1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding 2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim. 3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result 4. Re-enabled the tests that were disabled in #129519 and removed relevant code **test** `pytest test/distributed/_composable/fsdp/test_fully_shard_training.py` `pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py` cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang [ghstack-poisoned]
…rrect full_tensor() result ghstack-source-id: d996bc8 Pull Request resolved: pytorch/pytorch#130760
…ding for correct full_tensor() result" Fixes issue #129229 #129206 **Summary** 1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding 2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim. 3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result 4. Re-enabled the tests that were disabled in #129519 and removed relevant code **test** `pytest test/distributed/_composable/fsdp/test_fully_shard_training.py` `pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py` cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang [ghstack-poisoned]
| self._spmd_placements, | ||
| tensor_meta=self._tp_spec.tensor_meta, | ||
| ) | ||
| # NOTE: FSDP+TP does not support uneven sharding for now |
There was a problem hiding this comment.
hmmm I think this check is too strict, we should ONLY check the tensor dimension 0 for uneven sharding when the TP sharding spec sharded on tensor dim 0, as:
- DTensor supports uneven sharding computation for its normal case.
- Only when strided sharding enabled we don't yet support uneven sharding, and the strided sharding only enabled when the TP's DTensor is sharded on tensor dim 0 (and FSDP2 shards on tensor dim 0 too, which is the case where StridedShard comes in)
So please fix this check to be more accurate
There was a problem hiding this comment.
thanks for catching the mistake!! I added a check on split_factor and only perform the sharding evenness check when it's larger than 1 (i.e. tensor dim 0 has been sharded on TP mesh dim).
…ding for correct full_tensor() result" Fixes issue #129229 #129206 **Summary** 1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding 2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result 3. Re-enabled the tests that were disabled in #129519 **test** `pytest test/distributed/_composable/fsdp/` `pytest test/distributed/_composable/test_composability/test_2d_composability.py` `pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py` cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114) [ghstack-poisoned]
…ding for correct full_tensor() result" Fixes issue #129229 #129206 **Summary** 1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding 2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result 3. Re-enabled the tests that were disabled in #129519 **test** `pytest test/distributed/_composable/fsdp/` `pytest test/distributed/_composable/test_composability/test_2d_composability.py` `pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py` cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114) [ghstack-poisoned]
|
@pytorchbot merge -f "unrelated CI failures" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
**Summary** 1. re-enable FSDP+TP 2D in torchtitan. 2. remove temporary re-enablement added in #460 **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
**Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
**Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
**Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
**Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
…ch version that not including DTensor strided sharding" **Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding (pytorch/pytorch#130760) when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
…not including DTensor strided sharding" **Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding (pytorch/pytorch#130760) when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
…ch version that not including DTensor strided sharding" **Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding (pytorch/pytorch#130760) when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
…not including DTensor strided sharding" **Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding (pytorch/pytorch#130760) when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
…ng DTensor strided sharding (#507) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #507 **Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding (pytorch/pytorch#130760) when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ```
**Summary** 1. re-enable FSDP+TP 2D in torchtitan. 2. remove temporary re-enablement added in #460 **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
**Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
**Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
…not including DTensor strided sharding" **Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding (pytorch/pytorch#130760) when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
…not including DTensor strided sharding" **Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding (pytorch/pytorch#130760) when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ``` When to merge: when pytorch/pytorch#130760 is in nightly build. [ghstack-poisoned]
…ng DTensor strided sharding (#507) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #507 **Summary** 1. check if users are using new nightly-build pytorch that includes DTensor strided sharding (pytorch/pytorch#130760) when 2D/3D is used. Print warning if not. 2. remove temporary re-enablement added in #460 . **Test** Command: `python test_runner.py outputs --test pp_dp_tp --ngpu 8` GPUs: A100 Output: - without strided sharding: ``` [rank7]:2024-08-06 03:21:26,706 - root - INFO - step: 2 loss: 8.1652 memory: 0.51GiB(0.64%) wps: 8,250 mfu: 0.25% [rank7]:2024-08-06 03:21:27,013 - root - INFO - step: 3 loss: 8.0951 memory: 0.51GiB(0.64%) wps: 13,358 mfu: 0.41% [rank7]:2024-08-06 03:21:27,309 - root - INFO - step: 4 loss: 7.9748 memory: 0.51GiB(0.64%) wps: 13,865 mfu: 0.42% [rank7]:2024-08-06 03:21:27,582 - root - INFO - step: 5 loss: 7.8025 memory: 0.51GiB(0.64%) wps: 15,057 mfu: 0.46% [rank7]:2024-08-06 03:21:28,076 - root - INFO - step: 6 loss: 7.5612 memory: 0.51GiB(0.64%) wps: 8,300 mfu: 0.25% [rank7]:2024-08-06 03:21:28,608 - root - INFO - step: 7 loss: 7.3649 memory: 0.51GiB(0.64%) wps: 7,705 mfu: 0.23% [rank7]:2024-08-06 03:21:28,927 - root - INFO - step: 8 loss: 7.2946 memory: 0.51GiB(0.64%) wps: 12,832 mfu: 0.39% [rank7]:2024-08-06 03:21:29,251 - root - INFO - step: 9 loss: 7.1311 memory: 0.51GiB(0.64%) wps: 12,669 mfu: 0.38% [rank7]:2024-08-06 03:21:29,627 - root - INFO - step: 10 loss: 7.0540 memory: 0.51GiB(0.64%) wps: 10,918 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:21:59,723 - root - INFO - step: 11 loss: 7.0822 memory: 0.51GiB(0.64%) wps: 1,139 mfu: 0.03% [rank7]:2024-08-06 03:22:00,054 - root - INFO - step: 12 loss: 7.0508 memory: 0.51GiB(0.64%) wps: 12,366 mfu: 0.38% [rank7]:2024-08-06 03:22:00,340 - root - INFO - step: 13 loss: 6.9182 memory: 0.51GiB(0.64%) wps: 14,370 mfu: 0.44% [rank7]:2024-08-06 03:22:00,624 - root - INFO - step: 14 loss: 6.8948 memory: 0.51GiB(0.64%) wps: 14,442 mfu: 0.44% [rank7]:2024-08-06 03:22:00,907 - root - INFO - step: 15 loss: 6.8358 memory: 0.51GiB(0.64%) wps: 14,514 mfu: 0.44% [rank7]:2024-08-06 03:22:01,574 - root - INFO - step: 16 loss: 6.7653 memory: 0.51GiB(0.64%) wps: 6,144 mfu: 0.19% [rank7]:2024-08-06 03:22:02,209 - root - INFO - step: 17 loss: 6.7340 memory: 0.51GiB(0.64%) wps: 6,453 mfu: 0.20% [rank7]:2024-08-06 03:22:02,532 - root - INFO - step: 18 loss: 6.6874 memory: 0.51GiB(0.64%) wps: 12,695 mfu: 0.39% [rank7]:2024-08-06 03:22:02,863 - root - INFO - step: 19 loss: 6.6566 memory: 0.51GiB(0.64%) wps: 12,406 mfu: 0.38% [rank7]:2024-08-06 03:22:03,257 - root - INFO - step: 20 loss: 6.6629 memory: 0.51GiB(0.64%) wps: 10,392 mfu: 0.32% ``` - with strided sharding ``` [rank7]:2024-08-06 03:26:18,288 - root - INFO - step: 1 loss: 8.2069 memory: 0.50GiB(0.63%) wps: 915 mfu: 0.03% [rank7]:2024-08-06 03:26:19,084 - root - INFO - step: 2 loss: 8.1913 memory: 0.51GiB(0.64%) wps: 5,144 mfu: 0.16% [rank7]:2024-08-06 03:26:19,365 - root - INFO - step: 3 loss: 8.1148 memory: 0.51GiB(0.64%) wps: 14,593 mfu: 0.44% [rank7]:2024-08-06 03:26:19,698 - root - INFO - step: 4 loss: 7.9982 memory: 0.51GiB(0.64%) wps: 12,328 mfu: 0.37% [rank7]:2024-08-06 03:26:20,011 - root - INFO - step: 5 loss: 7.8382 memory: 0.51GiB(0.64%) wps: 13,100 mfu: 0.40% [rank7]:2024-08-06 03:26:20,498 - root - INFO - step: 6 loss: 7.6293 memory: 0.51GiB(0.64%) wps: 8,423 mfu: 0.26% [rank7]:2024-08-06 03:26:21,126 - root - INFO - step: 7 loss: 7.4454 memory: 0.51GiB(0.64%) wps: 6,530 mfu: 0.20% [rank7]:2024-08-06 03:26:21,472 - root - INFO - step: 8 loss: 7.3337 memory: 0.51GiB(0.64%) wps: 11,843 mfu: 0.36% [rank7]:2024-08-06 03:26:21,849 - root - INFO - step: 9 loss: 7.1960 memory: 0.51GiB(0.64%) wps: 10,892 mfu: 0.33% [rank7]:2024-08-06 03:26:22,229 - root - INFO - step: 10 loss: 7.1208 memory: 0.51GiB(0.64%) wps: 10,798 mfu: 0.33% >>>>>>>>>>>>>>>>>Checkpoint save & load<<<<<<<<<<<<<<<<<<< [rank7]:2024-08-06 03:26:50,306 - root - INFO - step: 11 loss: 7.1222 memory: 0.51GiB(0.64%) wps: 866 mfu: 0.03% [rank7]:2024-08-06 03:26:50,632 - root - INFO - step: 12 loss: 7.1189 memory: 0.51GiB(0.64%) wps: 12,589 mfu: 0.38% [rank7]:2024-08-06 03:26:50,917 - root - INFO - step: 13 loss: 6.9646 memory: 0.51GiB(0.64%) wps: 14,417 mfu: 0.44% [rank7]:2024-08-06 03:26:51,217 - root - INFO - step: 14 loss: 6.9626 memory: 0.51GiB(0.64%) wps: 13,680 mfu: 0.42% [rank7]:2024-08-06 03:26:51,514 - root - INFO - step: 15 loss: 6.8694 memory: 0.51GiB(0.64%) wps: 13,799 mfu: 0.42% [rank7]:2024-08-06 03:26:52,207 - root - INFO - step: 16 loss: 6.7994 memory: 0.51GiB(0.64%) wps: 5,910 mfu: 0.18% [rank7]:2024-08-06 03:26:53,053 - root - INFO - step: 17 loss: 6.7634 memory: 0.51GiB(0.64%) wps: 4,847 mfu: 0.15% [rank7]:2024-08-06 03:26:53,370 - root - INFO - step: 18 loss: 6.7233 memory: 0.51GiB(0.64%) wps: 12,915 mfu: 0.39% [rank7]:2024-08-06 03:26:53,686 - root - INFO - step: 19 loss: 6.7054 memory: 0.51GiB(0.64%) wps: 12,995 mfu: 0.39% [rank7]:2024-08-06 03:26:54,059 - root - INFO - step: 20 loss: 6.7130 memory: 0.51GiB(0.64%) wps: 10,991 mfu: 0.33% ```
Stack from ghstack (oldest at bottom):
Fixes issue #129229 #129206
Summary
FSDPchoose_StridedShardplacement for FSDP+TP shardingfull_tensor()resulttest
pytest test/distributed/_composable/fsdp/pytest test/distributed/_composable/test_composability/test_2d_composability.pypytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.pycc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @zhaojuanmao @mrshenli @rohan-varma @chauhang @LucasLLC @MeetVadakkanchery @mhorowitz @pradeepfn
Differential Revision: D60606114