[SPMD] Support manual sharding#6915
Conversation
jonb377
left a comment
There was a problem hiding this comment.
Interesting! I've been wondering what manual sharding is for.
| // The the returned tensors will be in 1:1 correspondence with the `devices` | ||
| // vector, so the `i`th result will belong on the `i`th device. | ||
| // the `tile_assignment`; MANUAL sharding result in shards where only the | ||
| // first device holds the full data; the returned tensor shards vector is |
There was a problem hiding this comment.
only the first device holds the full data
Is this by definition of manual sharding?
There was a problem hiding this comment.
This is not by definition, but by our implementation choice. A more proper example would be a list of tensors (DTensor), where each tensor is an individual full shard.
There was a problem hiding this comment.
Per our offline discussion, abstain from manual sharding on input data.
| result.reserve(cpu_shards.size() / shards_per_tensor); | ||
| for (int i = 0; i < cpu_shards.size(); i += shards_per_tensor) { | ||
| std::vector<at::Tensor> cpu_shards = | ||
| XlaDataToTensors(WrapXlaData(shard_handles), element_types); |
There was a problem hiding this comment.
Calling XlaDataToTensors on each tensor individually will slow down d2h transfers for async checkpointing, since PjRt won't be able to fully utilize transfer parallelization.
Do we expect manually-sharded tensors to contain actual device data generally, or will they usually be IR? If just IR, maybe we can add an assertion to prevent access here.
There was a problem hiding this comment.
I rather keep it functional for both cases -- shouldn't it be asynchronous anyway, not blocking the actual training run?
There was a problem hiding this comment.
This is interesting. I was not aware this performance optimization...
| } else if ((sharding.type() == xla::OpSharding::MANUAL)) { | ||
| // Just put the full tensor on the first device. | ||
| shards[0] = tensor; | ||
| shards.resize(1); |
There was a problem hiding this comment.
How does this work for a compuatation, since we need to feed each device some input data?
e.g. based on your unit test, what happens if we run:
x = torch.randn(3, 2)
xx = x.to(xm.xla_device()) # xx is device data
xt = xs._mark_manual_sharding(xx)
ones = torch.ones(3, 2).to(xm.xla_device()) # ones is replicated to all devices
print(xt + ones) # What will happen here?There was a problem hiding this comment.
XLA should assume that xt is sharded manually, so expected to be plicated as well. The purpose of MANUAL is to support custom kernel and prevent XLA to override the manual sharding.
There was a problem hiding this comment.
Good question. I would expect it behaves as a single device. Let me double check as well.
| xt = xs._mark_manual_sharding(xx) | ||
|
|
||
| hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt.global_tensor]) | ||
| self.assertIn('parameter(0), sharding={manual}', hlo) |
|
|
||
| self.assertEqual(id(mesh), id(expected_mesh)) | ||
|
|
||
| def test__mark_manual_sharding(self): |
There was a problem hiding this comment.
nit. even though it's testing the _ prefixed api, let's keep it as test_mark_manual_sharding
|
Here is the new TPU CI run: https://github.com/pytorch/xla/actions/runs/8652176761 |
|
TPU test here: https://github.com/pytorch/xla/actions/runs/8654305716 |
|
All tests passed. I'm going to merge it. Let me know if I need to follow up on anything. |
Summary: This pull request makes SPMD support the manual sharding type via a new private API called: _mark_manual_sharding. I don't expect users will need to call this function explicitly. Besides adding support for the sharding annotation, we also need to define the behavior of the data shards. For data, the current behavior is error out. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test__mark_manual_sharding
Summary: This pull request makes SPMD support the manual sharding type via a new private API called: _mark_manual_sharding. I don't expect users will need to call this function explicitly. Besides adding support for the sharding annotation, we also need to define the behavior of the data shards. For data, the current behavior is error out. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test__mark_manual_sharding
Summary:
This pull request makes SPMD support the manual sharding type via a new private API called: _mark_manual_sharding. I don't expect users will need to call this function explicitly.
Besides adding support for the sharding annotation, we also need to define the behavior of the data shards. For data, the current behavior is error out.
Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test__mark_manual_sharding