Skip to content

Enable eager spmd#7341

Merged
JackCaoG merged 4 commits intomasterfrom
JackCaoG/eager_spmd
Jun 26, 2024
Merged

Enable eager spmd#7341
JackCaoG merged 4 commits intomasterfrom
JackCaoG/eager_spmd

Conversation

@JackCaoG
Copy link
Copy Markdown
Collaborator

No description provided.

@JackCaoG JackCaoG added eager PyTorch/XLA eager-mode distributed SPMD and other distributed things. labels Jun 24, 2024
@JackCaoG JackCaoG force-pushed the JackCaoG/eager_spmd branch from 3ec9f92 to 4244dcb Compare June 25, 2024 22:55
@JackCaoG JackCaoG marked this pull request as ready for review June 26, 2024 18:30
@alanwaketan
Copy link
Copy Markdown
Collaborator

How could SPMD possibly work for eager mode?

@JackCaoG
Copy link
Copy Markdown
Collaborator Author

How could SPMD possibly work for eager mode?

consider eager mode as calling mark_step after every pytorch op.

@alanwaketan
Copy link
Copy Markdown
Collaborator

How could SPMD possibly work for eager mode?

consider eager mode as calling mark_step after every pytorch op.

Then how sharding propogation and auto partition work? I assume they don't carry states from last graph?

@JackCaoG
Copy link
Copy Markdown
Collaborator Author

How could SPMD possibly work for eager mode?

consider eager mode as calling mark_step after every pytorch op.

Then how sharding propogation and auto partition work? I assume they don't carry states from last graph?

The sharding propogation and auto partition still happening within the subgraph we compile. For example

t3 = t2.cos(t1)
t3 += t2

we will compile a graph for cos which will calculate the output sharding for t3 and then assign a PJRT sharded buffer that's not ready to t3. We will then just proceed with another graph with add and now we know the input sharding for t3, then we will just propagate that to the output.

@alanwaketan
Copy link
Copy Markdown
Collaborator

Okay, that's fair.

XLATensor::ShardingSpecPtr sharding = input_tensor->sharding_spec();
if (sharding && sharding->sharding.type() != xla::OpSharding::UNKNOWN) {
// don't propagate sharding in eager mode.
if (!XLAGraphExecutor::Get()->UseEagerMode() && sharding &&
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask why?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It complained about the output tensor already has a sharding and we can't propagate to it. This happens in the backward. I didn't spend enough time to debug it but I don't expect user to actually run eager mode with step fn(forward and backward), I only expect them to run it with some data preprocessing on device so I just quickly unblock myself.

Copy link
Copy Markdown
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks, Jack!

@JackCaoG JackCaoG merged commit d5e5713 into master Jun 26, 2024
JackCaoG added a commit that referenced this pull request Jul 12, 2024
bhavya01 pushed a commit that referenced this pull request Jul 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

distributed SPMD and other distributed things. eager PyTorch/XLA eager-mode

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants