Skip to content

fix dynamo inplace copy#7933

Merged
zpcore merged 3 commits intomasterfrom
piz/inplace-cp
Sep 3, 2024
Merged

fix dynamo inplace copy#7933
zpcore merged 3 commits intomasterfrom
piz/inplace-cp

Conversation

@zpcore
Copy link
Copy Markdown
Member

@zpcore zpcore commented Aug 29, 2024

Fix inplace copy that extra mark_step will be conducted.

@torch.compile(backend='openxla')
def cc(arg0_1):
  x = torch.randn([1])
  copy = torch.ops.aten.copy.default(arg0_1, x)
  return copy

@zpcore zpcore added the dynamo label Aug 29, 2024
@JackCaoG
Copy link
Copy Markdown
Collaborator

The issue is that InputCollector might also trigger inplace ops, we just need to clear it one more time

--- a/torch_xla/_dynamo/dynamo_bridge.py
+++ b/torch_xla/_dynamo/dynamo_bridge.py
@@ -723,6 +723,18 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
     return extract_compiled_graph_helper(xla_model, xla_args)
 
 
+def _clear_pending_irs_on_args(args_tensor_only, cloned_args):
+  # if args_tensor_only has pending IR which means there is a in place operations
+  # happened. We don't want to execute that operation yet, so we will replace the
+  # pending IR with the cloned arg.
+  args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
+      args_tensor_only)
+
+  for i, need_update in enumerate(args_need_update_bool):
+    if need_update and isinstance(args_tensor_only[i], torch.Tensor):
+      args_tensor_only[i].copy_(cloned_args[i])
+
+
 def partition_fx_graph_for_cpu_fallback(xla_model, xla_args, all_xla_args,
                                         all_xla_args_tensor_only):
   # below logic will try to partition the fx graph based on the fallback ops.
@@ -739,18 +751,8 @@ def partition_fx_graph_for_cpu_fallback(xla_model, xla_args, all_xla_args,
     print('Dynamo fallback ops are' + str(unsupported_nodes) +
           '. Please open a GitHub issue with the above op lowering requests.')
 
-  # This logic, needed for supporting in-place operations, is a duplicate of
-  # the one in the main `extract_internal` function above. We need to do this
-  # check for fetching fallback ops as well.
-  # TODO (@wonjoo): Make this duplicate code a bit cleaner.
-  args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
-      all_xla_args_tensor_only)
-
-  # Again, same logic in the `extract_internal` above to support in-place operations.
-  # TODO (@wonjoo): Make this duplicate code a bit cleaner.
-  for i, need_update in enumerate(args_need_update_bool):
-    if need_update and isinstance(all_xla_args_tensor_only[i], torch.Tensor):
-      all_xla_args_tensor_only[i].copy_(cloned_args[i])
+  # UnsupportedNodesCollector might trigger in place ops, need to clear them here.
+  _clear_pending_irs_on_args(all_xla_args_tensor_only, cloned_args)
 
   torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))
 
@@ -775,6 +777,9 @@ def partition_fx_graph_for_cpu_fallback(xla_model, xla_args, all_xla_args,
   partitioned_graph = partitioner.fuse_partitions(partitions)
   InputCollector(partitioned_graph).run(*xla_args)
 
+  # InputCollector might trigger in place ops, need to clear them here.
+  _clear_pending_irs_on_args(all_xla_args_tensor_only, cloned_args)
+
   # compile each submodule and replace it with a call
   for node in partitioned_graph.graph.nodes:
     if node.op == "call_module" and "fused_" in node.name:

the test I used

    
@torch.compile(backend='openxla')
def cc(arg0_1):
  new_arg = arg0_1 * 2
  copy = torch.ops.aten.copy.default(arg0_1, new_arg)
  return copy


device = torch_xla.device()
input = torch.randn([1], device=device)
print(input)
res = cc(input)
print(res)

You can add it to test_dynamo.py and make sure that *2 only got executed once

@zpcore
Copy link
Copy Markdown
Member Author

zpcore commented Aug 30, 2024

Thanks for the help, this works now!

@zpcore zpcore marked this pull request as ready for review August 31, 2024 20:38
@zpcore zpcore requested a review from JackCaoG September 3, 2024 16:24
@zpcore zpcore merged commit 989ac69 into master Sep 3, 2024
@zpcore zpcore deleted the piz/inplace-cp branch September 3, 2024 18:21
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Oct 11, 2024
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Dec 11, 2024
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Dec 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants