[Inductor] introduce comm buffer planning#138519
[Inductor] introduce comm buffer planning#138519yifuwang wants to merge 5 commits intogh/yifuwang/152/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138519
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 49456c3 with merge base 5ea6777 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Chillee
left a comment
There was a problem hiding this comment.
Overall, looks cool!
Haven't taken a close look yet, but how does this handle buffer allocation across layers? It just doesn't share the memory?
```
NOTE [comm buffer planning]
This file contains the memory planning logic for comm buffers. Compared to
regular buffer planning, Inductor leverages the allocator's "persistent
allocation" capability to meet the stringent requirements for registered
buffer communication (see NOTE [lowering-time collective optimization] for
details regarding the requirements).
Comm buffer planning is a collaborative process between Inductor and the
allocator. Inductor is responsible for planning comm buffers within a graph.
For each (comm_buffer_type, group_name) pair, Inductor uses a single,
persistently allocated pool to fulfill all comm buffer allocations. The
allocator manages memory reuse across subgroups.
In practice, comm buffer planning differs from regular buffer planning in the
following ways:
- Comm buffers can't use memory from regular pools, and non-comm buffers
shouldn't use memory from comm buffer pools [1]. This means that (1) comm
buffers are planned in isolation from regular buffers and (2) comm buffers
don't participate in in-place reuse (this simplifies the logic).
- To allow for memory reuse for persistent allocations across subgraphs,
Inductor needs to "free" the pool before exiting each subgraph. This means
that comm buffers cannot be graph outputs (this simplifies the logic).
- Comm buffer pools are allocated with dedicated allocators.
- Comm buffers for different (comm_buffer_type, group_name) pairs are planned
separately in an isolated fashion.
For comm buffer planning, we reuse most of the fundamental logic from regular
buffer planning. To accommodate the above differences, we use the
`CommBufferLine` hierarchy, which resembles the `MemoryPlanningLine`
hierarchy, to represent comm buffer allocations at codegen time. We use the
`CommBufferPlanner` to carry out comm buffer planning in an isolated fashion.
Ideally, the comm buffer planning logic could be further consolidated with
the existing buffer planning logic. For the time being, since the two code
paths differ in maturity, we prefer isolation at the cost of some divergence.
[1] Allowing non-comm buffers to use memory from comm buffer pools may
unnecessarily increase the size of persistently allocated memory. In the
future, we can optimize memory usage by performing comm buffer planning
first, then letting regular buffer planning leverage the free live ranges
from the comm buffer pools.
```
### Example Graph
```python
def call(args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (4, 4), (4, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
symm_mem_pool_0 = empty_strided_p2p((192, ), (1, ), torch.uint8, torch.device("cuda:0"), group_name="0", alloc_id=13423167936938864713)
buf0 = alloc_from_pool(symm_mem_pool_0, 128, torch.float32, (4, 4), (4, 1))
buf3 = alloc_from_pool(symm_mem_pool_0, 64, torch.float32, (4, 4), (4, 1))
buf6 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
# Topologically Sorted Source Nodes: [a, b, c], Original ATen: [aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_0.run(arg0_1, buf0, buf3, buf6, 16, grid=grid(16), stream=stream0)
del arg0_1
# Topologically Sorted Source Nodes: [a, a_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf1 = torch.ops.symm_mem.one_shot_all_reduce.default(buf0, 'sum', '0')
buf2 = buf1
del buf1
# Topologically Sorted Source Nodes: [b, b_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf4 = torch.ops.symm_mem.one_shot_all_reduce.default(buf3, 'sum', '0')
buf5 = buf4
del buf4
# Topologically Sorted Source Nodes: [c, c_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf7 = torch.ops.symm_mem.one_shot_all_reduce.default(buf6, 'sum', '0')
buf8 = buf7
del buf7
buf9 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
# Topologically Sorted Source Nodes: [add_3, d], Original ATen: [aten.add]
triton_poi_fused_add_1.run(buf9, buf5, buf8, 16, grid=grid(16), stream=stream0)
del buf2
del buf5
del buf8
# Topologically Sorted Source Nodes: [add_3, d, d_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf10 = torch.ops.symm_mem.one_shot_all_reduce.default(buf9, 'sum', '0')
del symm_mem_pool_0, buf0, buf3, buf6, buf9
buf11 = buf10
del buf10
return (buf11, )
```
cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov
[ghstack-poisoned]
```
NOTE [comm buffer planning]
This file contains the memory planning logic for comm buffers. Compared to
regular buffer planning, Inductor leverages the allocator's "persistent
allocation" capability to meet the stringent requirements for registered
buffer communication (see NOTE [lowering-time collective optimization] for
details regarding the requirements).
Comm buffer planning is a collaborative process between Inductor and the
allocator. Inductor is responsible for planning comm buffers within a graph.
For each (comm_buffer_type, group_name) pair, Inductor uses a single,
persistently allocated pool to fulfill all comm buffer allocations. The
allocator manages memory reuse across subgroups.
In practice, comm buffer planning differs from regular buffer planning in the
following ways:
- Comm buffers can't use memory from regular pools, and non-comm buffers
shouldn't use memory from comm buffer pools [1]. This means that (1) comm
buffers are planned in isolation from regular buffers and (2) comm buffers
don't participate in in-place reuse (this simplifies the logic).
- To allow for memory reuse for persistent allocations across subgraphs,
Inductor needs to "free" the pool before exiting each subgraph. This means
that comm buffers cannot be graph outputs (this simplifies the logic).
- Comm buffer pools are allocated with dedicated allocators.
- Comm buffers for different (comm_buffer_type, group_name) pairs are planned
separately in an isolated fashion.
For comm buffer planning, we reuse most of the fundamental logic from regular
buffer planning. To accommodate the above differences, we use the
`CommBufferLine` hierarchy, which resembles the `MemoryPlanningLine`
hierarchy, to represent comm buffer allocations at codegen time. We use the
`CommBufferPlanner` to carry out comm buffer planning in an isolated fashion.
Ideally, the comm buffer planning logic could be further consolidated with
the existing buffer planning logic. For the time being, since the two code
paths differ in maturity, we prefer isolation at the cost of some divergence.
[1] Allowing non-comm buffers to use memory from comm buffer pools may
unnecessarily increase the size of persistently allocated memory. In the
future, we can optimize memory usage by performing comm buffer planning
first, then letting regular buffer planning leverage the free live ranges
from the comm buffer pools.
```
### Example Graph
```python
def call(args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (4, 4), (4, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
symm_mem_pool_0 = empty_strided_p2p((192, ), (1, ), torch.uint8, torch.device("cuda:0"), group_name="0", alloc_id=13423167936938864713)
buf0 = alloc_from_pool(symm_mem_pool_0, 128, torch.float32, (4, 4), (4, 1))
buf3 = alloc_from_pool(symm_mem_pool_0, 64, torch.float32, (4, 4), (4, 1))
buf6 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
# Topologically Sorted Source Nodes: [a, b, c], Original ATen: [aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_0.run(arg0_1, buf0, buf3, buf6, 16, grid=grid(16), stream=stream0)
del arg0_1
# Topologically Sorted Source Nodes: [a, a_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf1 = torch.ops.symm_mem.one_shot_all_reduce.default(buf0, 'sum', '0')
buf2 = buf1
del buf1
# Topologically Sorted Source Nodes: [b, b_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf4 = torch.ops.symm_mem.one_shot_all_reduce.default(buf3, 'sum', '0')
buf5 = buf4
del buf4
# Topologically Sorted Source Nodes: [c, c_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf7 = torch.ops.symm_mem.one_shot_all_reduce.default(buf6, 'sum', '0')
buf8 = buf7
del buf7
buf9 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
# Topologically Sorted Source Nodes: [add_3, d], Original ATen: [aten.add]
triton_poi_fused_add_1.run(buf9, buf5, buf8, 16, grid=grid(16), stream=stream0)
del buf2
del buf5
del buf8
# Topologically Sorted Source Nodes: [add_3, d, d_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf10 = torch.ops.symm_mem.one_shot_all_reduce.default(buf9, 'sum', '0')
del symm_mem_pool_0, buf0, buf3, buf6, buf9
buf11 = buf10
del buf10
return (buf11, )
```
cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov
[ghstack-poisoned]
```
NOTE [comm buffer planning]
This file contains the memory planning logic for comm buffers. Compared to
regular buffer planning, Inductor leverages the allocator's "persistent
allocation" capability to meet the stringent requirements for registered
buffer communication (see NOTE [lowering-time collective optimization] for
details regarding the requirements).
Comm buffer planning is a collaborative process between Inductor and the
allocator. Inductor is responsible for planning comm buffers within a graph.
For each (comm_buffer_type, group_name) pair, Inductor uses a single,
persistently allocated pool to fulfill all comm buffer allocations. The
allocator manages memory reuse across subgroups.
In practice, comm buffer planning differs from regular buffer planning in the
following ways:
- Comm buffers can't use memory from regular pools, and non-comm buffers
shouldn't use memory from comm buffer pools [1]. This means that (1) comm
buffers are planned in isolation from regular buffers and (2) comm buffers
don't participate in in-place reuse (this simplifies the logic).
- To allow for memory reuse for persistent allocations across subgraphs,
Inductor needs to "free" the pool before exiting each subgraph. This means
that comm buffers cannot be graph outputs (this simplifies the logic).
- Comm buffer pools are allocated with dedicated allocators.
- Comm buffers for different (comm_buffer_type, group_name) pairs are planned
separately in an isolated fashion.
For comm buffer planning, we reuse most of the fundamental logic from regular
buffer planning. To accommodate the above differences, we use the
`CommBufferLine` hierarchy, which resembles the `MemoryPlanningLine`
hierarchy, to represent comm buffer allocations at codegen time. We use the
`CommBufferPlanner` to carry out comm buffer planning in an isolated fashion.
Ideally, the comm buffer planning logic could be further consolidated with
the existing buffer planning logic. For the time being, since the two code
paths differ in maturity, we prefer isolation at the cost of some divergence.
[1] Allowing non-comm buffers to use memory from comm buffer pools may
unnecessarily increase the size of persistently allocated memory. In the
future, we can optimize memory usage by performing comm buffer planning
first, then letting regular buffer planning leverage the free live ranges
from the comm buffer pools.
```
### Example Graph
```python
def call(args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (4, 4), (4, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
symm_mem_pool_0 = empty_strided_p2p((192, ), (1, ), torch.uint8, torch.device("cuda:0"), group_name="0", alloc_id=13423167936938864713)
buf0 = alloc_from_pool(symm_mem_pool_0, 128, torch.float32, (4, 4), (4, 1))
buf3 = alloc_from_pool(symm_mem_pool_0, 64, torch.float32, (4, 4), (4, 1))
buf6 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
# Topologically Sorted Source Nodes: [a, b, c], Original ATen: [aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_0.run(arg0_1, buf0, buf3, buf6, 16, grid=grid(16), stream=stream0)
del arg0_1
# Topologically Sorted Source Nodes: [a, a_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf1 = torch.ops.symm_mem.one_shot_all_reduce.default(buf0, 'sum', '0')
buf2 = buf1
del buf1
# Topologically Sorted Source Nodes: [b, b_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf4 = torch.ops.symm_mem.one_shot_all_reduce.default(buf3, 'sum', '0')
buf5 = buf4
del buf4
# Topologically Sorted Source Nodes: [c, c_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf7 = torch.ops.symm_mem.one_shot_all_reduce.default(buf6, 'sum', '0')
buf8 = buf7
del buf7
buf9 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
# Topologically Sorted Source Nodes: [add_3, d], Original ATen: [aten.add]
triton_poi_fused_add_1.run(buf9, buf5, buf8, 16, grid=grid(16), stream=stream0)
del buf2
del buf5
del buf8
# Topologically Sorted Source Nodes: [add_3, d, d_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf10 = torch.ops.symm_mem.one_shot_all_reduce.default(buf9, 'sum', '0')
del symm_mem_pool_0, buf0, buf3, buf6, buf9
buf11 = buf10
del buf10
return (buf11, )
```
cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov
[ghstack-poisoned]
```
NOTE [comm buffer planning]
This file contains the memory planning logic for comm buffers. Compared to
regular buffer planning, Inductor leverages the allocator's "persistent
allocation" capability to meet the stringent requirements for registered
buffer communication (see NOTE [lowering-time collective optimization] for
details regarding the requirements).
Comm buffer planning is a collaborative process between Inductor and the
allocator. Inductor is responsible for planning comm buffers within a graph.
For each (comm_buffer_type, group_name) pair, Inductor uses a single,
persistently allocated pool to fulfill all comm buffer allocations. The
allocator manages memory reuse across subgroups.
In practice, comm buffer planning differs from regular buffer planning in the
following ways:
- Comm buffers can't use memory from regular pools, and non-comm buffers
shouldn't use memory from comm buffer pools [1]. This means that (1) comm
buffers are planned in isolation from regular buffers and (2) comm buffers
don't participate in in-place reuse (this simplifies the logic).
- To allow for memory reuse for persistent allocations across subgraphs,
Inductor needs to "free" the pool before exiting each subgraph. This means
that comm buffers cannot be graph outputs (this simplifies the logic).
- Comm buffer pools are allocated with dedicated allocators.
- Comm buffers for different (comm_buffer_type, group_name) pairs are planned
separately in an isolated fashion.
For comm buffer planning, we reuse most of the fundamental logic from regular
buffer planning. To accommodate the above differences, we use the
`CommBufferLine` hierarchy, which resembles the `MemoryPlanningLine`
hierarchy, to represent comm buffer allocations at codegen time. We use the
`CommBufferPlanner` to carry out comm buffer planning in an isolated fashion.
Ideally, the comm buffer planning logic could be further consolidated with
the existing buffer planning logic. For the time being, since the two code
paths differ in maturity, we prefer isolation at the cost of some divergence.
[1] Allowing non-comm buffers to use memory from comm buffer pools may
unnecessarily increase the size of persistently allocated memory. In the
future, we can optimize memory usage by performing comm buffer planning
first, then letting regular buffer planning leverage the free live ranges
from the comm buffer pools.
```
### Example Graph
```python
def call(args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (4, 4), (4, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
symm_mem_pool_0 = empty_strided_p2p((192, ), (1, ), torch.uint8, torch.device("cuda:0"), group_name="0", alloc_id=13423167936938864713)
buf0 = alloc_from_pool(symm_mem_pool_0, 128, torch.float32, (4, 4), (4, 1))
buf3 = alloc_from_pool(symm_mem_pool_0, 64, torch.float32, (4, 4), (4, 1))
buf6 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
# Topologically Sorted Source Nodes: [a, b, c], Original ATen: [aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_0.run(arg0_1, buf0, buf3, buf6, 16, grid=grid(16), stream=stream0)
del arg0_1
# Topologically Sorted Source Nodes: [a, a_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf1 = torch.ops.symm_mem.one_shot_all_reduce.default(buf0, 'sum', '0')
buf2 = buf1
del buf1
# Topologically Sorted Source Nodes: [b, b_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf4 = torch.ops.symm_mem.one_shot_all_reduce.default(buf3, 'sum', '0')
buf5 = buf4
del buf4
# Topologically Sorted Source Nodes: [c, c_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf7 = torch.ops.symm_mem.one_shot_all_reduce.default(buf6, 'sum', '0')
buf8 = buf7
del buf7
buf9 = alloc_from_pool(symm_mem_pool_0, 0, torch.float32, (4, 4), (4, 1))
# Topologically Sorted Source Nodes: [add_3, d], Original ATen: [aten.add]
triton_poi_fused_add_1.run(buf9, buf5, buf8, 16, grid=grid(16), stream=stream0)
del buf2
del buf5
del buf8
# Topologically Sorted Source Nodes: [add_3, d, d_1], Original ATen: [aten.add, _c10d_functional.all_reduce]
buf10 = torch.ops.symm_mem.one_shot_all_reduce.default(buf9, 'sum', '0')
del symm_mem_pool_0, buf0, buf3, buf6, buf9
buf11 = buf10
del buf10
return (buf11, )
```
cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov
[ghstack-poisoned]
| # | ||
| # Comm buffer planning is a collaborative process between Inductor and the | ||
| # allocator. Inductor is responsible for planning comm buffers within a graph. | ||
| # For each (comm_buffer_type, group_name) pair, Inductor uses a single, |
There was a problem hiding this comment.
can you define (comm_buffer_type, group_name) pair (or explain what the differerent types are for)? I assume group_name refers to process_group? (also, why do different PGs need to have different allocations? bc we have to tie the registration to a particular ncclcomm?
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
See #162859. This PR adds initial support for symmetric buffers (Comm buffer) in `torch.compile` by realizing comm buffers during Inductor lowering and enabling conservative reuse using the existing `memory_plan_reuse` infrastructure. * **Comm buffer realization**: Each `torch.ops.symm_mem` operation is lowered to allocate a comm buffer via `empty_strided_p2p`. * **Layout support**: Relaxes layout restrictions so both `FixedLayout` and `FlexibleLayout` buffers can be realized as comm buffers. * **Comm buffer reuse**: Comm buffers are reused only when their lifetimes do not overlap and when they share an identical reuse key `(device, dtype, size, comm_buffer_type, group_name)`. To prevent mixing communication buffers with regular CUDA buffers, the memory planner maintains a dedicated comm-buffer reuse pool and routes allocations via a `comm_buffer` flag on existing planning lines, eliminating the need for separate comm-buffer-specific line classes. More general memory planning (like what’s proposed in #138519) can be a follow-up. Pull Request resolved: #171909 Approved by: https://github.com/kwen2501, https://github.com/eellison
Stack from ghstack (oldest at bottom):
Example Graph
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov