Skip to content

[Dynamo] Refine CPU fallback for TD+XLA#4935

Closed
seanlatias wants to merge 14 commits intopytorch:masterfrom
seanlatias:cpufallback
Closed

[Dynamo] Refine CPU fallback for TD+XLA#4935
seanlatias wants to merge 14 commits intopytorch:masterfrom
seanlatias:cpufallback

Conversation

@seanlatias
Copy link
Copy Markdown
Collaborator

In this PR, we refined the CPU fallback mechanism originally implemented by @wonjoolee95 and made the following changes.

  1. Fix the checking for XLA support.
    We use torch.fx.Interpreter to execute each node of the input module and use torch_xla.debug.metrics to check whether that node goes through CPU fallback or not. If it does, it means it is not supported by XLA.

  2. Fix the fallback meachnism.
    2.1. Partition the graph according to our results from (1), which results in a new graph with subgraphs containing ops supported by XLA only.
    2.2. Compile each subgraph into a compiled function call.
    2.3. Replace the subgraph with the compiled function call.

With the above changes, we can even cover cases where a combination of operator and operand is not supported.

cc: @JackCaoG @alanwaketan

@JackCaoG JackCaoG requested review from JackCaoG and wonjoo-wj April 24, 2023 17:39
Comment thread torch_xla/csrc/init_python_bindings.cpp Outdated
XLA_CHECK(false)
<< "_check_tensor_need_materialization "
"currently does not handle XLATensor without XLAData and IR";
need_materialization.push_back(true);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the right thing to do is to check xtensor->CurrentTensorData(), if it is not nullptr then

need_materialization.push_back(false);

since we don't need to execute a computation with a XLATensor with tensor_data(tensor data is a cpu at::Tensor that we know the value). We should leave the last else branch to throw an error to catch any future unhanlded case.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment thread test/dynamo/test_dynamo_fallback.py Outdated

@dynamo.optimize("torchxla_trace_once")
def fn_fallback(M, mat1, mat2):
A = torch.cummin(M, 1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is cumin suppose to fallback?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have checked that cummin is not supported. I think one problem of this kind of testing is that once we support the unsupported ops, we might need to update the test. I'm thinking maybe I can create a custom op that will never be supported to avoid this problem.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the scope of this PR, I think you can just leave a TODO comment. We can track that kind of specialized unit test in a separate issue, if that sounds okay.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

Comment thread test/dynamo/test_dynamo_fallback.py Outdated
mat1 = torch.randn(5, 10, device=xm.xla_device())
mat2 = torch.randn(5, 10, device=xm.xla_device())

res = fn_fallback(M, mat1, mat2)
Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG Apr 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should

  1. Check that res match the eager result.
  2. Check the counter (try met.short_metrics_report() and you should see the fallback counter? but maybe fallback happens in pytorch then we don't even see it), I actually don't know what to expect but if there is a fallback I assume the ExecuteTime will be 2 instead of 1. You also need to run the fn_fallback once to let it compile and then do a clear counter than rerun to see ExecuteTime to be 2 instead of 1.

For more counter related stuff, check https://github.com/pytorch/xla/blob/master/torch_xla/debug/metrics.py

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

Copy link
Copy Markdown
Collaborator

@wonjoo-wj wonjoo-wj Apr 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't see the mentioned PyTorch's counter, we can also create/use our own metric when we collect the fallback nodes -- could be useful for debugging/testing.

Comment thread test/dynamo/test_dynamo_fallback.py Outdated
mat1 = torch.randn(2, 3, device=xm.xla_device())
mat2 = torch.randn(3, 3, device=xm.xla_device())

res = fn_fallback(M, mat1, mat2, 0.5)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Copy Markdown
Collaborator

@wonjoo-wj wonjoo-wj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @seanlatias! Seems like some existing dynamo tests are failing, mostly due to metrics. The errors should be able to be reproduced locally:

python test/dynamo/test_dynamo.py DynamoInferenceBasicTest.test_resnet18

Comment thread test/dynamo/test_dynamo_fallback.py Outdated

@dynamo.optimize("torchxla_trace_once")
def fn_fallback(M, mat1, mat2):
A = torch.cummin(M, 1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the scope of this PR, I think you can just leave a TODO comment. We can track that kind of specialized unit test in a separate issue, if that sounds okay.

Comment thread test/dynamo/test_dynamo_fallback.py Outdated
mat1 = torch.randn(5, 10, device=xm.xla_device())
mat2 = torch.randn(5, 10, device=xm.xla_device())

res = fn_fallback(M, mat1, mat2)
Copy link
Copy Markdown
Collaborator

@wonjoo-wj wonjoo-wj Apr 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't see the mentioned PyTorch's counter, we can also create/use our own metric when we collect the fallback nodes -- could be useful for debugging/testing.

new_node = partitioned_graph.graph.call_function(
extract_internal(fused_module), node.args, None)
node.replace_all_uses_with(new_node)
partitioned_graph.graph.erase_node(node)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering why we need this partitioned_graph.graph.erase_node(node)? Can we do graph.eliminate_dead_code() like https://github.com/pytorch/pytorch/blob/0d66db1b2a9470a50d930308dbffda017500b80b/torch/_prims/nvfuser_executor.py#L465 or is this something different?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm explicitly removing the old node (call to submodule), which I replace with a new node (call to optimized function). With eliminate_dead_code(), it goes through the entire graph and check whether a node is being used or not. Ideally my old node is not being used anymore so the two approaches should lead to the same results. However, I tried eliminate_dead_code() it seems it faces some errors when trying to eliminate the old node.

@seanlatias
Copy link
Copy Markdown
Collaborator Author

Thanks, @seanlatias! Seems like some existing dynamo tests are failing, mostly due to metrics. The errors should be able to be reproduced locally:

python test/dynamo/test_dynamo.py DynamoInferenceBasicTest.test_resnet18

Do you know why sometimes mark_step increases the compilation count?

@JackCaoG
Copy link
Copy Markdown
Collaborator

mark_step will cut the current graph and compile/execute it. It is expected that mark_step will increase both Compile and Execute count. After the initial dynamo init, running the dynamo program should not trigger mark_step. You can check https://github.com/pytorch/xla/blob/master/test/dynamo/test_dynamo.py#L71 for a more detail example.

@seanlatias
Copy link
Copy Markdown
Collaborator Author

Ok, because I see in the current dynamo bridge implementation, there are two places that increase the Compile count. One is the mark_step at the entry point (

xm.mark_step()
) and the other one is when caching the graph (
torch_xla._XLAC._xla_warm_up_cache(args_and_out, [])
). However, in the restnet19 test, it expects the Compile count to be just one (
self.assertEqual(met.metric_data('CompileTime')[0], 1)
). And I think this is one of the errors I'm getting.

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

The metric comparison failure seems like it has to do with how the metrics.clear_counters() was moved. Previously, we do a metrics.clear_counters() after mark_step() at https://github.com/pytorch/xla/blob/master/torch_xla/core/dynamo_bridge.py#L223. With this PR, we moved this metrics.clear_counters() to the run_node() function at https://github.com/pytorch/xla/pull/4935/files#diff-158d06d13623de8a2b4c4ee54902f148a8e0abda80870e5f25ee0bef3cc369b0R323.

Let me also try to build this branch locally to confirm.

@seanlatias
Copy link
Copy Markdown
Collaborator Author

Thanks Wonjoo. But the clear_counters() does not affect the Compile metric. Please let me know if you also see similar issues.

@seanlatias
Copy link
Copy Markdown
Collaborator Author

After second thought, I think it makes sense for my approach to introduce extra Compile and Execute because I indeed execute the graph one more time to collect fallback info. So if this approach makes sense to you guys, the assertions related to metric checking in other dynamo-related tests will need to be changed. I will first push a code with only metric checking failures and we can discuss from there.

@seanlatias
Copy link
Copy Markdown
Collaborator Author

For the optimizer test, the number is way off because of an issue with the partitioner. I have created a PR (pytorch/pytorch#100195) to solve it.

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

After second thought, I think it makes sense for my approach to introduce extra Compile and Execute because I indeed execute the graph one more time to collect fallback info. So if this approach makes sense to you guys, the assertions related to metric checking in other dynamo-related tests will need to be changed. I will first push a code with only metric checking failures and we can discuss from there.

Thanks for the investigation, Sean. Nice, that makes a lot sense to me. Let's update the Compile and Execute metrics in the tests. I also wonder what kind of implications this would have regarding performance since we're now adding another layer of execution on each initial trace, but we can look into that later.

Also if you could add this finding to comments in our dynamo tests explaining this change in compile/execute time, that'd be great.

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

For the optimizer test, the number is way off because of an issue with the partitioner. I have created a PR (pytorch/pytorch#100195) to solve it.

Thanks for opening the PR, I've just triggered the CI run on it. LGTM but let's see if the reviewer has any feedback. Also, to make this PR build/test based of that PyTorch PR, we can add a PyTorch pin to this PR. If you add a file .torch_pin under torch_patches/ with the PyTorch PR number, the CI on this PR will build PyTorch based off that PR. Example PR of adding such pin: 40f41fb.

@seanlatias
Copy link
Copy Markdown
Collaborator Author

For the optimizer test, the number is way off because of an issue with the partitioner. I have created a PR (pytorch/pytorch#100195) to solve it.

Thanks for opening the PR, I've just triggered the CI run on it. LGTM but let's see if the reviewer has any feedback. Also, to make this PR build/test based of that PyTorch PR, we can add a PyTorch pin to this PR. If you add a file .torch_pin under torch_patches/ with the PyTorch PR number, the CI on this PR will build PyTorch based off that PR. Example PR of adding such pin: 40f41fb.

Thanks Wonjoo. Do you know how I can request review in PyTorch or who I should tag?

@JackCaoG
Copy link
Copy Markdown
Collaborator

JackCaoG commented May 2, 2023

@seanlatias I pinged Sherlock offline, we can help you find a reviewer for that pr.

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

Had a quick sync with @seanlatias offline. Oddly, it seems like some fallback works on master branch if we just remove the assertion in our dynamo bridge at https://github.com/pytorch/xla/blob/master/torch_xla/core/dynamo_bridge.py#L226-L230.

Using a similar example that was in our original RFCs (#4742 and pytorch/pytorch#93601), the code below runs without any errors on master:

import torch
import torch_xla

import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch._dynamo as dynamo



@dynamo.optimize("torchxla_trace_once")
def fn_fallback(M, mat1, mat2, beta):
  # xla currently only support alpha and beta == 1
  ret = torch.addmm(M, mat1, mat2, beta=beta)
  ret2 = ret * 2;
  return ret2;

device = xm.xla_device()
M = torch.randn(2, 3, device=device)
mat1 = torch.randn(2, 3, device=device)
mat2 = torch.randn(3, 3, device=device)

res1 = fn_fallback(M, mat1, mat2, 0.5)
res2 = fn_fallback(M, mat1, mat2, 0.5)
print('res1:', res1)
print('res2:', res2)
print('----------')
print(met.metrics_report())

And with the metrics report I can see that the aten op did fallback to CPU:

Counter: aten::addmm
  Value: 1

Previously, this would fail with an error like

torch._dynamo.exc.BackendCompilerFailed: torchxla_trace_once raised AssertionError: compiler_fn did not return callable

cc @JackCaoG, am I missing something here? It appears the fallback that wasn't working before is now working on master.

Thanks @seanlatias for the investigation, let me know if I'm missing any details here.

@seanlatias
Copy link
Copy Markdown
Collaborator Author

seanlatias commented May 9, 2023

Just to provide more details. I have also tried using the same fix (i.e., simply remove the check and the assertion) on Dynamo test suite. So far 3 huggingface models and 4 timm models that used to fail because of CPU fallback now work correctly (by using --accuracy flag). I also dump the generated HLO codes and they also look good to me. Just want to make sure it is safe if we remove this check because I see that the check was there from the very beginning when Dynamo+XLA integration is introduced (pytorch/pytorch#87741).

Also, in the original issue (pytorch/pytorch#93601), the root cause was actually the assertion fail from invoking CPU fallback.


cpu_res = fn_fallback(M, mat1, mat2)
xla_res = dynamo_fn(xla_M, xla_mat1, xla_mat2)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use something similar to https://github.com/pytorch/xla/blob/master/torch_xla/core/dynamo_bridge.py#L57 to check there is no fallback ops in this case?

If this works, the fallback should happen on the pytorch cpu side and does not trigger a fallback counter on pytorch/xla end.

We should also check that xla::cummin is not in the counter, in case we lower this op in the future.

cpu_res = fn_fallback(M, mat1, mat2, 0.5)
xla_res = dynamo_fn(M, mat1, mat2, 0.5)

self.assertTrue(torch.allclose(cpu_res, xla_res.cpu()))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, we should add the counter check. Did we also handle the opernad shape invalid case in this pr? I thought we just handle ops that we don't lower today.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR now actually handles invalid shape/operand case because of how the updated FallBackNodeCollector works. The FallBackNodeCollector now executes each node of the input module so it can figure out which ops fall back (regardless of if it falls back because it's just not implemented in XLA or because specific shapes/operands are not supported).

cpu_res = fn_fallback(M, mat1, mat2, 0.5)
xla_res = dynamo_fn(M, mat1, mat2, 0.5)

self.assertTrue(torch.allclose(cpu_res, xla_res.cpu()))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this one we should also check the metrics for ExecuteTime(you need to clear metrics first). We are expecting to see 2 executions since fallback happens in the middle.

none_remover.add_nones(result)
return result
if len(result) == 1:
return result[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

@seanlatias, I was hoping to push directly to PR/branch but seems like I can't since it's a local forked branch. Are you able to push directly to #5000? Just rebased this PR with latest and fixed some conflicts.

@seanlatias
Copy link
Copy Markdown
Collaborator Author

@seanlatias, I was hoping to push directly to PR/branch but seems like I can't since it's a local forked branch. Are you able to push directly to #5000? Just rebased this PR with latest and fixed some conflicts.

@wonjoolee95 yeah I can do that. But before doing so, I'm wondering if this approach is still valid? Because we probably only need to remove the fallback checking and things will just work.

@seanlatias
Copy link
Copy Markdown
Collaborator Author

@wonjoolee95 maybe let me do this. We can create an env var that determines whether we should use explicit CPU fallback (my method), or go through the original flow with checking removed. The default one will be the original flow with checking removed. How does that sound to you? If you agree, I can add the logics and push it to your branch.

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

@wonjoolee95 maybe let me do this. We can create an env var that determines whether we should use explicit CPU fallback (my method), or go through the original flow with checking removed. The default one will be the original flow with checking removed. How does that sound to you? If you agree, I can add the logics and push it to your branch.

@seanlatias, I had a quick sync with Jack last week regarding the existing behavior and how the fallback works if we just remove the assertion. Based on our discussion and understanding, this shouldn't be the case -- it shouldn't work, as dynamo expects a single hash representing the graph that is entirely executable by XLA. And if this graph includes an unsupported op, it cannot be entirely executable by XLA.

However, as shown on #4935 (comment), I can still reproduce the behavior with unsupported ops working on master branches if the assertion is just simply removed. Let me do a quick experiment and post my findings why this might be happening. So let's wait on implementing the flag you mentioned. Also, let's move the discussion to the new PR #5000.

@seanlatias
Copy link
Copy Markdown
Collaborator Author

@wonjoolee95 Sounds good. Will do.

@seanlatias
Copy link
Copy Markdown
Collaborator Author

Close this and move to #5000.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants