Skip to content

[WIP] Integrate Dynamo + SPMD for Inference#4862

Closed
steventk-g wants to merge 13 commits intopytorch:masterfrom
steventk-g:steven_spmd_inference
Closed

[WIP] Integrate Dynamo + SPMD for Inference#4862
steventk-g wants to merge 13 commits intopytorch:masterfrom
steventk-g:steven_spmd_inference

Conversation

@steventk-g
Copy link
Copy Markdown
Collaborator

@steventk-g steventk-g commented Apr 7, 2023

Includes @yeounoh's changes to use SPMD with torch.compile

@steventk-g steventk-g changed the title Integrate Dynamo + SPMD for Inference [WIP] Integrate Dynamo + SPMD for Inference Apr 7, 2023
@steventk-g steventk-g requested a review from yeounoh April 7, 2023 21:51
Comment thread test/spmd/test_dynamo_spmd_inference_latency.py Outdated
if FLAGS.fake_data:
assert FLAGS.test_set_batch_size == 1
test_loader = xu.SampleGenerator(
data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim).to(device),
Copy link
Copy Markdown
Contributor

@yeounoh yeounoh Apr 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we comment on that we needed to(device) here for spmd+dynamo?

Copy link
Copy Markdown
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM, added a few minor comments. Also we need to address linter checks.

@steventk-g steventk-g force-pushed the steven_spmd_inference branch 2 times, most recently from fdcf5f0 to a1e0c14 Compare April 11, 2023 23:42
Comment thread test/spmd/test_inference_spmd_dynamo_imagenet.py Outdated
torch.manual_seed(42)

model = torchvision.models.resnet50().to(
device) # get_model_property('model_fn')().to(device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to clean this up, either use get_model_property('model_fn')().to(device) or stick to resnet50(), in which case we don't need some of the model props code above.

Comment thread test/dynamo/test_bridge.py Outdated

import torch_xla.core.dynamo_bridge as bridge
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is needed?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Comment thread test/dynamo/test_bridge.py Outdated
Comment thread test/dynamo/test_bridge.py Outdated
python3 pytorch/xla/test/spmd/test_xla_virtual_device.py
python3 pytorch/xla/test/spmd/test_train_spmd_linear_model.py
XLA_USE_SPMD=1 python3 pytorch/xla/test/spmd/test_train_spmd_linear_model.py --sharding batch
XLA_USE_SPMD=1 python3 pytorch/xla/test/spmd/test_inference_spmd_dynamo_imagenet.py --fake_data --use_dynamo --sharding batch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome, thank you!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long will this test take? It is inference so I guess it won't be too bad but I am trying to prevent us from adding too many burden to the already very long CI.

Comment thread torch_xla/core/dynamo_bridge.py Outdated


def is_xla_tensor(tensor: torch.Tensor) -> bool:
# TODO(yeounoh) check if tensor sharding annotation can be accessed here
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove my comment here now.

Copy link
Copy Markdown
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more comments

Comment on lines +189 to +191
xm.mark_step()
xm.wait_device_ops()
met.clear_all()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are these needed? I had these in my dynamo test because test data is fake so I need a mark_step to materalize the fake input, if that;'s not the case here we don't need to do a mark_step here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to materialize the input here as well. @yeounoh may be able to add more context

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, maybe add a comment above to mentioned that this mark_step is to materialize the fake input then.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread test/spmd/test_inference_spmd_dynamo_imagenet.py
Comment thread test/spmd/test_inference_spmd_dynamo_imagenet.py
&coll, std::move(arguments), placeholders, std::move(cachedComputation));

auto syncfn = [async, hash]() {
auto syncfn = [async, hash, sharding_specs]() {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why sharding_specs needs to be passed from original function? It is init as default value right? We might as well just init it within this function.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I guess the point here is we need to supply a proper sharding spec for the output, but it is currently not supported yet. As a follow up, what's the right way to determine the sharding spec for output?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I'm not sure the right way to determine sharding spec for output. @yeounoh might be able to answer or provide more context

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine, we will leave that as a follow up.

Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly LGTM.

FYI @wonjoolee95 in this pr we break a really important assumption of the dynamo which is "same fx graph will result in same IR/HLO graph". In the SPMD case, even if the fx graph is the same, we can annotate the input XLATensor and make it generate different IR/HLO graph.

There is one thing that is still unclear to me and I hope @yeounoh and maybe @alanwaketan can give me some answer. AFAIK, during the inital phase of the dynamo extract_compile_graph, dynamo will pass us the fx graph along with the fake XLATensor with same shape with the real model. Does dyanmo somehow correctly move the sharding annotation from the real XLATensor to the Fake XLATensor? Otherwise I would expect we hit a cache miss when we execute the real SPMD computation.

@steventk-g steventk-g force-pushed the steven_spmd_inference branch from d774de0 to cf1f0f4 Compare April 24, 2023 16:26
@JackCaoG
Copy link
Copy Markdown
Collaborator

@steventk-g let me know when you want me to take another look

@JackCaoG
Copy link
Copy Markdown
Collaborator

JackCaoG commented May 2, 2023

I am taking over this commit, since it is on @steventk-g 's fork and it already has conflcits with the master, I will just open a new pr and patch from master.

@alanwaketan
Copy link
Copy Markdown
Collaborator

I am taking over this commit, since it is on @steventk-g 's fork and it already has conflcits with the master, I will just open a new pr and patch from master.

Let me know when it's ready. I can review it.

@JackCaoG
Copy link
Copy Markdown
Collaborator

Close this one in favor of #5002

@JackCaoG JackCaoG closed this May 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants