Skip to content

Integrate dlpack to dynamo.#7173

Merged
vanbasten23 merged 14 commits intomasterfrom
xiowei/integrate_dlpack_with_dynamo_fallback
Jun 14, 2024
Merged

Integrate dlpack to dynamo.#7173
vanbasten23 merged 14 commits intomasterfrom
xiowei/integrate_dlpack_with_dynamo_fallback

Conversation

@vanbasten23
Copy link
Copy Markdown
Collaborator

@vanbasten23 vanbasten23 commented Jun 3, 2024

This PR integrate the DLPack API to dynamo so that when we move a tensor between CUDA and XLA we don't have to go through CPU anymore.

Test:
PJRT_DEVICE=CUDA python pytorch/xla/test/dynamo/test_dynamo.py -k test_simple_model_automoves_tensors

@vanbasten23 vanbasten23 changed the title [DO NOT REVIEW YET] Integrate dlpack to dynamo. Integrate dlpack to dynamo. Jun 6, 2024
@vanbasten23 vanbasten23 requested review from JackCaoG and ysiraichi June 6, 2024 17:42
@vanbasten23 vanbasten23 marked this pull request as ready for review June 6, 2024 17:43
Copy link
Copy Markdown
Collaborator

@ysiraichi ysiraichi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, LGTM.
I do think we need better testing for this use case, though. What do you think about running all dynamo tests with this flag set?

Comment thread torch_xla/core/dynamo_bridge.py Outdated
Comment thread torch_xla/core/dynamo_bridge.py
Comment on lines +712 to +714
<< "The device currently being used : " << pjrt_device->DebugString()
<< " is different from the device where the buffer resides: "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, better error message!

@vanbasten23
Copy link
Copy Markdown
Collaborator Author

Overall, LGTM. I do think we need better testing for this use case, though. What do you think about running all dynamo tests with this flag set?

Currently it works for inference but not for training. If we run all dynamo test, then we need to change all the test. OTOH, all this PR does is to move CUDA tensor to the XLA device at the beginning of the dynamo bridge, all the rest should remain the same. So do we really need to run all dynamo tests with the flag?

@vanbasten23 vanbasten23 requested a review from ysiraichi June 10, 2024 18:59
@vanbasten23 vanbasten23 force-pushed the xiowei/integrate_dlpack_with_dynamo_fallback branch from 5b47510 to 0f3218c Compare June 11, 2024 17:15
Comment thread test/dynamo/test_dynamo.py Outdated
xenv.ZERO_COPY_ENABLED: zero_copy_enabled,
})
x = torch.tensor(100.0).to(device="cuda:0")
y = torch.tensor(200.0).to(device="cuda:0")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't you check that output is also on cuda:0?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and somehow verified that computation is run using dynamo not fallback so something

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't you check that output is also on cuda:0?

The test already checks that

self.assertTrue(res_xla_dynamo.device == original_device)

somehow verified that computation is run using dynamo not fallback so something

The test checks tracing is skipped in following runs

# verify that tracing is skipped in following runs
met.clear_counters()
res_xla_dynamo_reused = fn_simple_dynamo(x, y)
self.assertNotIn('xla::add', met.counter_names())
. Do you think it's enough?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are looking for fallbacks, what about using torch_xla._XLAC._get_executed_fallback_ops?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG. I added a check self.assertEqual(torch_xla._XLAC._get_executed_fallback_ops(), [])

# Have to move to CPU before moving it to target device.
moved_tensor = tensor.to(cpu_device)
moved_tensor = moved_tensor.to(target_device)
zero_copy_enabled = xu.getenv_as(xenv.ZERO_COPY_ENABLED, bool, defval=False)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason this has to be a env var? Biggest reason we use env var is we need to communcate something between python and C++ layers, so it is the master process need to pass some information(like rank) to the child process. In your case this is really just a config somewhere and should not be set as env var.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for you case I think you can just always use dlpack, is there any reason we don't want to use dlpack to convert XLA:GPU and cuad?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flag is only temporary. It's use to do a/b test: how much performance when we move tensor through CPU vs how much performance we can get by using dlpack.

So the temporary be removed later once the a/b testing is done.

@vanbasten23 vanbasten23 requested a review from JackCaoG June 12, 2024 15:55
@ysiraichi
Copy link
Copy Markdown
Collaborator

Currently it works for inference but not for training. If we run all dynamo test, then we need to change all the test. OTOH, all this PR does is to move CUDA tensor to the XLA device at the beginning of the dynamo bridge, all the rest should remain the same. So do we really need to run all dynamo tests with the flag?

In this case, we could have a decorator or something like this for selecting a few tests to run with zero-copy. For example, there is DynamoInferenceBasicTest at test/dynamo/test_dynamo.py.

While this is a small change (quantity), it may end up changing the execution behavior in ways that we might not think of. I think it would make this PR more robust.

@vanbasten23
Copy link
Copy Markdown
Collaborator Author

Currently it works for inference but not for training. If we run all dynamo test, then we need to change all the test. OTOH, all this PR does is to move CUDA tensor to the XLA device at the beginning of the dynamo bridge, all the rest should remain the same. So do we really need to run all dynamo tests with the flag?

In this case, we could have a decorator or something like this for selecting a few tests to run with zero-copy. For example, there is DynamoInferenceBasicTest at test/dynamo/test_dynamo.py.

Sure, I've added test for DynamoInferenceBasicTest

Copy link
Copy Markdown
Collaborator

@ysiraichi ysiraichi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

I left a few minor comments.
Thank you for taking your time to adapt existing dynamo tests. They look great!

Comment thread test/dynamo/test_dynamo.py Outdated
Comment on lines +307 to +309
# We need to make `dim` depend on `initialize_on_cuda` because the compilation cache
# does not clean itself between the parameterized tests.
dim = 5 + int(initialize_on_cuda)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean (i) dynamo's cache or (ii) XLA cache?

  • If it's (i), we can just reset it
  • If it's (ii), can't we just reset it, too? If not, I would argue that we don't need to worry about it, since that's not what's being tested, here

What do you think?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's (ii) and I don't think there exists a way to reset the XLA compilation cache afaik.

I would argue that we don't need to worry about it, since that's not what's being tested, here

Actually, the existing test

self.assertEqual(met.metric_data('CompileTime')[0], compile_count + 1)
tests the compilation cache. Without the change, the test will fail because of the reason in the comment. That's why I have to change it..

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I guess that works. You could also only run this check in one of the runs (e.g. if initialize_on_cuda == True). But, I think this is also ok.

Comment thread torch_xla/core/dynamo_bridge.py
Comment thread torch_xla/utils/dlpack.py
@vanbasten23
Copy link
Copy Markdown
Collaborator Author

Thanks for the review!

@vanbasten23 vanbasten23 merged commit c216d26 into master Jun 14, 2024
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Oct 11, 2024
@miladm miladm added the xla:gpu label Nov 22, 2024
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Dec 11, 2024
yitongh pushed a commit to AlibabaPAI/xla that referenced this pull request Dec 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants