Skip to content

[SPMD] Introduce high level manual sharding APIs#6931

Merged
alanwaketan merged 5 commits intomasterfrom
alanwaketan/manual_sharding_api
Apr 17, 2024
Merged

[SPMD] Introduce high level manual sharding APIs#6931
alanwaketan merged 5 commits intomasterfrom
alanwaketan/manual_sharding_api

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

Summary:
This pull request introduces:

  1. enable_manual_sharding: which starts the manual sharding region.
  2. disable_manual_sharding: which disable the manual sharding region.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_manual_sharding_api_e2e

@alanwaketan alanwaketan requested review from jonb377 and yeounoh April 17, 2024 00:46
@alanwaketan alanwaketan self-assigned this Apr 17, 2024
*,
mesh: Mesh = None) -> XLAShardedTensor:
"""
This API enables manual sharding for the given tensor. Manual sharding disables auto sharding proporgation and auto
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"auto" --> "SPMD", think it's important to not confuse.

Copy link
Copy Markdown
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, left a comment for comment :)

Copy link
Copy Markdown
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

"""
mesh = get_global_mesh() if mesh is None else mesh
t = mark_sharding(unwrap_sharded_tensor(t), mesh, partition_spec)
t = torch_xla._XLAC._spmd_full_to_shard_shape(unwrap_sharded_tensor(t))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can t here be DeviceData?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the input? Yes!

"""
This API enables manual sharding for the given tensor. Manual sharding disables auto sharding proporgation and auto
partition for the given tensor and all subsequential tensors that produced by an op that uses the given tensor as
input, and therefore allows the user to manually call collectives for the tensor and subsequential tensors. It
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just curious - how will we enable collectives in a manual region?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XLA cc ops by default should work. Just use it as normal. However, we need to teach our cc ops wrapper to be aware of SPMD mode. So, it will be phase 2 of the mnual sharding.

@alanwaketan alanwaketan merged commit 9b2ac4b into master Apr 17, 2024
@alanwaketan alanwaketan deleted the alanwaketan/manual_sharding_api branch April 17, 2024 18:28
lausannel pushed a commit to AlibabaPAI/xla that referenced this pull request Aug 6, 2024
Summary:
This pull request introduces:
1. enable_manual_sharding: which starts the manual sharding region.
2. disable_manual_sharding: which disable the manual sharding region.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_manual_sharding_api_e2e
baoleai pushed a commit to AlibabaPAI/xla that referenced this pull request Aug 6, 2024
Summary:
This pull request introduces:
1. enable_manual_sharding: which starts the manual sharding region.
2. disable_manual_sharding: which disable the manual sharding region.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_manual_sharding_api_e2e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants