add torch_xla.experimental.compile for eager mode#7246
Conversation
|
ok test added, should be ready for review. |
| result = func(*args, **kwargs) | ||
| except Exception as e: | ||
| # Handle exceptions (if needed) | ||
| print(f"Error in target function: {e}") |
There was a problem hiding this comment.
Here the exception is tracing exception right?
There was a problem hiding this comment.
yea, execution is async so we won't be able to catch it here.
| print(f"Error in target function: {e}") | ||
| raise # Re-raise the exception | ||
| # Sync the graph generated by the target function. | ||
| torch_xla.sync() |
There was a problem hiding this comment.
there actaully runs the graph and you might get exceptions here too.
There was a problem hiding this comment.
right, the way that LTC works is that async execution happens in a separate thread and the runtime error will be set in the unlocker. Next time when we try to get the device lock it will find that exception and throw
https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L861-L864
That being said I agree with you that there is no harm to put sync in the try region. Let me update that in the following pr.
| """ | ||
|
|
||
| @functools.wraps(func) # Keep function's name, docstring, etc. | ||
| def wrapper(*args, **kwargs): |
There was a problem hiding this comment.
so is the mechanism for caching the graph is already there right so no need to do anything extra?
There was a problem hiding this comment.
we still need to trace the whole model(run all python code), we just skip XLA compilation and lowering to HLO part.
This compile does not modify the function bytecode or transform the function in anyway besides enabling the tracing mode. It is actually more accurate to call it trace but compile is more align with pytorch API.
This should only be used for the eager mode. The
compilepretty much enable the LTC before entering the function and disable it again.TODO