Skip to content

Transfer data directly to the device#5752

Merged
will-cromar merged 12 commits intomasterfrom
wcromar/remove-populate-fn
Nov 2, 2023
Merged

Transfer data directly to the device#5752
will-cromar merged 12 commits intomasterfrom
wcromar/remove-populate-fn

Conversation

@will-cromar
Copy link
Copy Markdown
Collaborator

@will-cromar will-cromar commented Oct 31, 2023

  • Remove populate_fn entirely. We had to implement this weird indirection before to prevent our TensorFlow code from taking a dependency on PyTorch code.
  • Copy data directly from PyTorch Tensors into runtime buffers without having to stage them in an xla::Literal first.
  • Re-implement TensorSource as an interface that can directly wrap and own a data source.
    • Ownership makes it easier to implement async data transfers. Otherwise, we risk the source tensor being deleted before the transfer is done.
  • Rename TransferToServer since there's no Server. Reverted to make review easier. Will rebase commit into follow-up PR.
  • Re-enable and update PJRT client unit test, since it's the best unit test to cover this change.

@will-cromar will-cromar changed the title Remove populate_fn from TensorSource Transfer data directly to the device Nov 1, 2023
@will-cromar
Copy link
Copy Markdown
Collaborator Author

This very slightly improves performance on ResNet50 with fake data on TPU v4 by removing synchronous overhead. Key metrics after epoch 1:

Before:

Metric: DeviceLockWait
  TotalSamples: 2582
  Accumulator: 01s292ms557.601us
  ValueRate: 007ms301.380us / second
  Rate: 13.7858 / second
  Percentiles: 1%=000.880us; 5%=001.049us; 10%=001.240us; 20%=001.550us; 50%=005.870us; 80%=008.320us; 90%=002ms368.870us; 95%=005ms894.689us; 99%=006ms147.740us
...
Metric: TransferToServerTime
  TotalSamples: 1951
  Accumulator: 20s711ms722.341us
  ValueRate: 109ms777.976us / second
  Rate: 7.31885 / second
  Percentiles: 1%=065.640us; 5%=070.460us; 10%=073.750us; 20%=080.750us; 50%=094.750us; 80%=061ms422.564us; 90%=066ms168.024us; 95%=068ms557.643us; 99%=070ms677.393us

After:

Metric: DeviceLockWait
  TotalSamples: 2582
  Accumulator: 794ms078.846us
  ValueRate: 008ms619.516us / second
  Rate: 13.8462 / second
  Percentiles: 1%=001.020us; 5%=001.180us; 10%=001.310us; 20%=001.449us; 50%=005.920us; 80%=007.830us; 90%=002ms445.669us; 95%=005ms409.450us; 99%=006ms294.320us
...
Metric: TransferToDeviceTime
  TotalSamples: 1951
  Accumulator: 193ms746.667us
  ValueRate: 001ms032.600us / second
  Rate: 7.3311 / second
  Percentiles: 1%=047.810us; 5%=051.320us; 10%=057.360us; 20%=061.490us; 50%=074.340us; 80%=107.620us; 90%=123.420us; 95%=142.660us; 99%=197.790us

Note the drop from >20 seconds to ~200 ms in TransferToServerTime. In practice, this shaves off about a half second of waiting since ResNet50 is not actually transfer-bound.

@will-cromar will-cromar requested a review from JackCaoG November 1, 2023 20:45
@will-cromar will-cromar marked this pull request as ready for review November 1, 2023 20:45
Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Thanks @will-cromar

@will-cromar will-cromar merged commit 5ca36cf into master Nov 2, 2023
will-cromar added a commit that referenced this pull request Nov 2, 2023
will-cromar added a commit that referenced this pull request Nov 2, 2023
will-cromar added a commit that referenced this pull request Nov 6, 2023
* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.
will-cromar added a commit that referenced this pull request Nov 8, 2023
* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.
will-cromar added a commit that referenced this pull request Nov 9, 2023
* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.
@yeounoh
Copy link
Copy Markdown
Contributor

yeounoh commented Nov 9, 2023

Just learning about this change, this is great feat -- thanks @will-cromar !

will-cromar added a commit that referenced this pull request Nov 13, 2023
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
mbzomowski pushed a commit to mbzomowski-test-org/xla that referenced this pull request Nov 16, 2023
* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.
mbzomowski pushed a commit to mbzomowski-test-org/xla that referenced this pull request Nov 16, 2023
mbzomowski pushed a commit to mbzomowski-test-org/xla that referenced this pull request Nov 16, 2023
* Transfer data directly to the device (pytorch#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
zpcore pushed a commit that referenced this pull request Nov 21, 2023
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
lsy323 pushed a commit to lsy323/xla that referenced this pull request Nov 28, 2023
* Transfer data directly to the device (pytorch#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
* Transfer data directly to the device (pytorch#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
golechwierowicz pushed a commit that referenced this pull request Jan 12, 2024
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants