Skip to content

Commit dfda239

Browse files
feginpytorchmergebot
authored andcommitted
[DTensor] Raise an RuntimeError when checkpointing APIs are used with Partial placement (#163941)
A DTensor that contains partial placement shouldn't be checkpointed (DCP.save) -- the result is not correct and DCP doesn't know how to handle it. There are several APIs that are only used by checkpointing, e.g.,`__create_write_items__`. These APIs should raise an exception if the DTensor, `self`, has Partial placement. Ideally, we want to add the following test: ``` with self.assertRaisesRegex( RuntimeError, "Any checkpointing related operations are not supported for" ): dcp.save({"dtensor": dtensor}, checkpoint_id=tempfile.gettempdir()) ``` While we do see the RuntimeError is raised, the error was raised in another thread due to DTensor checkpoint APIs are called by DCP in a separate thread, which assertRaisesRegex cannot capture. Pull Request resolved: #163941 Approved by: https://github.com/tianyu-l
1 parent 991e3d0 commit dfda239

2 files changed

Lines changed: 47 additions & 0 deletions

File tree

test/distributed/tensor/test_api.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates
22
# Owner(s): ["oncall: distributed"]
33

4+
import tempfile
5+
46
import torch
7+
import torch.distributed.checkpoint as dcp
58
import torch.nn as nn
69
from torch.distributed.tensor import (
710
DeviceMesh,
811
distribute_module,
912
distribute_tensor,
1013
DTensor,
14+
Partial,
1115
Replicate,
1216
Shard,
1317
)
@@ -356,6 +360,33 @@ def shard_fn(name, module, device_mesh):
356360
self.assertFalse(param.is_meta)
357361
self.assertTrue(param.device.type == device_mesh.device_type)
358362

363+
@with_comms
364+
def test_checkpoint_apis_check_partial_placement(self):
365+
device_mesh = self.build_device_mesh()
366+
tensor = torch.randn(5, 5, device=self.device_type)
367+
dtensor = DTensor.from_local(tensor, device_mesh, [Partial()])
368+
with self.assertRaisesRegex(
369+
ValueError, "Any checkpointing related operations are not supported for"
370+
):
371+
dtensor.__create_write_items__("fqn", None)
372+
373+
with self.assertRaisesRegex(
374+
ValueError, "Any checkpointing related operations are not supported for"
375+
):
376+
dtensor.__create_chunk_list__()
377+
378+
with self.assertRaisesRegex(
379+
ValueError, "Any checkpointing related operations are not supported for"
380+
):
381+
dtensor.__get_tensor_shard__(0)
382+
383+
# Ideally we should not allow checkpointing related operations for DTensor
384+
with self.assertRaisesRegex(
385+
dcp.api.CheckpointException,
386+
"Any checkpointing related operations are not supported for",
387+
):
388+
dcp.save({"fqn": dtensor}, checkpoint_id=tempfile.mkdtemp())
389+
359390

360391
if __name__ == "__main__":
361392
run_tests()

torch/distributed/tensor/_api.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,21 @@ def placements(self) -> tuple[Placement, ...]:
592592
"""
593593
return self._spec.placements
594594

595+
def _raise_if_contains_partial_placements(self) -> None:
596+
"""
597+
Raise an error if the DTensor contains partial placements.
598+
"""
599+
for placement in self._spec.placements:
600+
if not isinstance(placement, Partial):
601+
continue
602+
603+
raise ValueError(
604+
"Any checkpointing related operations are not supported for "
605+
"DTensor with partial placements!"
606+
)
607+
595608
def __create_write_items__(self, fqn: str, object: Any):
609+
self._raise_if_contains_partial_placements()
596610
from torch.distributed.checkpoint.planner_helpers import (
597611
_create_write_items_for_dtensor,
598612
)
@@ -615,6 +629,7 @@ def __create_chunk_list__(self):
615629
Returns:
616630
A List[:class:`ChunkStorageMetadata`] object that represents the shard size/offset on the current rank.
617631
"""
632+
self._raise_if_contains_partial_placements()
618633
from torch.distributed.checkpoint.planner_helpers import (
619634
_create_chunk_from_dtensor,
620635
)
@@ -627,6 +642,7 @@ def __create_chunk_list__(self):
627642
raise RuntimeError("Unsupported tensor type!")
628643

629644
def __get_tensor_shard__(self, index):
645+
self._raise_if_contains_partial_placements()
630646
if hasattr(self._local_tensor, "__get_tensor_shard__"):
631647
return self._local_tensor.__get_tensor_shard__(index) # type: ignore[attr-defined]
632648
elif isinstance(self._local_tensor, torch.Tensor):

0 commit comments

Comments
 (0)