Distribute Literal->Tensor copies across thread pool#5825
Conversation
b9d8a7c to
8992668
Compare
| std::vector<at::Tensor> XlaDataToTensors( | ||
| absl::Span<const torch::lazy::BackendDataPtr> xla_data, | ||
| at::ScalarType dest_element_type) { | ||
| absl::Span<const at::ScalarType> dest_element_type) { |
There was a problem hiding this comment.
What's the reason for this change?
There was a problem hiding this comment.
it seems like we never really call XlaDataToTensors with different dest_element_type. @jonb377 are you introducing a new use case? if not we can keep it as a singleton?
There was a problem hiding this comment.
The reason to make the dest_elem_types a vector is actually the next change - I'm batching the local shard transfers for many tensors into a single XlaDataToTensors call. I probably should have kept this refactor with the upcoming change... But it makes that PR slightly smaller.
| std::vector<at::Tensor> tensors(literals.size()); | ||
| absl::BlockingCounter counter(literals.size()); | ||
| for (size_t i = 0; i < tensors.size(); ++i) { | ||
| auto copy_fn = [&, i]() { |
There was a problem hiding this comment.
Can you capture the variables you need explicitly?
There was a problem hiding this comment.
Actually, you need just about every variable in this scope since it's pretty narrow. I take that back.
| absl::Span<const torch::lazy::BackendDataPtr> xla_data, | ||
| at::ScalarType dest_element_type) { | ||
| absl::Span<const at::ScalarType> dest_element_type) { | ||
| std::vector<xla::Literal> literals = ReleaseGilAndTransferData(xla_data); |
There was a problem hiding this comment.
I wonder if we should just be returning Tensors here
There was a problem hiding this comment.
I'm interested in making TransferFromServer return at::Tensor and cut out the xla::Literal middleman, but that's in the idea phase. Opted to keep this change smaller and just distribute the copy work over more cores.
|
Thanks for the reviews @will-cromar and @JackCaoG! |
8992668 to
fa65980
Compare
fa65980 to
c8f7315
Compare
* Distribute Literal->Tensor copies across thread pool * Update for pytorch#5799
* Distribute Literal->Tensor copies across thread pool * Update for pytorch#5799
* Distribute Literal->Tensor copies across thread pool * Update for pytorch#5799
* Distribute Literal->Tensor copies across thread pool * Update for #5799
* Distribute Literal->Tensor copies across thread pool * Update for #5799
After an xla::Literal has been created in TransferFromServer, it must be copied into an at::Tensor. This incurs a significant amount of overhead (up to 3x the transfer overhead after #5824). This is because the copies still occur synchronously on a single thread.
This change dispatches the copies to a thread pool to speed up the process. When checkpointing a 2B parameter model, the overhead decreases from ~5000ms to ~611ms.*
*Note: These benchmarks were prior to #5799 and used the old threading library.