Skip to content

Add torch._dynamo.is_fullgraph_compiling to allow different codepath depending on fullgraph tracing #120400

Closed
fxmarty wants to merge 7 commits intopytorch:mainfrom
fxmarty:expose-compile-fullgraph-detection
Closed

Add torch._dynamo.is_fullgraph_compiling to allow different codepath depending on fullgraph tracing #120400
fxmarty wants to merge 7 commits intopytorch:mainfrom
fxmarty:expose-compile-fullgraph-detection

Conversation

@fxmarty
Copy link

@fxmarty fxmarty commented Feb 22, 2024

This PR fixes https://pytorch.slack.com/archives/C033H6DJSJU/p1708510833453919 & allows to implement different code path depending on whether torch.compile is called with the argument fullgraph=True.

Example (see also the unit test):

def f(x):
    if torch._dynamo.is_fullgraph_compiling():
        # fullgraph=True compliant code path
    else:
        # more permissive code path

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/120400

Note: Links to docs will display an error until the docs builds have been completed.

❌ 10 New Failures

As of commit 86c5756 with merge base 8a32a07 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 22, 2024

Please seek CI approval before scheduling CIFlow labels

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 22, 2024

Please seek CI approval before scheduling CIFlow labels

Comment on lines +288 to +289
# See: https://github.com/pytorch/pytorch/issues/110765
tx.mark_inconsistent_side_effects()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jon-chuang I am not sure whether this is necessary here?

Copy link
Collaborator

@jon-chuang jon-chuang Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's necessary for capturing side-effect only code. If your code has torch operation on tensor input, it makes no difference.

Comment on lines +5875 to +5883
opt_f = torch.compile(f, fullgraph=True)

self.assertEqual(f(), torch.zeros(2, 2))
self.assertEqual(opt_f(), torch.ones(2, 2))

opt_g = torch.compile(g, fullgraph=False)

self.assertEqual(g(), torch.zeros(2, 2))
self.assertEqual(opt_g(), torch.zeros(2, 2))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One issue here is that if calling twice the same:

opt_f = torch.compile(f, fullgraph=False)
f()
opt_f = torch.compile(f, fullgraph=True)
f()

somehow torch.compile does not initialize new InstructionTranslator at the second call, so the one_graph attribute is wrong, see https://app.slack.com/client/T2077MDKQ/C033H6DJSJU

Is it legal to call torch.compile several times on the same python object/function?

Copy link
Collaborator

@jon-chuang jon-chuang Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if your original graph was a single graph, you don't recompile for second invocation

Copy link
Collaborator

@jon-chuang jon-chuang Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may or may not need to change this behaviour; it's an easy optimization that held previously.

If this PR proposes to capture different graphs based on presence of this function call, then you can make the behaviour stricter - i.e. cause fullgraph flag changing to always recompile - only when your new function call is present; see guards for fullgraph flag.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this PR proposes to capture different graphs based on presence of this function call

Yes, this PR proposes to capture different graphs based on torch._dynamo.is_fullgraph_compiling() (i.e. based on fullgraph argument). This is working well in case we don't call multiple times torch.compile on the same function/object, but currently does not work when calling successively torch.compile on the same function/object due to InstructionTranslator not being re-initialized.

then you can make the behaviour stricter - i.e. cause fullgraph flag changing to always recompile - only when your new function call is present; see guards for fullgraph flag.

Could you point me out to the file responsible for that? I could not find nopython, one_graph, fullgraph references in guards.py or mutation_guard.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may have been removed unfortunately, a bunch of code got reverted due to some meta internal failures (the investigation hasn't concluded afaik).

You may have to introduce a new guard that recompiles when the fullgraph flag changes, if this behaviour is strictly necessary.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was removed in #115384

@albanD albanD requested a review from anijain2305 February 22, 2024 15:57
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 22, 2024
@Skylion007
Copy link
Collaborator

@janeyx99 This looks like it would be useful to automatically set the capturable parameter on torch optimizers?

@Skylion007 Skylion007 requested a review from janeyx99 February 22, 2024 22:31
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 23, 2024

Please seek CI approval before scheduling CIFlow labels

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 23, 2024

Please seek CI approval before scheduling CIFlow labels

@janeyx99
Copy link
Contributor

@janeyx99 This looks like it would be useful to automatically set the capturable parameter on torch optimizers?

well we want capturable to be enabled for non fullgraph tracing too. so this shouldn’t change how that logic is currently set up.

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 23, 2024

Please seek CI approval before scheduling CIFlow labels

or current_backend == cached_backends.get(backend_obj_id, None)
)

def check_nopython(ref_nopython: bool):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in eval_frame.py are probably very far from the best, but it works.

This check_nopython is actually called twice it seems at each guard check, not sure why:

--------------- f fullgraph=True
--------------- fullgraph=False
guarded_backend_cache.nopython False
ref_nopython True
guarded_backend_cache.nopython False
ref_nopython True
--------------- f fullgraph=True
call forward
guarded_backend_cache.nopython True
ref_nopython False
guarded_backend_cache.nopython True
ref_nopython True
--------------- fullgraph=False
guarded_backend_cache.nopython False
ref_nopython True
guarded_backend_cache.nopython False
ref_nopython False
--------------- fullgraph=True
guarded_backend_cache.nopython True
ref_nopython False
guarded_backend_cache.nopython True
ref_nopython True

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 26, 2024

Please seek CI approval before scheduling CIFlow labels

@fxmarty
Copy link
Author

fxmarty commented Feb 26, 2024

Some tests were not passing in the CI but do pass locally, not sure why:

test/test_sparse_csr.py2024-02-26T05:48:12.3626098Z FAILED [0.1356s] test_sparse_csr.py::TestSparseCSRCPU::test_sparse_to_sparse_compressed_SparseBSC_cpu_float64 - torch._dynamo.exc.InternalTorchDynamoError: nnz not found
2024-02-26T06:09:20.6013131Z FAILED [0.0583s] functorch/test_rearrange.py::TestRearrange::test_concatenations_and_stacking - torch._dynamo.exc.InternalTorchDynamoError: dimension d0 is unbound
2024-02-26T06:09:20.6358400Z FAILED [0.0714s] functorch/test_rearrange.py::TestRearrange::test_ellipsis_ops - torch._dynamo.exc.InternalTorchDynamoError: dimension d0 is unbound
2024-02-26T06:09:20.6690803Z FAILED [0.0971s] functorch/test_rearrange.py::TestRearrange::test_rearrange_consistency - torch._dynamo.exc.InternalTorchDynamoError: dimension d0 is unbound
2024-02-26T06:09:20.6960094Z FAILED [0.0425s] functorch/test_rearrange.py::TestRearrange::test_rearrange_permutations - torch._dynamo.exc.InternalTorchDynamoError: dimension d0 is unbound
2024-02-26T06:09:20.7040419Z FAILED [0.0775s] functorch/test_rearrange.py::TestRearrange::test_squeeze - torch._dynamo.exc.InternalTorchDynamoError: dimension d0 is unbound
2024-02-26T06:09:20.7120616Z FAILED [0.0536s] functorch/test_rearrange.py::TestRearrange::test_unsqueeze - torch._dynamo.exc.InternalTorchDynamoError: dimension d0 is unbound
2024-02-26T06:19:56.1722324Z FAILED [0.0209s] torch_np/test_basic.py::TestOneArr::test_asarray_array_func0 - torch._dynamo.exc.InternalTorchDynamoError: Boolean value of Tensor with more than one value is ambiguous
2024-02-26T06:12:06.8184282Z FAILED [0.0375s] test_jit.py::TestTypeSharing::test_tracing_gives_different_types - torch._dynamo.exc.InternalTorchDynamoError: __eq__(): incompatible function arguments. The following argument types are supported:
2024-02-26T06:12:06.8184536Z     1. (self: torch._C.Type, arg0: torch._C.Type) -> bool

There was

2024-02-26T06:07:38.9873830Z FAILED [0.1431s] test_xnnpack_integration.py::TestXNNPACKOps::test_conv2d_transpose - hypothesis.errors.Flaky: Hypothesis test_conv2d_transpose(self=<__main__.TestXNNPACKOps testMethod=test_conv2d_transpose>, batch_size=1, input_channels_per_group=1, height=5, width=5, output_channels_per_group=1, groups=1, kernel_h=1, kernel_w=1, stride_h=1, stride_w=1, pad_h=0, pad_w=0, output_pad_h=0, output_pad_w=0, dilation=1, use_bias=False, format=None) produces unreliable results: Falsified on the first call but did not on a subsequent one

as welll but I did not compile with xnnpack so not sure if this one passes locally.

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 26, 2024

Please seek CI approval before scheduling CIFlow labels

@fxmarty
Copy link
Author

fxmarty commented Feb 29, 2024

Hi, I'll be away for two weeks from next week, happy to do modifications here by Friday. It would be cool to have this in 2.3.

@fxmarty
Copy link
Author

fxmarty commented Apr 15, 2024

any update @anijain2305? Having a torch._dynamo.is_fullgraph_compiling() and torch._dynamo.is_exporting() would be helpful for us at HF. Are you interested in adding these to PyTorch?

@anijain2305
Copy link
Contributor

@fxmarty Can you share why you need this? The changes required for this are quite awkward. Wondering if you have a specific problem that we can fix so that this flag is not needed.

@fxmarty
Copy link
Author

fxmarty commented Apr 18, 2024

Three use cases where I would have wanted this:

Basically, I would want to have different execution path depending on fullgraph and/or export.

@fxmarty
Copy link
Author

fxmarty commented Apr 26, 2024

any thoughts @anijain2305?

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@fxmarty
Copy link
Author

fxmarty commented Jul 12, 2024

any interest?

@github-actions github-actions bot closed this Aug 11, 2024
amodab01 referenced this pull request in huggingface/transformers Feb 10, 2025
* update non-causal mask for sdpa

* add test

* update docstrings

* add one more test

* fix cross attention bug

* gentler atol/rtol
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor module: dynamo open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants