Skip to content

[LoweringContext] Support an optimized parameter mapping for SPMD#8460

Merged
tengyifei merged 1 commit intopytorch:masterfrom
rpsilva-aws:rpsilva_lc_mapping_v3
Dec 7, 2024
Merged

[LoweringContext] Support an optimized parameter mapping for SPMD#8460
tengyifei merged 1 commit intopytorch:masterfrom
rpsilva-aws:rpsilva_lc_mapping_v3

Conversation

@rpsilva-aws
Copy link
Copy Markdown
Collaborator

Currently, the existing parameter mapping for the lowering context is not well suited for SPMD. In case of large models, it will cause a large synchronous bottleneck when transferring all device data to the host. This is caused by each ReplicateShardedData computation that gathers and reassembles each sharded data across multiple devices. This is by design, since it is expected to collect all parameters regardless of their allocation.

In this PR, we introduce a new mapping that does not invoke the sharded replication, but instead uses references to the device data. This is generally sufficient and preferred in most cases, where the user only wants to access the validate parameters (those that are not returned as -1 from tensor_parameter_id, as 'fake' parameters).

@rpsilva-aws
Copy link
Copy Markdown
Collaborator Author

Re-opened from #8453, cleaned up the merge commit.

@tengyifei tengyifei self-requested a review December 5, 2024 23:51
@tengyifei tengyifei added the tpuci label Dec 5, 2024
@tengyifei tengyifei marked this pull request as ready for review December 5, 2024 23:52
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_lc_mapping_v3 branch from 8fd7ac7 to 9858577 Compare December 5, 2024 23:53
@tengyifei tengyifei merged commit 5d11f66 into pytorch:master Dec 7, 2024
@rpsilva-aws rpsilva-aws deleted the rpsilva_lc_mapping_v3 branch December 9, 2024 19:03
tengyifei added a commit that referenced this pull request Jan 2, 2025
Previously scan uses `parameter_id_tensor_mapping` to fetch tensors
hoisted to HLO parameters e.g. the fn being scanned may create
additional tensors while its running. `parameter_id_tensor_mapping`
will fetch those tensors back to host as XLA literals and create new
tensors wrapphing those, resulting in additional host RAM usage.

PR #8460 added `device_parameter_id_tensor_mapping` that returns the
actual device backed tensors instead of another copy. So we'll use that
and test that this avoids host transfers.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants