Skip to content

[functional_collective] remove the logic that forces torch-xla to use legacy funcol#123776

Closed
yifuwang wants to merge 3 commits intogh/yifuwang/77/basefrom
gh/yifuwang/77/head
Closed

[functional_collective] remove the logic that forces torch-xla to use legacy funcol#123776
yifuwang wants to merge 3 commits intogh/yifuwang/77/basefrom
gh/yifuwang/77/head

Conversation

@yifuwang
Copy link
Copy Markdown
Collaborator

@yifuwang yifuwang commented Apr 10, 2024

… legacy funcol

After pytorch/xla#6887, torch-xla now also uses
the all_reduce from native funcol. So we can remove this logic.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123776

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 2be8459 with merge base 585cd11 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot Bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Apr 10, 2024
@yifuwang yifuwang marked this pull request as ready for review April 10, 2024 22:13
@yifuwang yifuwang requested review from wanchaol and wconstab April 10, 2024 22:13
…-xla to use legacy funcol"

After pytorch/xla#6887, torch-xla now also uses
the all_reduce from native funcol. So we can remove this logic.

[ghstack-poisoned]
@yifuwang yifuwang requested a review from yf225 April 11, 2024 21:15
Copy link
Copy Markdown
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I wonder if we should update the xla pin to include the xla side of changes? cc @alanwaketan

@alanwaketan
Copy link
Copy Markdown
Collaborator

Nice! I wonder if we should update the xla pin to include the xla side of changes? cc @alanwaketan

Let's do it. Just grep xla.txt and use the hash of the head of torch-xla.

…-xla to use legacy funcol"

After pytorch/xla#6887, torch-xla now also uses
the all_reduce from native funcol. So we can remove this logic.

[ghstack-poisoned]
@yifuwang yifuwang added the topic: not user facing topic category label Apr 13, 2024
pytorchmergebot pushed a commit that referenced this pull request Apr 13, 2024
… funcol ops (#123777)

## Summary

After this PR, the functional collective Python APIs will stop honoring `TORCH_DISABLE_NATIVE_FUNCOL` and only use native funcol ops. Specifically, this PR:
- Removed `use_native_funcol()`.
- Removed the code path in the Python APIs when `use_native_funcol()` is `False`.
- Changed the CI tests that runs on both native funcol and legacy funcol through the Python API to only run with native funcol.

## Test Changes

`test_functional_api.py`
- Removed the tests where only one of output_split_sizes or input_split_sizes is specified. This behavior is unreliable has been removed from the native funcol.
- Removed `TestWaitiness` which tests an implementation detail of the legacy funcol. We have equivalent tests for native funcol in `test/distributed/test_c10d_functional_native.py` https://github.com/pytorch/pytorch/blob/b7fac76fc259394136bc77b3e39d5705919e5c4c/test/distributed/test_c10d_functional_native.py#L114-L116

`test/distributed/_tensor/test_dtensor.py`
`test/distributed/_tensor/test_dtensor_compile.py`
`test/distributed/test_device_mesh.py`
`test/distributed/_tensor/experimental/test_tp_transform.py`
`test/distributed/_tensor/test_matrix_ops.py`
`test/distributed/test_inductor_collectives.py`
- All these tests were double running with both native funcol and legacy funcol. Changed to only run with native funcol.

`test/distributed/test_c10d_functional_native.py`
- Removed the `run_with_native_funcol` decorators.

Pull Request resolved: #123777
Approved by: https://github.com/wanchaol
ghstack dependencies: #123776
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
… legacy funcol (pytorch#123776)

After pytorch/xla#6887, torch-xla now also uses
the all_reduce from native funcol. So we can remove this logic.

Pull Request resolved: pytorch#123776
Approved by: https://github.com/wanchaol
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
… funcol ops (pytorch#123777)

## Summary

After this PR, the functional collective Python APIs will stop honoring `TORCH_DISABLE_NATIVE_FUNCOL` and only use native funcol ops. Specifically, this PR:
- Removed `use_native_funcol()`.
- Removed the code path in the Python APIs when `use_native_funcol()` is `False`.
- Changed the CI tests that runs on both native funcol and legacy funcol through the Python API to only run with native funcol.

## Test Changes

`test_functional_api.py`
- Removed the tests where only one of output_split_sizes or input_split_sizes is specified. This behavior is unreliable has been removed from the native funcol.
- Removed `TestWaitiness` which tests an implementation detail of the legacy funcol. We have equivalent tests for native funcol in `test/distributed/test_c10d_functional_native.py` https://github.com/pytorch/pytorch/blob/b7fac76fc259394136bc77b3e39d5705919e5c4c/test/distributed/test_c10d_functional_native.py#L114-L116

`test/distributed/_tensor/test_dtensor.py`
`test/distributed/_tensor/test_dtensor_compile.py`
`test/distributed/test_device_mesh.py`
`test/distributed/_tensor/experimental/test_tp_transform.py`
`test/distributed/_tensor/test_matrix_ops.py`
`test/distributed/test_inductor_collectives.py`
- All these tests were double running with both native funcol and legacy funcol. Changed to only run with native funcol.

`test/distributed/test_c10d_functional_native.py`
- Removed the `run_with_native_funcol` decorators.

Pull Request resolved: pytorch#123777
Approved by: https://github.com/wanchaol
ghstack dependencies: pytorch#123776
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
… legacy funcol (pytorch#123776)

After pytorch/xla#6887, torch-xla now also uses
the all_reduce from native funcol. So we can remove this logic.

Pull Request resolved: pytorch#123776
Approved by: https://github.com/wanchaol
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
… funcol ops (pytorch#123777)

## Summary

After this PR, the functional collective Python APIs will stop honoring `TORCH_DISABLE_NATIVE_FUNCOL` and only use native funcol ops. Specifically, this PR:
- Removed `use_native_funcol()`.
- Removed the code path in the Python APIs when `use_native_funcol()` is `False`.
- Changed the CI tests that runs on both native funcol and legacy funcol through the Python API to only run with native funcol.

## Test Changes

`test_functional_api.py`
- Removed the tests where only one of output_split_sizes or input_split_sizes is specified. This behavior is unreliable has been removed from the native funcol.
- Removed `TestWaitiness` which tests an implementation detail of the legacy funcol. We have equivalent tests for native funcol in `test/distributed/test_c10d_functional_native.py` https://github.com/pytorch/pytorch/blob/b7fac76fc259394136bc77b3e39d5705919e5c4c/test/distributed/test_c10d_functional_native.py#L114-L116

`test/distributed/_tensor/test_dtensor.py`
`test/distributed/_tensor/test_dtensor_compile.py`
`test/distributed/test_device_mesh.py`
`test/distributed/_tensor/experimental/test_tp_transform.py`
`test/distributed/_tensor/test_matrix_ops.py`
`test/distributed/test_inductor_collectives.py`
- All these tests were double running with both native funcol and legacy funcol. Changed to only run with native funcol.

`test/distributed/test_c10d_functional_native.py`
- Removed the `run_with_native_funcol` decorators.

Pull Request resolved: pytorch#123777
Approved by: https://github.com/wanchaol
ghstack dependencies: pytorch#123776
@github-actions github-actions Bot deleted the gh/yifuwang/77/head branch May 14, 2024 01:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants