Skip to content

add torch_xla.experimental.compile for eager mode#7246

Merged
JackCaoG merged 3 commits intomasterfrom
JackCaoG/torch_xla_compile2
Jun 12, 2024
Merged

add torch_xla.experimental.compile for eager mode#7246
JackCaoG merged 3 commits intomasterfrom
JackCaoG/torch_xla_compile2

Conversation

@JackCaoG
Copy link
Copy Markdown
Collaborator

This should only be used for the eager mode. The compile pretty much enable the LTC before entering the function and disable it again.

TODO

  1. add more unit test

@JackCaoG JackCaoG added the usability Bugs/features related to improving the usability of PyTorch/XLA label Jun 11, 2024
@JackCaoG JackCaoG requested review from lsy323 and qihqi June 11, 2024 23:05
@JackCaoG JackCaoG marked this pull request as ready for review June 12, 2024 00:18
@JackCaoG
Copy link
Copy Markdown
Collaborator Author

ok test added, should be ready for review.

Copy link
Copy Markdown
Collaborator

@qihqi qihqi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, few qqs:

result = func(*args, **kwargs)
except Exception as e:
# Handle exceptions (if needed)
print(f"Error in target function: {e}")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the exception is tracing exception right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, execution is async so we won't be able to catch it here.

print(f"Error in target function: {e}")
raise # Re-raise the exception
# Sync the graph generated by the target function.
torch_xla.sync()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there actaully runs the graph and you might get exceptions here too.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, the way that LTC works is that async execution happens in a separate thread and the runtime error will be set in the unlocker. Next time when we try to get the device lock it will find that exception and throw
https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L861-L864

That being said I agree with you that there is no harm to put sync in the try region. Let me update that in the following pr.

"""

@functools.wraps(func) # Keep function's name, docstring, etc.
def wrapper(*args, **kwargs):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so is the mechanism for caching the graph is already there right so no need to do anything extra?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still need to trace the whole model(run all python code), we just skip XLA compilation and lowering to HLO part.

This compile does not modify the function bytecode or transform the function in anyway besides enabling the tracing mode. It is actually more accurate to call it trace but compile is more align with pytorch API.

@JackCaoG JackCaoG merged commit 90168e8 into master Jun 12, 2024
@JackCaoG JackCaoG added the eager PyTorch/XLA eager-mode label Jun 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

eager PyTorch/XLA eager-mode usability Bugs/features related to improving the usability of PyTorch/XLA

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants