Skip to content

[Inductor] introduce comm buffer planning#138519

Open
yifuwang wants to merge 5 commits intogh/yifuwang/152/basefrom
gh/yifuwang/152/head
Open

[Inductor] introduce comm buffer planning#138519
yifuwang wants to merge 5 commits intogh/yifuwang/152/basefrom
gh/yifuwang/152/head

Conversation

@yifuwang
Copy link
Copy Markdown
Collaborator

@yifuwang yifuwang commented Oct 21, 2024

Stack from ghstack (oldest at bottom):

NOTE [comm buffer planning]

This file contains the memory planning logic for comm buffers. Compared to
regular buffer planning, Inductor leverages the allocator's "persistent
allocation" capability to meet the stringent requirements for registered
buffer communication (see NOTE [lowering-time collective optimization] for
details regarding the requirements).

Comm buffer planning is a collaborative process between Inductor and the
allocator. Inductor is responsible for planning comm buffers within a graph.
For each (comm_buffer_type, group_name) pair, Inductor uses a single,
persistently allocated pool to fulfill all comm buffer allocations. The
allocator manages memory reuse across subgroups.

In practice, comm buffer planning differs from regular buffer planning in the
following ways:

- Comm buffers can't use memory from regular pools, and non-comm buffers
shouldn't use memory from comm buffer pools [1]. This means that (1) comm
buffers are planned in isolation from regular buffers and (2) comm buffers
don't participate in in-place reuse (this simplifies the logic).
- To allow for memory reuse for persistent allocations across subgraphs,
Inductor needs to "free" the pool before exiting each subgraph. This means
that comm buffers cannot be graph outputs (this simplifies the logic).
- Comm buffer pools are allocated with dedicated allocators.
- Comm buffers for different (comm_buffer_type, group_name) pairs are planned
separately in an isolated fashion.

For comm buffer planning, we reuse most of the fundamental logic from regular
buffer planning. To accommodate the above differences, we use the
`CommBufferLine` hierarchy, which resembles the `MemoryPlanningLine`
hierarchy, to represent comm buffer allocations at codegen time. We use the
`CommBufferPlanner` to carry out comm buffer planning in an isolated fashion.

Ideally, the comm buffer planning logic could be further consolidated with
the existing buffer planning logic. For the time being, since the two code
paths differ in maturity, we prefer isolation at the cost of some divergence.

[1] Allowing non-comm buffers to use memory from comm buffer pools may
unnecessarily increase the size of persistently allocated memory. In the
future, we can optimize memory usage by performing comm buffer planning
first, then letting regular buffer planning leverage the free live ranges
from the comm buffer pools.

Example Graph

