Skip to content

Optimize execution for ops that have multiple output in eager mode#7680

Merged
JackCaoG merged 2 commits intomasterfrom
JackCaoG/eager_faster
Jul 16, 2024
Merged

Optimize execution for ops that have multiple output in eager mode#7680
JackCaoG merged 2 commits intomasterfrom
JackCaoG/eager_faster

Conversation

@JackCaoG
Copy link
Copy Markdown
Collaborator

In eager mode the execution happens when we create an XLATensor with IR, we will use the IR as the root to build/execute the graph.

This is mostly fine but for ops that has multiple outputs(like native_batch_norm), most of the outputs share a good amounts of common HLOs. It will be much faster to execute all of them in a single graph. The eager mode in PyTorch/XLA can't really execute HLO one by one, so the goal is to execute once(ideally) for each pytorch op.

The change in this pr will

  1. delay the eager execution for some ops when they creating new XLAtensor with IRs
  2. execute the HLO for all XLAtensors after they are created.

I will take another round to check I didn't mess up anything but would appreciate if someone can look closely at my change inside tensor_method.cpp.

@JackCaoG JackCaoG added the eager PyTorch/XLA eager-mode label Jul 12, 2024
Comment thread torch_xla/csrc/tensor_methods.cpp Outdated
@JackCaoG
Copy link
Copy Markdown
Collaborator Author

I also intentionally didn't handle the collectives. Collective will return a all_reduce token which we actually don't want to execute in eager case. I will handle that in a separate pr.

@aws-rhsoln
Copy link
Copy Markdown
Contributor

Curious how much perf boost do we expect when we fuse them into a single graph?

@JackCaoG
Copy link
Copy Markdown
Collaborator Author

JackCaoG commented Jul 15, 2024

Curious how much perf boost do we expect when we fuse them into a single graph?

for a test code

torch_xla.experimental.eager_mode(True)

device = torch_xla.device()
m = nn.BatchNorm2d(16).to(device)
m.train()
input = torch.randn(16, 16, 1024, 1024, device=device)

start = time.time()
for _ in range(20):
  input = m(input)
xm.wait_device_ops()
end = time.time()
duration = end - start
print(f"total time = {duration}")

with my change total time = 0.46190381050109863, without this change total time = 14.28174352645874. I actually don;t know why it is 28x faster, but I did verified that in HLO without my change BatchNorm2d will compute the result one by one.

@JackCaoG JackCaoG marked this pull request as ready for review July 15, 2024 18:37
@JackCaoG
Copy link
Copy Markdown
Collaborator Author

@alanwaketan @wonjoolee95 This one is ready for review.

@JackCaoG JackCaoG merged commit b2c7f65 into master Jul 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

eager PyTorch/XLA eager-mode

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants