Skip to content

Support eager mode for multi-process training#7327

Merged
JackCaoG merged 4 commits intomasterfrom
JackCaoG/multi_process_eager_2
Jun 24, 2024
Merged

Support eager mode for multi-process training#7327
JackCaoG merged 4 commits intomasterfrom
JackCaoG/multi_process_eager_2

Conversation

@JackCaoG
Copy link
Copy Markdown
Collaborator

in place all_reduce is used for optimizer_step for data parallel training for multi-process. The HLO for

  ordinal_tensor_1 = torch.tensor([index], dtype=torch.float).to(device)
  ordinal_tensor_2 = torch.tensor([index], dtype=torch.int32).to(device)

  xm.all_reduce(xm.REDUCE_SUM, [ordinal_tensor_1, ordinal_tensor_2])

looks like

ENTRY %IrToHlo.27 (p0.1: f32[], p1.2: s32[1], p2.3: f32[1]) -> (f32[1], s32[1]) {
.......
  %all-reduce.12 = (s32[1]{0}, s32[]) all-reduce(s32[1]{0} %get-tuple-element.6, s32[] %get-tuple-element.7), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.8, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/workspaces/dk3/pytorch/xla/torch_xla/core/xla_model.py" source_line=501}
.....
  %get-tuple-element.24 = f32[1]{0} get-tuple-element((f32[1]{0}, f32[]) %all-reduce.23), index=0, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/workspaces/dk3/pytorch/xla/torch_xla/core/xla_model.py" source_line=501}
  %get-tuple-element.13 = s32[1]{0} get-tuple-element((s32[1]{0}, s32[]) %all-reduce.12), index=0, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/workspaces/dk3/pytorch/xla/torch_xla/core/xla_model.py" source_line=501}
  ROOT %tuple.26 = (f32[1]{0}, s32[1]{0}) tuple(f32[1]{0} %get-tuple-element.24, s32[1]{0} %get-tuple-element.13)
}

Note that in above HLO we have 2 output but we only all_reduce once. Without this change we will eagerly evaluate each output, which result in all_rduce being compiled/execute twice which is not ideal. For ops like all_reduce that one ops has multiple outputs, it is better to group the execution and only execute once.

@JackCaoG JackCaoG added the tpuci label Jun 21, 2024
@JackCaoG JackCaoG added eager PyTorch/XLA eager-mode usability Bugs/features related to improving the usability of PyTorch/XLA labels Jun 21, 2024
@JackCaoG JackCaoG marked this pull request as ready for review June 21, 2024 23:29
@JackCaoG
Copy link
Copy Markdown
Collaborator Author

this pr is ready for review.

@JackCaoG JackCaoG merged commit 222bbd8 into master Jun 24, 2024
@JackCaoG JackCaoG deleted the JackCaoG/multi_process_eager_2 branch June 24, 2024 23:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

eager PyTorch/XLA eager-mode usability Bugs/features related to improving the usability of PyTorch/XLA

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants