Skip to content

Enable PJRT C API with libtpu-nightly#4400

Merged
will-cromar merged 6 commits intomasterfrom
wcromar/pjrt-c-api-fixes
Jan 18, 2023
Merged

Enable PJRT C API with libtpu-nightly#4400
will-cromar merged 6 commits intomasterfrom
wcromar/pjrt-c-api-fixes

Conversation

@will-cromar
Copy link
Copy Markdown
Collaborator

@will-cromar will-cromar commented Dec 28, 2022

As of our last TF update, the libtpu-nightly package contains the new PJRT C API.

  • Use GetHloModules now that it's implemented
  • Switch ExecuteSharded to use on_device_shape to match ExecuteReplicated
  • Use xla::PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes now that it's implemented
  • Temporary hack: copy TimedSection pointer to output buffers' OnReady callbacks. Futures with ExecuteSharded are already implemented and will be available in the next libtpu update (thanks @jyingl3!)
  • Add a TPU test for the ExecuteTime metric to make sure I'm using the callback correctly.
  • Explicitly load TPU PJRT plugin to address a TODO in TensorFlow

Tested manually with test_experimental_pjrt_tpu.py and test_train_mp_imagenet.py on TPU v4.

@will-cromar will-cromar force-pushed the wcromar/pjrt-c-api-fixes branch from 394fa65 to c9bcb09 Compare January 12, 2023 21:21
@will-cromar will-cromar marked this pull request as ready for review January 13, 2023 21:09
@will-cromar will-cromar requested a review from JackCaoG January 13, 2023 21:09
return execute_time_ns

def test_execute_time_metric(self):
results = pjrt._run_multiprocess(self._execute_time_metric)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this test the one added by Goran?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Goran added a similar test for CPU here:

xla/test/test_metrics.py

Lines 166 to 186 in 9c01c3c

@unittest.skipIf(
xm.get_xla_supported_devices("GPU") or
xm.get_xla_supported_devices("TPU"), f"This test only works on CPU.")
def test_execute_time_metric(self):
# Initialize the client before starting the timer.
xm.xla_device()
begin = time.perf_counter_ns()
value = torch.randn(
10000, 10000, device=xm.xla_device()) * torch.randn(
10000, 10000, device=xm.xla_device())
value_mean = value.mean()
xm.mark_step()
cpu_value = value_mean.cpu()
wall_time_ns = time.perf_counter_ns() - begin
self.assertIn("ExecuteTime", met.metric_names())
execute_time_ns = met.metric_data('ExecuteTime')[1]
# Execution time should be the bulk of the wall time.
# Ensures that the metric does not measure the execution
# of `ExecuteComputation`, but the actual async time.
self.assertGreater(execute_time_ns, .5 * wall_time_ns)

Making that test work for TPU is tricky because the overhead is way higher on TPU, so the execute time is a small fraction of the overall wall time. I did check that this new test does fail if I revert our fix to the execute time metric.

returned_future->OnReady([timed](Status unused) mutable { timed.reset(); });
// TODO(wcromar): Uncomment this when we update past TF commit
// b8f59020ea0e9e6fba0e9c5e7be88271703eaf9e
// returned_future->OnReady([timed](Status unused) mutable { timed.reset();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is returned_future->OnReady broken with current pin?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The fix should be included in our next pin update and we can revert this change.

@will-cromar will-cromar force-pushed the wcromar/pjrt-c-api-fixes branch from e53f499 to fc72c4f Compare January 17, 2023 22:12
@will-cromar will-cromar merged commit c1a9879 into master Jan 18, 2023
ManfeiBai pushed a commit that referenced this pull request Jan 19, 2023
* Update PjRtComputationClient for latest PJRT C API

* Load PJRT plugin

* Add TPU version of metric timing test

* Add commit hash

* Cleanup

* Formatting
ManfeiBai pushed a commit that referenced this pull request Jan 19, 2023
* Update PjRtComputationClient for latest PJRT C API

* Load PJRT plugin

* Add TPU version of metric timing test

* Add commit hash

* Cleanup

* Formatting
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants