[Dynamo] Refine CPU fallback for TD+XLA#4935
[Dynamo] Refine CPU fallback for TD+XLA#4935seanlatias wants to merge 14 commits intopytorch:masterfrom
Conversation
…have XLAData or IR
| XLA_CHECK(false) | ||
| << "_check_tensor_need_materialization " | ||
| "currently does not handle XLATensor without XLAData and IR"; | ||
| need_materialization.push_back(true); |
There was a problem hiding this comment.
I think the right thing to do is to check xtensor->CurrentTensorData(), if it is not nullptr then
need_materialization.push_back(false);
since we don't need to execute a computation with a XLATensor with tensor_data(tensor data is a cpu at::Tensor that we know the value). We should leave the last else branch to throw an error to catch any future unhanlded case.
|
|
||
| @dynamo.optimize("torchxla_trace_once") | ||
| def fn_fallback(M, mat1, mat2): | ||
| A = torch.cummin(M, 1) |
There was a problem hiding this comment.
is cumin suppose to fallback?
There was a problem hiding this comment.
Yes, I have checked that cummin is not supported. I think one problem of this kind of testing is that once we support the unsupported ops, we might need to update the test. I'm thinking maybe I can create a custom op that will never be supported to avoid this problem.
There was a problem hiding this comment.
For the scope of this PR, I think you can just leave a TODO comment. We can track that kind of specialized unit test in a separate issue, if that sounds okay.
| mat1 = torch.randn(5, 10, device=xm.xla_device()) | ||
| mat2 = torch.randn(5, 10, device=xm.xla_device()) | ||
|
|
||
| res = fn_fallback(M, mat1, mat2) |
There was a problem hiding this comment.
I think we should
- Check that
resmatch the eager result. - Check the counter (try
met.short_metrics_report()and you should see the fallback counter? but maybe fallback happens in pytorch then we don't even see it), I actually don't know what to expect but if there is a fallback I assume theExecuteTimewill be 2 instead of 1. You also need to run thefn_fallbackonce to let it compile and then do a clear counter than rerun to seeExecuteTimeto be 2 instead of 1.
For more counter related stuff, check https://github.com/pytorch/xla/blob/master/torch_xla/debug/metrics.py
There was a problem hiding this comment.
If we don't see the mentioned PyTorch's counter, we can also create/use our own metric when we collect the fallback nodes -- could be useful for debugging/testing.
| mat1 = torch.randn(2, 3, device=xm.xla_device()) | ||
| mat2 = torch.randn(3, 3, device=xm.xla_device()) | ||
|
|
||
| res = fn_fallback(M, mat1, mat2, 0.5) |
wonjoo-wj
left a comment
There was a problem hiding this comment.
Thanks, @seanlatias! Seems like some existing dynamo tests are failing, mostly due to metrics. The errors should be able to be reproduced locally:
python test/dynamo/test_dynamo.py DynamoInferenceBasicTest.test_resnet18
|
|
||
| @dynamo.optimize("torchxla_trace_once") | ||
| def fn_fallback(M, mat1, mat2): | ||
| A = torch.cummin(M, 1) |
There was a problem hiding this comment.
For the scope of this PR, I think you can just leave a TODO comment. We can track that kind of specialized unit test in a separate issue, if that sounds okay.
| mat1 = torch.randn(5, 10, device=xm.xla_device()) | ||
| mat2 = torch.randn(5, 10, device=xm.xla_device()) | ||
|
|
||
| res = fn_fallback(M, mat1, mat2) |
There was a problem hiding this comment.
If we don't see the mentioned PyTorch's counter, we can also create/use our own metric when we collect the fallback nodes -- could be useful for debugging/testing.
| new_node = partitioned_graph.graph.call_function( | ||
| extract_internal(fused_module), node.args, None) | ||
| node.replace_all_uses_with(new_node) | ||
| partitioned_graph.graph.erase_node(node) |
There was a problem hiding this comment.
Wondering why we need this partitioned_graph.graph.erase_node(node)? Can we do graph.eliminate_dead_code() like https://github.com/pytorch/pytorch/blob/0d66db1b2a9470a50d930308dbffda017500b80b/torch/_prims/nvfuser_executor.py#L465 or is this something different?
There was a problem hiding this comment.
Here I'm explicitly removing the old node (call to submodule), which I replace with a new node (call to optimized function). With eliminate_dead_code(), it goes through the entire graph and check whether a node is being used or not. Ideally my old node is not being used anymore so the two approaches should lead to the same results. However, I tried eliminate_dead_code() it seems it faces some errors when trying to eliminate the old node.
Do you know why sometimes |
|
|
|
Ok, because I see in the current dynamo bridge implementation, there are two places that increase the xla/torch_xla/core/dynamo_bridge.py Line 207 in 8db3f89 xla/torch_xla/core/dynamo_bridge.py Line 279 in 8db3f89 Compile count to be just one (xla/test/dynamo/test_dynamo.py Line 82 in 8db3f89 |
|
The metric comparison failure seems like it has to do with how the Let me also try to build this branch locally to confirm. |
|
Thanks Wonjoo. But the |
|
After second thought, I think it makes sense for my approach to introduce extra |
|
For the optimizer test, the number is way off because of an issue with the partitioner. I have created a PR (pytorch/pytorch#100195) to solve it. |
Thanks for the investigation, Sean. Nice, that makes a lot sense to me. Let's update the Also if you could add this finding to comments in our dynamo tests explaining this change in compile/execute time, that'd be great. |
Thanks for opening the PR, I've just triggered the CI run on it. LGTM but let's see if the reviewer has any feedback. Also, to make this PR build/test based of that PyTorch PR, we can add a PyTorch pin to this PR. If you add a file |
Thanks Wonjoo. Do you know how I can request review in PyTorch or who I should tag? |
|
@seanlatias I pinged Sherlock offline, we can help you find a reviewer for that pr. |
|
Had a quick sync with @seanlatias offline. Oddly, it seems like some fallback works on master branch if we just remove the assertion in our dynamo bridge at https://github.com/pytorch/xla/blob/master/torch_xla/core/dynamo_bridge.py#L226-L230. Using a similar example that was in our original RFCs (#4742 and pytorch/pytorch#93601), the code below runs without any errors on master: And with the metrics report I can see that the Previously, this would fail with an error like cc @JackCaoG, am I missing something here? It appears the fallback that wasn't working before is now working on master. Thanks @seanlatias for the investigation, let me know if I'm missing any details here. |
|
Just to provide more details. I have also tried using the same fix (i.e., simply remove the check and the assertion) on Dynamo test suite. So far 3 huggingface models and 4 timm models that used to fail because of CPU fallback now work correctly (by using Also, in the original issue (pytorch/pytorch#93601), the root cause was actually the assertion fail from invoking CPU fallback. |
|
|
||
| cpu_res = fn_fallback(M, mat1, mat2) | ||
| xla_res = dynamo_fn(xla_M, xla_mat1, xla_mat2) | ||
|
|
There was a problem hiding this comment.
can you use something similar to https://github.com/pytorch/xla/blob/master/torch_xla/core/dynamo_bridge.py#L57 to check there is no fallback ops in this case?
If this works, the fallback should happen on the pytorch cpu side and does not trigger a fallback counter on pytorch/xla end.
We should also check that xla::cummin is not in the counter, in case we lower this op in the future.
| cpu_res = fn_fallback(M, mat1, mat2, 0.5) | ||
| xla_res = dynamo_fn(M, mat1, mat2, 0.5) | ||
|
|
||
| self.assertTrue(torch.allclose(cpu_res, xla_res.cpu())) |
There was a problem hiding this comment.
ditto, we should add the counter check. Did we also handle the opernad shape invalid case in this pr? I thought we just handle ops that we don't lower today.
There was a problem hiding this comment.
This PR now actually handles invalid shape/operand case because of how the updated FallBackNodeCollector works. The FallBackNodeCollector now executes each node of the input module so it can figure out which ops fall back (regardless of if it falls back because it's just not implemented in XLA or because specific shapes/operands are not supported).
| cpu_res = fn_fallback(M, mat1, mat2, 0.5) | ||
| xla_res = dynamo_fn(M, mat1, mat2, 0.5) | ||
|
|
||
| self.assertTrue(torch.allclose(cpu_res, xla_res.cpu())) |
There was a problem hiding this comment.
for this one we should also check the metrics for ExecuteTime(you need to clear metrics first). We are expecting to see 2 executions since fallback happens in the middle.
| none_remover.add_nones(result) | ||
| return result | ||
| if len(result) == 1: | ||
| return result[0] |
|
@seanlatias, I was hoping to push directly to PR/branch but seems like I can't since it's a local forked branch. Are you able to push directly to #5000? Just rebased this PR with latest and fixed some conflicts. |
@wonjoolee95 yeah I can do that. But before doing so, I'm wondering if this approach is still valid? Because we probably only need to remove the fallback checking and things will just work. |
|
@wonjoolee95 maybe let me do this. We can create an env var that determines whether we should use explicit CPU fallback (my method), or go through the original flow with checking removed. The default one will be the original flow with checking removed. How does that sound to you? If you agree, I can add the logics and push it to your branch. |
@seanlatias, I had a quick sync with Jack last week regarding the existing behavior and how the fallback works if we just remove the assertion. Based on our discussion and understanding, this shouldn't be the case -- it shouldn't work, as dynamo expects a single hash representing the graph that is entirely executable by XLA. And if this graph includes an unsupported op, it cannot be entirely executable by XLA. However, as shown on #4935 (comment), I can still reproduce the behavior with unsupported ops working on master branches if the assertion is just simply removed. Let me do a quick experiment and post my findings why this might be happening. So let's wait on implementing the flag you mentioned. Also, let's move the discussion to the new PR #5000. |
|
@wonjoolee95 Sounds good. Will do. |
|
Close this and move to #5000. |
In this PR, we refined the CPU fallback mechanism originally implemented by @wonjoolee95 and made the following changes.
Fix the checking for XLA support.
We use
torch.fx.Interpreterto execute each node of the input module and usetorch_xla.debug.metricsto check whether that node goes through CPU fallback or not. If it does, it means it is not supported by XLA.Fix the fallback meachnism.
2.1. Partition the graph according to our results from (1), which results in a new graph with subgraphs containing ops supported by XLA only.
2.2. Compile each subgraph into a compiled function call.
2.3. Replace the subgraph with the compiled function call.
With the above changes, we can even cover cases where a combination of operator and operand is not supported.
cc: @JackCaoG @alanwaketan