Skip to content

Automatically move CUDA non XLA Tensors to XLA Device and back to CUDA device#6644

Merged
changm merged 3 commits intomasterfrom
changm/automove
Mar 13, 2024
Merged

Automatically move CUDA non XLA Tensors to XLA Device and back to CUDA device#6644
changm merged 3 commits intomasterfrom
changm/automove

Conversation

@changm
Copy link
Copy Markdown
Collaborator

@changm changm commented Feb 29, 2024

Currently only works for inference. The assumptions don't hold for training with Autograd yet.

@changm changm self-assigned this Feb 29, 2024
@changm changm requested a review from vanbasten23 February 29, 2024 00:24
Comment thread test/dynamo/test_dynamo.py Outdated
Comment thread test/dynamo/test_dynamo.py Outdated
Comment thread test/dynamo/test_dynamo.py Outdated
@vanbasten23 vanbasten23 requested a review from JackCaoG March 5, 2024 21:49
@changm changm requested a review from golechwierowicz March 6, 2024 15:11
Comment thread torch_xla/core/dynamo_bridge.py Outdated
@changm changm force-pushed the changm/automove branch from 476edaf to 2f7c0cc Compare March 7, 2024 14:51
@changm changm requested a review from vanbasten23 March 7, 2024 14:52
@JackCaoG
Copy link
Copy Markdown
Collaborator

JackCaoG commented Mar 7, 2024

Is this an experimental pr or you want to merge this?

@changm changm force-pushed the changm/automove branch from f50e471 to 195978d Compare March 7, 2024 20:38
@changm
Copy link
Copy Markdown
Collaborator Author

changm commented Mar 7, 2024

Is this an experimental pr or you want to merge this?

Ideally we would merge this, or is there a reason not to?

Comment thread test/dynamo/test_dynamo.py
Comment thread torch_xla/core/dynamo_bridge.py Outdated
Comment thread torch_xla/core/dynamo_bridge.py Outdated
@changm changm changed the title Automatically move non XLA Tensors to XLA Device and back to original device. Automatically move CUDA non XLA Tensors to XLA Device and back to CUDA device Mar 11, 2024
Comment thread test/dynamo/test_dynamo.py
Comment thread torch_xla/core/dynamo_bridge.py
Comment thread torch_xla/core/dynamo_bridge.py Outdated
nonlocal skip_checking_input_sharding_threashold

original_device: torch.device = _get_input_arg_device(args)
is_cuda_args: bool = _args_on_cuda(args)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_args_on_cuda will call _get_input_arg_device which is redundant.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think it's a little cleaner to do the redundant call, but removed the call here.

Comment thread torch_xla/core/dynamo_bridge.py
Comment thread torch_xla/core/dynamo_bridge.py
Comment thread torch_xla/core/dynamo_bridge.py
@changm changm requested review from JackCaoG and ysiraichi March 11, 2024 23:02
@changm changm merged commit d13ae1b into master Mar 13, 2024
@changm changm deleted the changm/automove branch March 13, 2024 16:41
@vanbasten23
Copy link
Copy Markdown
Collaborator

Sorry for being late. It's looking good!

yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Oct 11, 2024
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Dec 11, 2024
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Dec 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants