Skip to content

Commit 5f287df

Browse files
lwpytorchmergebot
authored andcommitted
Add type information for FakeProcessGroup (#133211)
Pull Request resolved: #133211 Approved by: https://github.com/Skylion007
1 parent e557444 commit 5f287df

3 files changed

Lines changed: 14 additions & 4 deletions

File tree

torch/_C/_distributed_c10d.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,9 @@ class ProcessGroup:
520520
@property
521521
def group_desc(self) -> str: ...
522522

523+
class FakeProcessGroup(Backend):
524+
def __init__(self, rank: int, world_size: int) -> None: ...
525+
523526
class ProcessGroupGloo(Backend):
524527
class Device: ...
525528

torch/csrc/distributed/c10d/init.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3151,9 +3151,13 @@ such as `dist.all_reduce(tensor, async_op=True)`.
31513151
auto fakeProcessGroup =
31523152
intrusive_ptr_no_gil_destructor_class_<::c10d::FakeProcessGroup>(
31533153
module, "FakeProcessGroup", backend)
3154-
.def(py::init([](int rank, int size) {
3155-
return c10::make_intrusive<::c10d::FakeProcessGroup>(rank, size);
3156-
}));
3154+
.def(
3155+
py::init([](int rank, int size) {
3156+
return c10::make_intrusive<::c10d::FakeProcessGroup>(
3157+
rank, size);
3158+
}),
3159+
py::arg("rank"),
3160+
py::arg("world_size"));
31573161

31583162
py::class_<c10::DDPLoggingData>(module, "DDPLoggingData")
31593163
.def(py::init<>())

torch/distributed/fsdp/_flat_param.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1439,7 +1439,10 @@ def _all_gather_flat_param(
14391439
# HACK this should be handled by C10D
14401440
if sharded_flat_param.is_cpu: # type: ignore[attr-defined]
14411441
tensor_list = list(
1442-
torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg))
1442+
torch.chunk(
1443+
padded_unsharded_flat_param,
1444+
dist.get_world_size(pg), # type: ignore[arg-type]
1445+
)
14431446
)
14441447
dist.all_gather(tensor_list, sharded_flat_param, group=pg)
14451448
else:

0 commit comments

Comments
 (0)