Skip to content

[SPMD] auto-sharding PoC#6719

Merged
yeounoh merged 39 commits intomasterfrom
spmd_auto_alpa
Mar 14, 2024
Merged

[SPMD] auto-sharding PoC#6719
yeounoh merged 39 commits intomasterfrom
spmd_auto_alpa

Conversation

@yeounoh
Copy link
Copy Markdown
Contributor

@yeounoh yeounoh commented Mar 12, 2024

This implemented a PoC prototype on XLA:TPU, as described in #6322

PyTorch/XLA auto-sharding can be enabled by one of the following:

  • Setting envvar XLA_SPMD_AUTO=1
  • Calling the SPMD API in the beginning of your code:
import torch_xla.runtime as xr
xr.use_spmd(auto=True)
  • Calling pytorch.distributed._tensor.distribute_module with auto-policy and xla:
import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy

device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Currently, model should be loaded to xla device via distribute_module.
model = MyModule()  # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)

Some notable limitations that we will address in follow-ups:

  • XLA:GPU is not supported
  • TPU pod is not supported

cc @baoleai

@yeounoh yeounoh added the distributed SPMD and other distributed things. label Mar 12, 2024
@yeounoh yeounoh requested a review from JackCaoG March 12, 2024 00:21
@yeounoh yeounoh self-assigned this Mar 12, 2024
@yeounoh yeounoh marked this pull request as draft March 12, 2024 00:22
@yeounoh yeounoh force-pushed the spmd_auto_alpa branch 2 times, most recently from 126ceee to 4d568ef Compare March 12, 2024 00:25
Comment thread WORKSPACE Outdated
Comment thread setup.py Outdated
@yeounoh yeounoh force-pushed the spmd_auto_alpa branch 2 times, most recently from 6ca8f97 to d6dc442 Compare March 12, 2024 00:38
Comment thread test/spmd/test_dynamo_spmd.py
Comment thread test/spmd/test_spmd_graph_dump.py
Comment thread torch_xla/csrc/init_python_bindings.cpp Outdated
Comment thread torch_xla/csrc/runtime/profiler.cc
Comment thread torch_xla/csrc/init_python_bindings.cpp
@yeounoh yeounoh force-pushed the spmd_auto_alpa branch 12 times, most recently from 303b239 to d3c1d70 Compare March 12, 2024 07:34
yeounoh added 28 commits March 14, 2024 00:49
* Assume REPLICATED for UNKNOWN during paramter resharding
…patch

* Ungroup resharding ops

* Replace device data after resharding
Delete quantization openxla patch

Debugging probes
* Disable parameter wrapping with auto-sharding
Comment thread test/run_tests.sh
run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py"
run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py"
run_test "$CDIR/spmd/test_dtensor_integration.py"
run_test "$CDIR/spmd/test_dtensor_integration2.py"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this on TPU CI as well or it is ok to leave out?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhh i think it's ok to leave out. Want to run this sanity check on TPU!

Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to adjust remaining comments in a follow up [r

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backport_2.3 distributed SPMD and other distributed things.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants