Skip to content

[LoweringContext] SPMD propagation#8471

Merged
tengyifei merged 5 commits intopytorch:masterfrom
rpsilva-aws:rpsilva_spmd_lc_v2
Dec 10, 2024
Merged

[LoweringContext] SPMD propagation#8471
tengyifei merged 5 commits intopytorch:masterfrom
rpsilva-aws:rpsilva_spmd_lc_v2

Conversation

@rpsilva-aws
Copy link
Copy Markdown
Collaborator

Introduce SPMD sharding to the lowering context: this ensures that the computation has the respective sharding specs deduced from the inputs (scoped to the creation of the parameters), and to propagate the input shardings to the output.

HloModule SomeFn.12, entry_computation_layout={(f32[2048]{0}, f32[], f32[32,2048]{1,0})->(f32[2048]{0}, f32[32,2048]{1,0})}

ENTRY %SomeFn.12 (p0.3: f32[2048], p1.7: f32[], p2.8: f32[32,2048]) -> (f32[2048], f32[32,2048]) {
  %p0.3 = f32[2048]{0} parameter(0), sharding={devices=[4,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31 last_tile_dim_replicate}
  %constant.2 = f32[] constant(1)
  %constant.1 = f32[] constant(1)
  %multiply.4 = f32[] multiply(f32[] %constant.2, f32[] %constant.1)
  %broadcast.5 = f32[2048]{0} broadcast(f32[] %multiply.4), dimensions={}
  %add.6 = f32[2048]{0} add(f32[2048]{0} %p0.3, f32[2048]{0} %broadcast.5)
  %p2.8 = f32[32,2048]{1,0} parameter(2), sharding={devices=[1,8,4]0,8,16,24,1,9,17,25,2,10,18,26,3,11,19,27,4,12,20,28,5,13,21,29,6,14,22,30,7,15,23,31 last_tile_dim_replicate}
  %p1.7 = f32[] parameter(1), sharding={replicated}
  %broadcast.9 = f32[32,2048]{1,0} broadcast(f32[] %p1.7), dimensions={}
  %multiply.10 = f32[32,2048]{1,0} multiply(f32[32,2048]{1,0} %p2.8, f32[32,2048]{1,0} %broadcast.9)
  ROOT %tuple.11 = (f32[2048]{0}, f32[32,2048]{1,0}) tuple(f32[2048]{0} %add.6, f32[32,2048]{1,0} %multiply.10)
}

@rpsilva-aws rpsilva-aws marked this pull request as ready for review December 9, 2024 18:46
@rpsilva-aws
Copy link
Copy Markdown
Collaborator Author

FYI: @tengyifei

@tengyifei tengyifei added the tpuci label Dec 9, 2024
Comment thread test/spmd/test_spmd_lowering_context.py Outdated
Comment thread test/run_tests.sh Outdated
Comment thread torch_xla/csrc/init_python_bindings.cpp
@tengyifei tengyifei merged commit b068cab into pytorch:master Dec 10, 2024
@rpsilva-aws rpsilva-aws deleted the rpsilva_spmd_lc_v2 branch December 10, 2024 00:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants