Skip to content

Commit 4e88547

Browse files
Revert "Introduce constrain_range; remove old expr_subs (#95063)"
This reverts commit 3711f7c. Reverted #95063 on behalf of https://github.com/jeanschmidt due to Breaking internal builds, more details can be found: https://fburl.com/phabricator/fq5b6k8a
1 parent 1ab112c commit 4e88547

3 files changed

Lines changed: 36 additions & 80 deletions

File tree

test/test_proxy_tensor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212

1313
from torch._decomp import decomposition_table
1414
from torch.fx.experimental.symbolic_shapes import (
15-
sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets,
16-
constrain_range
15+
sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets
1716
)
1817
from torch.testing._internal.common_device_type import ops
1918
from torch._C import _disabled_torch_function_impl
@@ -900,7 +899,9 @@ def forward(self, a_1):
900899
def test_item_to_constructor(self):
901900
def f(a):
902901
r = a.item()
903-
constrain_range(r, min=0)
902+
r.node.shape_env.expr_subs[r.node.expr].append(((r >= 0).node.expr, True))
903+
# TODO: infer this constraint from r >= 0
904+
r.node.shape_env.expr_subs[r.node.expr].append(((r == -1).node.expr, False))
904905
return torch.empty(r)
905906

906907
r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip()
@@ -1065,7 +1066,7 @@ def f(a, b):
10651066
from torch._dynamo.source import LocalSource
10661067
self.assertExpectedInline(
10671068
str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")])),
1068-
"""['a.size()[0] == 2*b.size()[0]', 'a.stride()[0] == 1', 'a.storage_offset() == 0', 'b.stride()[0] == 1', 'b.storage_offset() == 0', '2 <= b.size()[0]']""" # noqa: B950
1069+
"""['a.size()[0] == 2*b.size()[0]', 'a.stride()[0] == 1', 'a.storage_offset() == 0', 'b.stride()[0] == 1', 'b.storage_offset() == 0', 'b.size()[0] != 0 and b.size()[0] != 1']""" # noqa: B950
10691070
)
10701071

10711072
def test_sym_storage_offset(self):

torch/fx/experimental/symbolic_shapes.py

Lines changed: 28 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from typing import Set, Dict, List, Type, Optional, cast, Union
2+
from typing import Set, Dict, List, Type, Optional, cast, Union, Tuple
33
import sys
44
import builtins
55
import itertools
@@ -17,8 +17,6 @@
1717
# NB: The sym_* functions are used via getattr() and must be imported here.
1818
from torch import SymInt, SymFloat, SymBool, sym_not, sym_float, sym_max, sym_min # noqa: F401
1919
from torch._guards import ShapeGuard, Source
20-
from torch.utils._sympy.value_ranges import ValueRanges, ValueRangeAnalysis
21-
from torch.utils._sympy.interp import sympy_interp
2220

2321
SymTypes = (SymInt, SymFloat, SymBool)
2422

@@ -118,26 +116,6 @@ def guard_scalar(a):
118116
else:
119117
raise AssertionError(f"unrecognized scalar {a}")
120118

121-
# inclusive both ways
122-
def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
123-
if min is None:
124-
min = -sympy.oo
125-
if max is None:
126-
max = sympy.oo
127-
if not isinstance(a, SymInt):
128-
assert min <= a <= max
129-
return
130-
if isinstance(a.node.expr, sympy.Integer):
131-
assert min <= int(a.node.expr) <= max
132-
return
133-
# TODO: Turn this into a runtime assert too
134-
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
135-
r = a.node.shape_env.var_to_range[a.node.expr]
136-
a.node.shape_env.var_to_range[a.node.expr] = ValueRanges(
137-
builtins.max(r.lower, min), builtins.min(r.upper, max)
138-
)
139-
140-
141119
def guard_bool(a):
142120
if isinstance(a, SymBool):
143121
return a.node.guard_bool("", 0) # NB: uses Python backtrace
@@ -1094,11 +1072,6 @@ def __init__(self, allow_scalar_outputs=True, strict_mark_dyn=False, assume_stat
10941072
# Maps symbolic ints to their original concrete values
10951073
# Currently populated from tensors
10961074
self.var_to_val: Dict["sympy.Symbol", "sympy.Integer"] = {}
1097-
# Maps symbolic ints to their min/max range. These ranges
1098-
# are conservative: the int MUST fall in the range, but the
1099-
# range may contain ints which may not actually appear in
1100-
# practice
1101-
self.var_to_range: Dict["sympy.Symbol", ValueRanges] = {}
11021075
# Maps from sympy ints to expressions representing them
11031076
# Populated from equality guards (i.e. a.shape[0] == b.shape[0])
11041077
self.replacements: Dict["sympy.Symbol", "sympy.Expr"] = {} #
@@ -1109,6 +1082,18 @@ def __init__(self, allow_scalar_outputs=True, strict_mark_dyn=False, assume_stat
11091082
self.val_to_var: Dict[int, "sympy.Expr"] = {0: sympy.Integer(0), 1: sympy.Integer(1)}
11101083
self.unbacked_symfloat_counter = itertools.count()
11111084
self.unbacked_symint_counter = itertools.count()
1085+
# A bunch of facts involving unbacked symints that we can
1086+
# attempt replacements with. This is very dumb and should
1087+
# be replaced with a proper entailment mechanism.
1088+
#
1089+
# The dictionary is indexed in the following way. Suppose you have
1090+
# a replacement s0 + s1 to e2. We arbitrarily pick a symbol in
1091+
# the source expression and place this substitution in the list of
1092+
# that key; e.g., {s0: (s0 + s1, e2)}. We will only attempt this
1093+
# substitution if s0 is present in the guard we're attempting to
1094+
# evaluate. The choice of key is arbitrary, since we will check
1095+
# for both s0 and s1 substitutions if s0 + s1 is in the key.
1096+
self.expr_subs: Dict["sympy.Symbol", List[Tuple["sympy.Expr", "sympy.Expr"]]] = collections.defaultdict(list)
11121097
self.strict_mark_dyn = strict_mark_dyn
11131098
self.assume_static_by_default = assume_static_by_default
11141099

@@ -1205,13 +1190,11 @@ def create_symintnode(self, sym: "sympy.Expr", *, hint: Optional[int]):
12051190
def create_unbacked_symfloat(self):
12061191
symbol = Symbol(f"f{next(self.unbacked_symfloat_counter)}")
12071192
symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
1208-
self.var_to_range[symbol] = ValueRanges.unknown()
12091193
return SymFloat(SymNode(symbol, self, float, None))
12101194

12111195
def create_unbacked_symint(self):
12121196
symbol = Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
12131197
symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
1214-
self.var_to_range[symbol] = ValueRanges.unknown()
12151198
return SymInt(SymNode(symbol, self, int, None))
12161199

12171200
# This is guaranteed to return a symbol or its negation is a sympy.Symbol,
@@ -1231,13 +1214,8 @@ def create_symbol(self, val: int, source: Source, dyn=False) -> "sympy.Expr":
12311214
self.var_to_val[sympy_expr] = sympy.Integer(val)
12321215

12331216
if not dyn:
1234-
# Non explicitly marked dynamic dims register to val_to_var to get duck shaped
1217+
# Only non dynamic goes here
12351218
self.val_to_var[val] = sympy_expr
1236-
# We also infer that they must not be 0/1
1237-
self.var_to_range[sympy_expr] = ValueRanges(2, sympy.oo)
1238-
else:
1239-
# Avoid up front 0/1 specializing dynamic dims
1240-
self.var_to_range[sympy_expr] = ValueRanges(0, sympy.oo)
12411219

12421220
if not dyn:
12431221
# This implements duck-shaping: input sizes that match are assigned
@@ -1444,23 +1422,13 @@ def _verify(expr, potential_expr):
14441422
log.warning(f"Failing guard allocated at: \n{tb}")
14451423
raise
14461424

1447-
# 3. Every symbol must be within its value range (this handles 0/1
1448-
# specialization too). NB: because we never update value ranges
1449-
# except in case of explicit user annotation, these are not included
1450-
# in simplified. However, when we start updating value ranges
1451-
# these should probably get reported in tests too
1425+
# 3. Every symbol must not be equal to 0/1
14521426
if not _simplified:
1453-
for symbol, sources in symbol_to_source.items():
1427+
for sources in symbol_to_source.values():
14541428
assert sources
1455-
r = self.var_to_range[symbol]
1456-
bounds = []
1457-
if r.lower != -sympy.oo:
1458-
bounds.append(str(r.lower))
1459-
bounds.append(source_ref(sources[0]))
1460-
if r.upper != sympy.oo:
1461-
bounds.append(str(r.upper))
1462-
if len(bounds) > 1:
1463-
exprs.append(" <= ".join(bounds))
1429+
# We must assert that each symbol is not zero or one, as we make
1430+
# negative inferences on shape variables
1431+
exprs.append(f"{source_ref(sources[0])} != 0 and {source_ref(sources[0])} != 1")
14641432

14651433
return exprs
14661434

@@ -1559,20 +1527,11 @@ def _maybe_evaluate_static(self, expr: "sympy.Expr") -> "Optional[sympy.Expr]":
15591527
if len(list(new_expr.free_symbols)) == 0:
15601528
return new_expr
15611529

1562-
# Check if the range can solve it statically
1563-
range_env = {
1564-
s: self.var_to_range[s]
1565-
for s in expr.free_symbols
1566-
if s not in self.var_to_val
1567-
}
1568-
range_env.update({
1569-
new_shape_env[s] - 1: ValueRangeAnalysis.sub(self.var_to_range[s], 1)
1570-
for s in expr.free_symbols
1571-
if s in self.var_to_val
1572-
})
1573-
out = sympy_interp(ValueRangeAnalysis, range_env, new_expr)
1574-
if out.is_singleton():
1575-
return out.lower
1530+
# Attempt expr_subs on the original expression
1531+
for s in new_expr.free_symbols:
1532+
new_expr = new_expr.subs(self.expr_subs[s])
1533+
if len(list(new_expr.free_symbols)) == 0:
1534+
return new_expr
15761535

15771536
return None
15781537

@@ -1638,13 +1597,10 @@ def size_hint(self, expr: "sympy.Expr"):
16381597
"""
16391598
result_expr = safe_expand(expr).xreplace(self.var_to_val)
16401599
if len(result_expr.free_symbols) != 0:
1641-
range_env = {
1642-
s: self.var_to_range[s]
1643-
for s in result_expr.free_symbols
1644-
}
1645-
out = sympy_interp(ValueRangeAnalysis, range_env, result_expr)
1646-
if out.is_singleton():
1647-
return out.lower
1600+
for s in result_expr.free_symbols:
1601+
result_expr = result_expr.subs(self.expr_subs[s])
1602+
if len(list(result_expr.free_symbols)) == 0:
1603+
return result_expr
16481604
raise self._make_data_dependent_error(result_expr)
16491605
return result_expr
16501606

torch/utils/_sympy/interp.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def sympy_interp(
6666
# sometimes?
6767
if isinstance(expr, sympy.Integer):
6868
return analysis.constant(int(expr), torch.int64)
69-
elif isinstance(expr, sympy.Number):
69+
elif isinstance(expr, sympy.Float):
7070
return analysis.constant(float(expr), torch.double)
7171
elif isinstance(expr, BooleanAtom):
7272
return analysis.constant(bool(expr), torch.bool)
@@ -81,9 +81,8 @@ def sympy_interp(
8181

8282
# Recursive case
8383
args = [sympy_interp(analysis, env, arg) for arg in expr.args] # type: ignore[arg-type]
84-
handler_name = handlers()[expr.func]
85-
handler = getattr(analysis, handler_name)
86-
if handler_name in ASSOCIATIVE_OPS:
84+
handler = getattr(analysis, handlers()[expr.func])
85+
if handler in ASSOCIATIVE_OPS:
8786
assert len(args) > 1
8887
acc = handler(args[0], args[1])
8988
for i in range(2, len(args)):

0 commit comments

Comments
 (0)