Skip to content

Commit daa588c

Browse files
committed
Update on "[inductor] dont reuse buffers if it affects peak (#145883)"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
2 parents eca991b + 77e5993 commit daa588c

1 file changed

Lines changed: 7 additions & 14 deletions

File tree

test/inductor/test_segmented_tree.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Owner(s): ["module: inductor"]
22

3-
from torch._inductor.test_case import TestCase, run_tests
43
from hypothesis import given, strategies as st
54

65
from torch._inductor.codegen.segmented_tree import SegmentedTree
6+
from torch._inductor.test_case import run_tests, TestCase
77

88

99
# Helper functions for operations
@@ -48,12 +48,10 @@ def test_basic_construction(self):
4848
tree = SegmentedTree(values, add_op, max_op, 0)
4949
assert tree.summarize_range(0, 4) == 9
5050

51-
5251
def test_empty_array(self):
5352
with self.assertRaises(ValueError):
5453
SegmentedTree([], add_op, max_op, 0)
5554

56-
5755
# Property-based tests
5856
@given(values=positive_integers)
5957
def test_max_query_matches_naive(self, values):
@@ -67,8 +65,9 @@ def test_max_query_matches_naive(self, values):
6765
f"Range [{start}:{end}] expected {expected}, got {actual}"
6866
)
6967

70-
71-
@given(values=positive_integers, range_indices=st.data(), update_value=update_values)
68+
@given(
69+
values=positive_integers, range_indices=st.data(), update_value=update_values
70+
)
7271
def test_range_update(self, values, range_indices, update_value):
7372
# Create a copy for naive implementation
7473
naive_values = values.copy()
@@ -92,7 +91,6 @@ def test_range_update(self, values, range_indices, update_value):
9291
f"After update, range [{i}:{j}] expected {expected}, got {actual}"
9392
)
9493

95-
9694
@given(values=positive_integers, range_data=st.data())
9795
def test_multiple_operations(self, values, range_data):
9896
# Create a copy for naive implementation
@@ -117,7 +115,6 @@ def test_multiple_operations(self, values, range_data):
117115
tree.update_range(start, end, update_value)
118116
naive_range_update(naive_values, start, end, update_value)
119117

120-
121118
def test_single_element_ranges(self):
122119
values = [1, 3, 5, 7, 9]
123120
tree = SegmentedTree(values, add_op, max_op, 0)
@@ -127,7 +124,6 @@ def test_single_element_ranges(self):
127124
f"Single element range at index {i} failed"
128125
)
129126

130-
131127
def test_full_array_range(self):
132128
values = [1, 3, 5, 7, 9]
133129
tree = SegmentedTree(values, add_op, max_op, 0)
@@ -141,7 +137,6 @@ def test_full_array_range(self):
141137
expected = max([v + update_value for v in values])
142138
assert tree.summarize_range(0, len(values) - 1) == expected
143139

144-
145140
def test_boundary_conditions(self):
146141
values = [1, 3, 5, 7, 9]
147142
tree = SegmentedTree(values, add_op, max_op, 0)
@@ -156,8 +151,9 @@ def test_boundary_conditions(self):
156151
assert tree.summarize_range(0, 1) == max(values[0:2])
157152

158153
# Test last two elements
159-
assert tree.summarize_range(len(values) - 2, len(values) - 1) == max(values[-2:])
160-
154+
assert tree.summarize_range(len(values) - 2, len(values) - 1) == max(
155+
values[-2:]
156+
)
161157

162158
def test_invalid_ranges(self):
163159
values = [1, 3, 5, 7, 9]
@@ -170,7 +166,6 @@ def test_invalid_ranges(self):
170166
with self.assertRaises(ValueError):
171167
tree.update_range(4, 2, 10)
172168

173-
174169
def test_out_of_bounds(self):
175170
values = [1, 3, 5, 7, 9]
176171
tree = SegmentedTree(values, add_op, max_op, 0)
@@ -196,7 +191,6 @@ def test_out_of_bounds(self):
196191
with self.assertRaises(ValueError):
197192
tree.update_range(0, len(values), 10)
198193

199-
200194
def test_overlapping_updates(self):
201195
values = [1, 3, 5, 7, 9]
202196
naive_values = values.copy()
@@ -218,7 +212,6 @@ def test_overlapping_updates(self):
218212
f"After overlapping updates, range [{i}:{j}] expected {expected}, got {actual}"
219213
)
220214

221-
222215
def test_sequential_updates_and_queries(self):
223216
values = [2, 4, 6, 8, 10, 12, 14]
224217
naive_values = values.copy()

0 commit comments

Comments
 (0)