Skip to content

Commit 4301818

Browse files
kwen2501pytorchmergebot
authored andcommitted
[SymmMem] Back symm_mem.emtpy() with implicit pool (pytorch#172292)
Resolves pytorch#172050 Two motivations: - Give better UX and perf to users who explicitly use `symm_mem.empty()`. - Simplify the code generated by Inductor, i.e. `symm_mem.empty()` would automatically reuse memory, rather than requiring Inductor to bookkeep it. The MemPool infra for all CUDA backends (`CUDA`, `NVSHMEM`, `NCCL`) has been built previously. Pull Request resolved: pytorch#172292 Approved by: https://github.com/ngimel, https://github.com/dzmitry-huba ghstack dependencies: pytorch#172163
1 parent 4c2c83b commit 4301818

1 file changed

Lines changed: 33 additions & 6 deletions

File tree

torch/distributed/_symmetric_memory/__init__.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,6 +1879,26 @@ def _all_to_all_vdev_2d_offset_meta(
18791879
from torch.types import _device, _dtype, _int
18801880

18811881

1882+
_use_implicit_mempool: bool | None = None # type: ignore[assignment]
1883+
1884+
1885+
def _should_use_implicit_mempool() -> bool:
1886+
r"""
1887+
Check if the implicit memory pool should be used for symmetric memory allocations.
1888+
1889+
Returns:
1890+
bool: True if the implicit memory pool should be used, False otherwise.
1891+
1892+
By default, use implicit memory pool for `symm_mem.empty`. Users can
1893+
disable this by setting the environment variable `TORCH_SYMMMEM_IMPLICIT_POOL` to `0`.
1894+
"""
1895+
global _use_implicit_mempool
1896+
if _use_implicit_mempool is None:
1897+
_use_implicit_mempool = os.getenv("TORCH_SYMMMEM_IMPLICIT_POOL", "1") == "1"
1898+
1899+
return _use_implicit_mempool
1900+
1901+
18821902
@overload
18831903
def empty(
18841904
*size: _int, dtype: _dtype | None = None, device: _device | None = None
@@ -1927,13 +1947,20 @@ def empty( # type: ignore[misc]
19271947

19281948
if device is None:
19291949
device = torch.get_default_device()
1950+
else:
1951+
device = torch.device(device)
19301952

1931-
return _SymmetricMemory.empty_strided_p2p(
1932-
size=size,
1933-
stride=torch._prims_common.make_contiguous_strides_for(size),
1934-
dtype=dtype,
1935-
device=torch.device(device),
1936-
)
1953+
stride = torch._prims_common.make_contiguous_strides_for(size)
1954+
1955+
if _should_use_implicit_mempool() and device.type == "cuda":
1956+
# Allocate tensor from an implicit memory pool
1957+
mempool = get_mem_pool(device)
1958+
# TODO: this path can be made device-agnostic if `use_mem_pool` is
1959+
# elevated from torch.cuda to torch accelerator.
1960+
with torch.cuda.use_mem_pool(mempool):
1961+
return _SymmetricMemory.empty_strided_p2p(size, stride, dtype, device)
1962+
else:
1963+
return _SymmetricMemory.empty_strided_p2p(size, stride, dtype, device)
19371964

19381965

19391966
def rendezvous(

0 commit comments

Comments
 (0)