Skip to content

Commit 3be70dc

Browse files
v0i0pytorchmergebot
authored andcommitted
[inductor] dont reuse buffers if it affects peak (pytorch#145883) (pytorch#159530)
Pull Request resolved: pytorch#159530 Approved by: https://github.com/eellison
1 parent 47a1db8 commit 3be70dc

5 files changed

Lines changed: 615 additions & 2 deletions

File tree

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
# Owner(s): ["module: inductor"]
2+
3+
import pytest
4+
from hypothesis import given, strategies as st
5+
6+
from torch._inductor.codegen.segmented_tree import SegmentedTree
7+
from torch.testing._internal.common_utils import run_tests
8+
9+
10+
# Helper functions for operations
11+
def max_op(a, b):
12+
return max(a, b)
13+
14+
15+
def add_op(a, b):
16+
return a + b
17+
18+
19+
# Naive implementations for reference
20+
def naive_range_max(arr, start, end):
21+
return max(arr[start : end + 1])
22+
23+
24+
def naive_range_update(arr, start, end, value):
25+
for i in range(start, end + 1):
26+
arr[i] += value
27+
28+
29+
# Strategies for hypothesis testing
30+
positive_integers = st.lists(
31+
st.integers(min_value=1, max_value=100), min_size=1, max_size=50
32+
)
33+
34+
35+
def valid_range_indices(array_length):
36+
return st.tuples(
37+
st.integers(min_value=0, max_value=array_length - 1),
38+
st.integers(min_value=0, max_value=array_length - 1),
39+
).map(lambda x: (min(x), max(x)))
40+
41+
42+
update_values = st.integers(min_value=1, max_value=50)
43+
44+
45+
# Basic construction and initialization tests
46+
def test_basic_construction():
47+
values = [1, 3, 5, 7, 9]
48+
tree = SegmentedTree(values, add_op, max_op, 0)
49+
assert tree.summarize_range(0, 4) == 9
50+
51+
52+
def test_empty_array():
53+
with pytest.raises(ValueError):
54+
SegmentedTree([], add_op, max_op, 0)
55+
56+
57+
# Property-based tests
58+
@given(values=positive_integers)
59+
def test_max_query_matches_naive(values):
60+
tree = SegmentedTree(values, add_op, max_op, 0)
61+
62+
for start in range(len(values)):
63+
for end in range(start, len(values)):
64+
expected = naive_range_max(values, start, end)
65+
actual = tree.summarize_range(start, end)
66+
assert actual == expected, (
67+
f"Range [{start}:{end}] expected {expected}, got {actual}"
68+
)
69+
70+
71+
@given(values=positive_integers, range_indices=st.data(), update_value=update_values)
72+
def test_range_update(values, range_indices, update_value):
73+
# Create a copy for naive implementation
74+
naive_values = values.copy()
75+
76+
# Create segment tree
77+
tree = SegmentedTree(values, add_op, max_op, 0)
78+
79+
# Get valid range indices
80+
start, end = range_indices.draw(valid_range_indices(len(values)))
81+
82+
# Apply updates
83+
tree.update_range(start, end, update_value)
84+
naive_range_update(naive_values, start, end, update_value)
85+
86+
# Verify all possible ranges
87+
for i in range(len(values)):
88+
for j in range(i, len(values)):
89+
expected = naive_range_max(naive_values, i, j)
90+
actual = tree.summarize_range(i, j)
91+
assert actual == expected, (
92+
f"After update, range [{i}:{j}] expected {expected}, got {actual}"
93+
)
94+
95+
96+
@given(values=positive_integers, range_data=st.data())
97+
def test_multiple_operations(values, range_data):
98+
# Create a copy for naive implementation
99+
naive_values = values.copy()
100+
tree = SegmentedTree(values, add_op, max_op, 0)
101+
102+
# Perform multiple operations
103+
num_operations = 5
104+
for _ in range(num_operations):
105+
# Randomly choose between query and update
106+
operation_type = range_data.draw(st.sampled_from(["query", "update"]))
107+
start, end = range_data.draw(valid_range_indices(len(values)))
108+
109+
if operation_type == "query":
110+
expected = naive_range_max(naive_values, start, end)
111+
actual = tree.summarize_range(start, end)
112+
assert actual == expected, (
113+
f"Range query [{start}:{end}] expected {expected}, got {actual}"
114+
)
115+
else: # update
116+
update_value = range_data.draw(update_values)
117+
tree.update_range(start, end, update_value)
118+
naive_range_update(naive_values, start, end, update_value)
119+
120+
121+
def test_single_element_ranges():
122+
values = [1, 3, 5, 7, 9]
123+
tree = SegmentedTree(values, add_op, max_op, 0)
124+
125+
for i in range(len(values)):
126+
assert tree.summarize_range(i, i) == values[i], (
127+
f"Single element range at index {i} failed"
128+
)
129+
130+
131+
def test_full_array_range():
132+
values = [1, 3, 5, 7, 9]
133+
tree = SegmentedTree(values, add_op, max_op, 0)
134+
135+
# Test querying the entire array
136+
assert tree.summarize_range(0, len(values) - 1) == max(values)
137+
138+
# Update the entire array and test again
139+
update_value = 10
140+
tree.update_range(0, len(values) - 1, update_value)
141+
expected = max([v + update_value for v in values])
142+
assert tree.summarize_range(0, len(values) - 1) == expected
143+
144+
145+
def test_boundary_conditions():
146+
values = [1, 3, 5, 7, 9]
147+
tree = SegmentedTree(values, add_op, max_op, 0)
148+
149+
# Test first element
150+
assert tree.summarize_range(0, 0) == values[0]
151+
152+
# Test last element
153+
assert tree.summarize_range(len(values) - 1, len(values) - 1) == values[-1]
154+
155+
# Test first two elements
156+
assert tree.summarize_range(0, 1) == max(values[0:2])
157+
158+
# Test last two elements
159+
assert tree.summarize_range(len(values) - 2, len(values) - 1) == max(values[-2:])
160+
161+
162+
def test_invalid_ranges():
163+
values = [1, 3, 5, 7, 9]
164+
tree = SegmentedTree(values, add_op, max_op, 0)
165+
166+
# Test start > end
167+
with pytest.raises(ValueError):
168+
tree.summarize_range(3, 2)
169+
170+
with pytest.raises(ValueError):
171+
tree.update_range(4, 2, 10)
172+
173+
174+
def test_out_of_bounds():
175+
values = [1, 3, 5, 7, 9]
176+
tree = SegmentedTree(values, add_op, max_op, 0)
177+
178+
# Test negative indices
179+
with pytest.raises(ValueError):
180+
tree.summarize_range(-1, 3)
181+
182+
with pytest.raises(ValueError):
183+
tree.summarize_range(0, -1)
184+
185+
# Test indices >= n
186+
with pytest.raises(ValueError):
187+
tree.summarize_range(0, len(values))
188+
189+
with pytest.raises(ValueError):
190+
tree.summarize_range(len(values), len(values) + 1)
191+
192+
# Test update with out of bounds indices
193+
with pytest.raises(ValueError):
194+
tree.update_range(-1, 3, 10)
195+
196+
with pytest.raises(ValueError):
197+
tree.update_range(0, len(values), 10)
198+
199+
200+
def test_overlapping_updates():
201+
values = [1, 3, 5, 7, 9]
202+
naive_values = values.copy()
203+
tree = SegmentedTree(values, add_op, max_op, 0)
204+
205+
# Apply overlapping updates
206+
tree.update_range(0, 2, 5) # Update [0, 1, 2]
207+
naive_range_update(naive_values, 0, 2, 5)
208+
209+
tree.update_range(1, 3, 3) # Update [1, 2, 3]
210+
naive_range_update(naive_values, 1, 3, 3)
211+
212+
# Verify all possible ranges
213+
for i in range(len(values)):
214+
for j in range(i, len(values)):
215+
expected = naive_range_max(naive_values, i, j)
216+
actual = tree.summarize_range(i, j)
217+
assert actual == expected, (
218+
f"After overlapping updates, range [{i}:{j}] expected {expected}, got {actual}"
219+
)
220+
221+
222+
def test_sequential_updates_and_queries():
223+
values = [2, 4, 6, 8, 10, 12, 14]
224+
naive_values = values.copy()
225+
tree = SegmentedTree(values, add_op, max_op, 0)
226+
227+
# Sequence of operations
228+
operations = [
229+
("update", 1, 3, 5), # Update range [1, 2, 3] with +5
230+
("query", 0, 4), # Query range [0, 1, 2, 3, 4]
231+
("update", 2, 5, 3), # Update range [2, 3, 4, 5] with +3
232+
("query", 1, 3), # Query range [1, 2, 3]
233+
("update", 0, 6, 2), # Update entire array with +2
234+
("query", 0, 6), # Query entire array
235+
("query", 3, 5), # Query range [3, 4, 5]
236+
]
237+
238+
for op in operations:
239+
if op[0] == "update":
240+
_, start, end, value = op
241+
tree.update_range(start, end, value)
242+
naive_range_update(naive_values, start, end, value)
243+
244+
# Verify tree state after update
245+
for i in range(len(values)):
246+
for j in range(i, len(values)):
247+
expected = naive_range_max(naive_values, i, j)
248+
actual = tree.summarize_range(i, j)
249+
assert actual == expected, (
250+
f"After update ({start}, {end}, {value}), query [{i}:{j}] expected {expected}, got {actual}"
251+
)
252+
else: # query
253+
_, start, end = op
254+
expected = naive_range_max(naive_values, start, end)
255+
assert tree.summarize_range(start, end) == expected, (
256+
f"Query [{start}:{end}] expected {expected}, got {tree.summarize_range(start, end)}"
257+
)
258+
259+
260+
if __name__ == "__main__":
261+
run_tests()

test/inductor/test_torchinductor.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13754,6 +13754,45 @@ def f(input, repeats):
1375413754
has_lowered = not re.search(r"repeat_interleave.Tensor", code)
1375513755
self.assertEqual(has_lowered, can_lower)
1375613756

13757+
@staticmethod
13758+
def _is_triggering_buffer_reuse(fn, *inputs):
13759+
with config.patch(allow_buffer_reuse=True):
13760+
_, (code_allowed,) = run_and_get_code(fn, *inputs)
13761+
with config.patch(allow_buffer_reuse=False):
13762+
_, (code_disallowed,) = run_and_get_code(fn, *inputs)
13763+
code_allowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_allowed)
13764+
code_disallowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_disallowed)
13765+
return code_allowed != code_disallowed
13766+
13767+
def test_allow_reuse_disable_if_exceed_peak(self):
13768+
@torch.compile
13769+
def fn(inp): # 1*N^2
13770+
a = inp.mean(-1) # 1*N^2 + N
13771+
b = (inp - a) ** 2 # 2*N^2 + N
13772+
c = b @ b # 3*N^2 (!!) since this is the peak, can not reuse across
13773+
d = c.mean(-1) # 2*N^2 + N
13774+
return d # 1*N^2 + N
13775+
13776+
inp = torch.randn(100, 100, device=self.device)
13777+
self.assertFalse(CommonTemplate._is_triggering_buffer_reuse(fn, inp))
13778+
13779+
def test_allow_reuse_active_if_under_peak(self):
13780+
def g(inp):
13781+
return (inp - torch.logsumexp(inp, -1)) ** 2
13782+
13783+
@torch.compile
13784+
def fn(m, inp):
13785+
inp = m @ g(inp)
13786+
inp = m @ g(inp)
13787+
inp = m @ g(inp)
13788+
inp = m @ g(inp)
13789+
inp = m @ g(inp)
13790+
return inp
13791+
13792+
m = torch.randn(100, 100, device=self.device)
13793+
inp = torch.randn(100, 100, device=self.device)
13794+
self.assertTrue(CommonTemplate._is_triggering_buffer_reuse(fn, m, inp))
13795+
1375713796
# end of class CommonTemplate - add new tests here
1375813797

1375913798

0 commit comments

Comments
 (0)