def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (4, 4), (4, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        symm_mem_pool_0 = empty_strided_p2p((192, ), (1, ), torch.uint8, torch.device("cuda:0"), group_name="0", alloc_id=13423167936938864713)
        buf0 = alloc_from_pool(symm_mem_pool_0, 128, torch.float32, (4, 4), (4, 1))
        buf3 = alloc_from_pool(symm_mem_pool_0, 64, torch.float32, (4, 4), (4, 1))
        buf6 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
        # Topologically Sorted Source Nodes: [a, b, c], Original ATen: [aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_0.run(arg0_1, buf0, buf3, buf6, 16, grid=grid(16), stream=stream0)
        del arg0_1
        # Topologically Sorted Source Nodes: [a, a_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf1 = torch.ops.symm_mem.one_shot_all_reduce.default(buf0, 'sum', '0')
        buf2 = buf1
        del buf1
        # Topologically Sorted Source Nodes: [b, b_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf4 = torch.ops.symm_mem.one_shot_all_reduce.default(buf3, 'sum', '0')
        buf5 = buf4
        del buf4
        # Topologically Sorted Source Nodes: [c, c_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf7 = torch.ops.symm_mem.one_shot_all_reduce.default(buf6, 'sum', '0')
        buf8 = buf7
        del buf7
        buf9 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
        # Topologically Sorted Source Nodes: [add_3, d], Original ATen: [aten.add]
        triton_poi_fused_add_1.run(buf9, buf5, buf8, 16, grid=grid(16), stream=stream0)
        del buf2
        del buf5
        del buf8
        # Topologically Sorted Source Nodes: [add_3, d, d_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf10 = torch.ops.symm_mem.one_shot_all_reduce.default(buf9, 'sum', '0')
        del symm_mem_pool_0, buf0, buf3, buf6, buf9
        buf11 = buf10
        del buf10
    return (buf11, )

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Oct 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138519

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 49456c3 with merge base 5ea6777 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Oct 21, 2024
yifuwang pushed a commit that referenced this pull request Oct 21, 2024
ghstack-source-id: 90b4868
Pull Request resolved: #138519
@yifuwang yifuwang requested a review from Chillee October 21, 2024 23:07
@yifuwang yifuwang added the topic: not user facing topic category label Oct 21, 2024
Copy link
Copy Markdown
Collaborator

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks cool!

Haven't taken a close look yet, but how does this handle buffer allocation across layers? It just doesn't share the memory?

```
NOTE [comm buffer planning]

This file contains the memory planning logic for comm buffers. Compared to
regular buffer planning, Inductor leverages the allocator's "persistent
allocation" capability to meet the stringent requirements for registered
buffer communication (see NOTE [lowering-time collective optimization] for
details regarding the requirements).

Comm buffer planning is a collaborative process between Inductor and the
allocator. Inductor is responsible for planning comm buffers within a graph.
For each (comm_buffer_type, group_name) pair, Inductor uses a single,
persistently allocated pool to fulfill all comm buffer allocations. The
allocator manages memory reuse across subgroups.

In practice, comm buffer planning differs from regular buffer planning in the
following ways:

- Comm buffers can't use memory from regular pools, and non-comm buffers
shouldn't use memory from comm buffer pools [1]. This means that (1) comm
buffers are planned in isolation from regular buffers and (2) comm buffers
don't participate in in-place reuse (this simplifies the logic).
- To allow for memory reuse for persistent allocations across subgraphs,
Inductor needs to "free" the pool before exiting each subgraph. This means
that comm buffers cannot be graph outputs (this simplifies the logic).
- Comm buffer pools are allocated with dedicated allocators.
- Comm buffers for different (comm_buffer_type, group_name) pairs are planned
separately in an isolated fashion.

For comm buffer planning, we reuse most of the fundamental logic from regular
buffer planning. To accommodate the above differences, we use the
`CommBufferLine` hierarchy, which resembles the `MemoryPlanningLine`
hierarchy, to represent comm buffer allocations at codegen time. We use the
`CommBufferPlanner` to carry out comm buffer planning in an isolated fashion.

Ideally, the comm buffer planning logic could be further consolidated with
the existing buffer planning logic. For the time being, since the two code
paths differ in maturity, we prefer isolation at the cost of some divergence.

[1] Allowing non-comm buffers to use memory from comm buffer pools may
unnecessarily increase the size of persistently allocated memory. In the
future, we can optimize memory usage by performing comm buffer planning
first, then letting regular buffer planning leverage the free live ranges
from the comm buffer pools.
```

### Example Graph

```python
def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (4, 4), (4, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        symm_mem_pool_0 = empty_strided_p2p((192, ), (1, ), torch.uint8, torch.device("cuda:0"), group_name="0", alloc_id=13423167936938864713)
        buf0 = alloc_from_pool(symm_mem_pool_0, 128, torch.float32, (4, 4), (4, 1))
        buf3 = alloc_from_pool(symm_mem_pool_0, 64, torch.float32, (4, 4), (4, 1))
        buf6 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
        # Topologically Sorted Source Nodes: [a, b, c], Original ATen: [aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_0.run(arg0_1, buf0, buf3, buf6, 16, grid=grid(16), stream=stream0)
        del arg0_1
        # Topologically Sorted Source Nodes: [a, a_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf1 = torch.ops.symm_mem.one_shot_all_reduce.default(buf0, 'sum', '0')
        buf2 = buf1
        del buf1
        # Topologically Sorted Source Nodes: [b, b_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf4 = torch.ops.symm_mem.one_shot_all_reduce.default(buf3, 'sum', '0')
        buf5 = buf4
        del buf4
        # Topologically Sorted Source Nodes: [c, c_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf7 = torch.ops.symm_mem.one_shot_all_reduce.default(buf6, 'sum', '0')
        buf8 = buf7
        del buf7
        buf9 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
        # Topologically Sorted Source Nodes: [add_3, d], Original ATen: [aten.add]
        triton_poi_fused_add_1.run(buf9, buf5, buf8, 16, grid=grid(16), stream=stream0)
        del buf2
        del buf5
        del buf8
        # Topologically Sorted Source Nodes: [add_3, d, d_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf10 = torch.ops.symm_mem.one_shot_all_reduce.default(buf9, 'sum', '0')
        del symm_mem_pool_0, buf0, buf3, buf6, buf9
        buf11 = buf10
        del buf10
    return (buf11, )
```

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
yifuwang pushed a commit that referenced this pull request Oct 22, 2024
ghstack-source-id: ec6863c
Pull Request resolved: #138519
```
NOTE [comm buffer planning]

This file contains the memory planning logic for comm buffers. Compared to
regular buffer planning, Inductor leverages the allocator's "persistent
allocation" capability to meet the stringent requirements for registered
buffer communication (see NOTE [lowering-time collective optimization] for
details regarding the requirements).

Comm buffer planning is a collaborative process between Inductor and the
allocator. Inductor is responsible for planning comm buffers within a graph.
For each (comm_buffer_type, group_name) pair, Inductor uses a single,
persistently allocated pool to fulfill all comm buffer allocations. The
allocator manages memory reuse across subgroups.

In practice, comm buffer planning differs from regular buffer planning in the
following ways:

- Comm buffers can't use memory from regular pools, and non-comm buffers
shouldn't use memory from comm buffer pools [1]. This means that (1) comm
buffers are planned in isolation from regular buffers and (2) comm buffers
don't participate in in-place reuse (this simplifies the logic).
- To allow for memory reuse for persistent allocations across subgraphs,
Inductor needs to "free" the pool before exiting each subgraph. This means
that comm buffers cannot be graph outputs (this simplifies the logic).
- Comm buffer pools are allocated with dedicated allocators.
- Comm buffers for different (comm_buffer_type, group_name) pairs are planned
separately in an isolated fashion.

For comm buffer planning, we reuse most of the fundamental logic from regular
buffer planning. To accommodate the above differences, we use the
`CommBufferLine` hierarchy, which resembles the `MemoryPlanningLine`
hierarchy, to represent comm buffer allocations at codegen time. We use the
`CommBufferPlanner` to carry out comm buffer planning in an isolated fashion.

Ideally, the comm buffer planning logic could be further consolidated with
the existing buffer planning logic. For the time being, since the two code
paths differ in maturity, we prefer isolation at the cost of some divergence.

[1] Allowing non-comm buffers to use memory from comm buffer pools may
unnecessarily increase the size of persistently allocated memory. In the
future, we can optimize memory usage by performing comm buffer planning
first, then letting regular buffer planning leverage the free live ranges
from the comm buffer pools.
```

### Example Graph

```python
def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (4, 4), (4, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        symm_mem_pool_0 = empty_strided_p2p((192, ), (1, ), torch.uint8, torch.device("cuda:0"), group_name="0", alloc_id=13423167936938864713)
        buf0 = alloc_from_pool(symm_mem_pool_0, 128, torch.float32, (4, 4), (4, 1))
        buf3 = alloc_from_pool(symm_mem_pool_0, 64, torch.float32, (4, 4), (4, 1))
        buf6 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
        # Topologically Sorted Source Nodes: [a, b, c], Original ATen: [aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_0.run(arg0_1, buf0, buf3, buf6, 16, grid=grid(16), stream=stream0)
        del arg0_1
        # Topologically Sorted Source Nodes: [a, a_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf1 = torch.ops.symm_mem.one_shot_all_reduce.default(buf0, 'sum', '0')
        buf2 = buf1
        del buf1
        # Topologically Sorted Source Nodes: [b, b_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf4 = torch.ops.symm_mem.one_shot_all_reduce.default(buf3, 'sum', '0')
        buf5 = buf4
        del buf4
        # Topologically Sorted Source Nodes: [c, c_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf7 = torch.ops.symm_mem.one_shot_all_reduce.default(buf6, 'sum', '0')
        buf8 = buf7
        del buf7
        buf9 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
        # Topologically Sorted Source Nodes: [add_3, d], Original ATen: [aten.add]
        triton_poi_fused_add_1.run(buf9, buf5, buf8, 16, grid=grid(16), stream=stream0)
        del buf2
        del buf5
        del buf8
        # Topologically Sorted Source Nodes: [add_3, d, d_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf10 = torch.ops.symm_mem.one_shot_all_reduce.default(buf9, 'sum', '0')
        del symm_mem_pool_0, buf0, buf3, buf6, buf9
        buf11 = buf10
        del buf10
    return (buf11, )
```

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
```
NOTE [comm buffer planning]

This file contains the memory planning logic for comm buffers. Compared to
regular buffer planning, Inductor leverages the allocator's "persistent
allocation" capability to meet the stringent requirements for registered
buffer communication (see NOTE [lowering-time collective optimization] for
details regarding the requirements).

Comm buffer planning is a collaborative process between Inductor and the
allocator. Inductor is responsible for planning comm buffers within a graph.
For each (comm_buffer_type, group_name) pair, Inductor uses a single,
persistently allocated pool to fulfill all comm buffer allocations. The
allocator manages memory reuse across subgroups.

In practice, comm buffer planning differs from regular buffer planning in the
following ways:

- Comm buffers can't use memory from regular pools, and non-comm buffers
shouldn't use memory from comm buffer pools [1]. This means that (1) comm
buffers are planned in isolation from regular buffers and (2) comm buffers
don't participate in in-place reuse (this simplifies the logic).
- To allow for memory reuse for persistent allocations across subgraphs,
Inductor needs to "free" the pool before exiting each subgraph. This means
that comm buffers cannot be graph outputs (this simplifies the logic).
- Comm buffer pools are allocated with dedicated allocators.
- Comm buffers for different (comm_buffer_type, group_name) pairs are planned
separately in an isolated fashion.

For comm buffer planning, we reuse most of the fundamental logic from regular
buffer planning. To accommodate the above differences, we use the
`CommBufferLine` hierarchy, which resembles the `MemoryPlanningLine`
hierarchy, to represent comm buffer allocations at codegen time. We use the
`CommBufferPlanner` to carry out comm buffer planning in an isolated fashion.

Ideally, the comm buffer planning logic could be further consolidated with
the existing buffer planning logic. For the time being, since the two code
paths differ in maturity, we prefer isolation at the cost of some divergence.

[1] Allowing non-comm buffers to use memory from comm buffer pools may
unnecessarily increase the size of persistently allocated memory. In the
future, we can optimize memory usage by performing comm buffer planning
first, then letting regular buffer planning leverage the free live ranges
from the comm buffer pools.
```

### Example Graph

```python
def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (4, 4), (4, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        symm_mem_pool_0 = empty_strided_p2p((192, ), (1, ), torch.uint8, torch.device("cuda:0"), group_name="0", alloc_id=13423167936938864713)
        buf0 = alloc_from_pool(symm_mem_pool_0, 128, torch.float32, (4, 4), (4, 1))
        buf3 = alloc_from_pool(symm_mem_pool_0, 64, torch.float32, (4, 4), (4, 1))
        buf6 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
        # Topologically Sorted Source Nodes: [a, b, c], Original ATen: [aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_0.run(arg0_1, buf0, buf3, buf6, 16, grid=grid(16), stream=stream0)
        del arg0_1
        # Topologically Sorted Source Nodes: [a, a_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf1 = torch.ops.symm_mem.one_shot_all_reduce.default(buf0, 'sum', '0')
        buf2 = buf1
        del buf1
        # Topologically Sorted Source Nodes: [b, b_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf4 = torch.ops.symm_mem.one_shot_all_reduce.default(buf3, 'sum', '0')
        buf5 = buf4
        del buf4
        # Topologically Sorted Source Nodes: [c, c_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf7 = torch.ops.symm_mem.one_shot_all_reduce.default(buf6, 'sum', '0')
        buf8 = buf7
        del buf7
        buf9 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
        # Topologically Sorted Source Nodes: [add_3, d], Original ATen: [aten.add]
        triton_poi_fused_add_1.run(buf9, buf5, buf8, 16, grid=grid(16), stream=stream0)
        del buf2
        del buf5
        del buf8
        # Topologically Sorted Source Nodes: [add_3, d, d_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf10 = torch.ops.symm_mem.one_shot_all_reduce.default(buf9, 'sum', '0')
        del symm_mem_pool_0, buf0, buf3, buf6, buf9
        buf11 = buf10
        del buf10
    return (buf11, )
```

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
yifuwang pushed a commit that referenced this pull request Oct 24, 2024
ghstack-source-id: c9f4730
Pull Request resolved: #138519
```
NOTE [comm buffer planning]

This file contains the memory planning logic for comm buffers. Compared to
regular buffer planning, Inductor leverages the allocator's "persistent
allocation" capability to meet the stringent requirements for registered
buffer communication (see NOTE [lowering-time collective optimization] for
details regarding the requirements).

Comm buffer planning is a collaborative process between Inductor and the
allocator. Inductor is responsible for planning comm buffers within a graph.
For each (comm_buffer_type, group_name) pair, Inductor uses a single,
persistently allocated pool to fulfill all comm buffer allocations. The
allocator manages memory reuse across subgroups.

In practice, comm buffer planning differs from regular buffer planning in the
following ways:

- Comm buffers can't use memory from regular pools, and non-comm buffers
shouldn't use memory from comm buffer pools [1]. This means that (1) comm
buffers are planned in isolation from regular buffers and (2) comm buffers
don't participate in in-place reuse (this simplifies the logic).
- To allow for memory reuse for persistent allocations across subgraphs,
Inductor needs to "free" the pool before exiting each subgraph. This means
that comm buffers cannot be graph outputs (this simplifies the logic).
- Comm buffer pools are allocated with dedicated allocators.
- Comm buffers for different (comm_buffer_type, group_name) pairs are planned
separately in an isolated fashion.

For comm buffer planning, we reuse most of the fundamental logic from regular
buffer planning. To accommodate the above differences, we use the
`CommBufferLine` hierarchy, which resembles the `MemoryPlanningLine`
hierarchy, to represent comm buffer allocations at codegen time. We use the
`CommBufferPlanner` to carry out comm buffer planning in an isolated fashion.

Ideally, the comm buffer planning logic could be further consolidated with
the existing buffer planning logic. For the time being, since the two code
paths differ in maturity, we prefer isolation at the cost of some divergence.

[1] Allowing non-comm buffers to use memory from comm buffer pools may
unnecessarily increase the size of persistently allocated memory. In the
future, we can optimize memory usage by performing comm buffer planning
first, then letting regular buffer planning leverage the free live ranges
from the comm buffer pools.
```

### Example Graph

```python
def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (4, 4), (4, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        symm_mem_pool_0 = empty_strided_p2p((192, ), (1, ), torch.uint8, torch.device("cuda:0"), group_name="0", alloc_id=13423167936938864713)
        buf0 = alloc_from_pool(symm_mem_pool_0, 128, torch.float32, (4, 4), (4, 1))
        buf3 = alloc_from_pool(symm_mem_pool_0, 64, torch.float32, (4, 4), (4, 1))
        buf6 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
        # Topologically Sorted Source Nodes: [a, b, c], Original ATen: [aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_0.run(arg0_1, buf0, buf3, buf6, 16, grid=grid(16), stream=stream0)
        del arg0_1
        # Topologically Sorted Source Nodes: [a, a_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf1 = torch.ops.symm_mem.one_shot_all_reduce.default(buf0, 'sum', '0')
        buf2 = buf1
        del buf1
        # Topologically Sorted Source Nodes: [b, b_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf4 = torch.ops.symm_mem.one_shot_all_reduce.default(buf3, 'sum', '0')
        buf5 = buf4
        del buf4
        # Topologically Sorted Source Nodes: [c, c_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf7 = torch.ops.symm_mem.one_shot_all_reduce.default(buf6, 'sum', '0')
        buf8 = buf7
        del buf7
        buf9 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
        # Topologically Sorted Source Nodes: [add_3, d], Original ATen: [aten.add]
        triton_poi_fused_add_1.run(buf9, buf5, buf8, 16, grid=grid(16), stream=stream0)
        del buf2
        del buf5
        del buf8
        # Topologically Sorted Source Nodes: [add_3, d, d_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
        buf10 = torch.ops.symm_mem.one_shot_all_reduce.default(buf9, 'sum', '0')
        del symm_mem_pool_0, buf0, buf3, buf6, buf9
        buf11 = buf10
        del buf10
    return (buf11, )
```

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
yifuwang pushed a commit that referenced this pull request Oct 25, 2024
ghstack-source-id: 1cbfe12
Pull Request resolved: #138519
#
# Comm buffer planning is a collaborative process between Inductor and the
# allocator. Inductor is responsible for planning comm buffers within a graph.
# For each (comm_buffer_type, group_name) pair, Inductor uses a single,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you define (comm_buffer_type, group_name) pair (or explain what the differerent types are for)? I assume group_name refers to process_group? (also, why do different PGs need to have different allocations? bc we have to tie the registration to a particular ncclcomm?

@github-actions
Copy link
Copy Markdown
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Feb 17, 2025
@kwen2501 kwen2501 added no-stale and removed Stale labels Mar 6, 2025
pytorchmergebot pushed a commit that referenced this pull request Jan 16, 2026
See #162859. This PR adds initial support for symmetric buffers (Comm buffer) in `torch.compile` by realizing comm buffers during Inductor lowering and enabling conservative reuse using the existing `memory_plan_reuse` infrastructure.

* **Comm buffer realization**: Each `torch.ops.symm_mem` operation is lowered to allocate a comm buffer via `empty_strided_p2p`.
* **Layout support**: Relaxes layout restrictions so both `FixedLayout` and `FlexibleLayout` buffers can be realized as comm buffers.
* **Comm buffer reuse**: Comm buffers are reused only when their lifetimes do not overlap and when they share an identical reuse key `(device, dtype, size, comm_buffer_type, group_name)`. To prevent mixing communication buffers with regular CUDA buffers, the memory planner maintains a dedicated comm-buffer reuse pool and routes allocations via a `comm_buffer` flag on existing planning lines, eliminating the need for separate comm-buffer-specific line classes.

More general memory planning (like what’s proposed in #138519) can be a follow-up.

Pull Request resolved: #171909
Approved by: https://github.com/kwen2501, https://github.com/eellison
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants