Skip to content

Ignore non-XLA nodes and their direct dependents.#6170

Merged
ysiraichi merged 5 commits intomasterfrom
ysiraichi/ignore-non-xla-nodes
Jan 11, 2024
Merged

Ignore non-XLA nodes and their direct dependents.#6170
ysiraichi merged 5 commits intomasterfrom
ysiraichi/ignore-non-xla-nodes

Conversation

@ysiraichi
Copy link
Copy Markdown
Collaborator

Fix: #5966

This PR generalizes FallBackNodeCollector into UnsupportedNodesCollector, improving and solving a few issues the old implementation had:

  • nodes which resulted in a non-XLA tensor weren't flagged as fallback
  • nodes whose arguments were a container with non-XLA tensors weren't flagged as fallback (e.g. stack)
  • tracking these "unsupported" nodes weren't really only fallback nodes
    • but, nodes that can't exist in partition boundaries (e.g. arguments and return values)

@ysiraichi
Copy link
Copy Markdown
Collaborator Author

cc @JackCaoG @miladm

@ysiraichi
Copy link
Copy Markdown
Collaborator Author

Note: this PR can, possibly, create more partitions than before, affecting performance. Maybe we should make sure there are no regressions before actually landing it.

@JackCaoG
Copy link
Copy Markdown
Collaborator

I think the definition of the fallback is that we have to execute this op on fallback devices(uusually) cpu. for the cases you mentioned

  1. nodes which resulted in a non-XLA tensor weren't flagged as fallback
    Does the operation being executed on cpu or gpu/tpu?

  2. nodes whose arguments were a container with non-XLA tensors weren't flagged as fallback (e.g. stack)
    I guess this one is technically a fallback, since op will be executed on cpu. It will show up in dynamo fallback messages, but shouldn't be in the XLA's metrics aten counter I think?

I don't fully understand what 3 means, can you give ma an example?

@ysiraichi
Copy link
Copy Markdown
Collaborator Author

ysiraichi commented Jan 8, 2024

The problem was that FallBackNodeCollector was doing more than just collecting operations that were executed on CPU. It was also flagging nodes, that had non-XLA tensors arguments, as fallback (even though they were executed on XLA).

Reason: my guess is that it did so, in order to guarantee that the input/output of the generated partitions were all XLA tensors (which I believe extract_internal assumes).

Assuming that's the case, it still missed the 2 cases I mentioned:

  • nodes which result in non-XLA tensor: might end up as output of some partition
  • nodes whose arguments were a container with non-XLA tensors: arguments might end up as the input of some partition

Solution:

  • disambiguate fallback from unsupported nodes (from the perspective of the partitioner)
  • add the missing 2 cases

@JackCaoG let me know what you think.

@JackCaoG
Copy link
Copy Markdown
Collaborator

JackCaoG commented Jan 8, 2024

yea that make sense, is this pr ready for review?

@ysiraichi
Copy link
Copy Markdown
Collaborator Author

Yes, it is.

Comment thread test/dynamo/test_bridge.py Outdated
Comment thread torch_xla/core/dynamo_bridge.py Outdated
Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm, minor questions on test cases.

@ysiraichi ysiraichi force-pushed the ysiraichi/ignore-non-xla-nodes branch from c38c1a0 to 65cd98a Compare January 10, 2024 22:34
@ysiraichi ysiraichi force-pushed the ysiraichi/ignore-non-xla-nodes branch from 65cd98a to 416bccb Compare January 10, 2024 23:28
@ysiraichi
Copy link
Copy Markdown
Collaborator Author

@JackCaoG could you approve this PR?

@ysiraichi ysiraichi merged commit 8141078 into master Jan 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[torchbench] AssertionError: All tensors should be on xla

2 participants