Transfer data directly to the device#5772
Conversation
|
The issue here is that I was calculating the In the numerous special cases in In cases where the source type does not match the output type, I believe we'll still have to "stage" the data in an |
|
I have everything working locally now. Separating the more tedious changes here into #5777. After this PR, we'll still have to make an intermediate copy if the input tensor type does not match the target type. I added a counter to capture this overhead, since it may have a performance impact. For what it's worth, casting the Getting around the copy is simple: just create the tensors such that the type matches what it will be on the device. So if you want a |
13035ed to
df239db
Compare
|
The TPU CI is currently hanging on |
* Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516.
This reverts commit 4225deb.
30abda8 to
62cf72d
Compare
|
@jonb377 and I were able to figure out where the deadlock is. The hang is caused by a GIL deadlock when we try to retrieve data from the device before a transfer finishes, and the transfer is the only thing keeping an Here's what's happening:
Here's the relevant stack trace through Details |
|
I fixed a similar GIL deadlock bug about a year ago in #4504. In that case, the solution was to release the GIL during |
|
I wrapped the GIL release and data transfer into a new utility, I'm open to better names for |
|
|
||
| std::vector<DataPtr> TransferToServer( | ||
| absl::Span<const TensorSource> tensors) override; | ||
| absl::Span<const std::shared_ptr<const TensorSource>> tensors) override; |
There was a problem hiding this comment.
Should we require the caller to manage the ownership & lifetime of TensorSource*, i.e., const TensorSource* instead, or it's necessary to ensure that the memory is held during the client ops and in the client?
There was a problem hiding this comment.
Good question. PJRT lets us tie the lifetime of an object to an operation by capturing it in a callback. std::functions have to be copyable, so shared_ptr is our best choice here. TensorSource itself may be expensive or impossible to copy.
The caller of TransferToServer will be much shorter-lived than the actual transfer, so ownership should pass down. We could tighten up the interface here and consume a unique_ptr since we only need copyability within the implementation of TransferToServer. What do you think?
|
I will try to take a look today |
* Transfer data directly to the device (pytorch#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (pytorch#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (pytorch#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
* Transfer data directly to the device (#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 6384516. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb. * Skip hanging test * fix gil deadlock * formatting
Take two. See the notes from the original PR #5752
New changes:
at::Tensorinstead of the strides in thexla::Shape.at::Tensorif the target type differs from the actual type.XlaDataToTensors, sinceat::Tensordestruction afterTransferToServercan deadlock withTransferFromServer.