[WIP] Integrate Dynamo + SPMD for Inference#4862
[WIP] Integrate Dynamo + SPMD for Inference#4862steventk-g wants to merge 13 commits intopytorch:masterfrom
Conversation
| if FLAGS.fake_data: | ||
| assert FLAGS.test_set_batch_size == 1 | ||
| test_loader = xu.SampleGenerator( | ||
| data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim).to(device), |
There was a problem hiding this comment.
Maybe we comment on that we needed to(device) here for spmd+dynamo?
fdcf5f0 to
a1e0c14
Compare
| torch.manual_seed(42) | ||
|
|
||
| model = torchvision.models.resnet50().to( | ||
| device) # get_model_property('model_fn')().to(device) |
There was a problem hiding this comment.
Need to clean this up, either use get_model_property('model_fn')().to(device) or stick to resnet50(), in which case we don't need some of the model props code above.
|
|
||
| import torch_xla.core.dynamo_bridge as bridge | ||
| import torch_xla.core.xla_model as xm | ||
| import torch_xla.experimental.xla_sharding as xs |
| python3 pytorch/xla/test/spmd/test_xla_virtual_device.py | ||
| python3 pytorch/xla/test/spmd/test_train_spmd_linear_model.py | ||
| XLA_USE_SPMD=1 python3 pytorch/xla/test/spmd/test_train_spmd_linear_model.py --sharding batch | ||
| XLA_USE_SPMD=1 python3 pytorch/xla/test/spmd/test_inference_spmd_dynamo_imagenet.py --fake_data --use_dynamo --sharding batch |
There was a problem hiding this comment.
How long will this test take? It is inference so I guess it won't be too bad but I am trying to prevent us from adding too many burden to the already very long CI.
|
|
||
|
|
||
| def is_xla_tensor(tensor: torch.Tensor) -> bool: | ||
| # TODO(yeounoh) check if tensor sharding annotation can be accessed here |
There was a problem hiding this comment.
I think we can remove my comment here now.
| xm.mark_step() | ||
| xm.wait_device_ops() | ||
| met.clear_all() |
There was a problem hiding this comment.
why are these needed? I had these in my dynamo test because test data is fake so I need a mark_step to materalize the fake input, if that;'s not the case here we don't need to do a mark_step here.
There was a problem hiding this comment.
We need to materialize the input here as well. @yeounoh may be able to add more context
There was a problem hiding this comment.
ok, maybe add a comment above to mentioned that this mark_step is to materialize the fake input then.
| &coll, std::move(arguments), placeholders, std::move(cachedComputation)); | ||
|
|
||
| auto syncfn = [async, hash]() { | ||
| auto syncfn = [async, hash, sharding_specs]() { |
There was a problem hiding this comment.
why sharding_specs needs to be passed from original function? It is init as default value right? We might as well just init it within this function.
There was a problem hiding this comment.
ok, I guess the point here is we need to supply a proper sharding spec for the output, but it is currently not supported yet. As a follow up, what's the right way to determine the sharding spec for output?
There was a problem hiding this comment.
Good question, I'm not sure the right way to determine sharding spec for output. @yeounoh might be able to answer or provide more context
There was a problem hiding this comment.
That's fine, we will leave that as a follow up.
JackCaoG
left a comment
There was a problem hiding this comment.
mostly LGTM.
FYI @wonjoolee95 in this pr we break a really important assumption of the dynamo which is "same fx graph will result in same IR/HLO graph". In the SPMD case, even if the fx graph is the same, we can annotate the input XLATensor and make it generate different IR/HLO graph.
There is one thing that is still unclear to me and I hope @yeounoh and maybe @alanwaketan can give me some answer. AFAIK, during the inital phase of the dynamo extract_compile_graph, dynamo will pass us the fx graph along with the fake XLATensor with same shape with the real model. Does dyanmo somehow correctly move the sharding annotation from the real XLATensor to the Fake XLATensor? Otherwise I would expect we hit a cache miss when we execute the real SPMD computation.
d774de0 to
cf1f0f4
Compare
|
@steventk-g let me know when you want me to take another look |
|
I am taking over this commit, since it is on @steventk-g 's fork and it already has conflcits with the master, I will just open a new pr and patch from master. |
Let me know when it's ready. I can review it. |
|
Close this one in favor of #5002 |
Includes @yeounoh's changes to use SPMD with
torch.compile