11import torch
2- from typing import Set , Dict , List , Type , Optional , cast , Union
2+ from typing import Set , Dict , List , Type , Optional , cast , Union , Tuple
33import sys
44import builtins
55import itertools
1717# NB: The sym_* functions are used via getattr() and must be imported here.
1818from torch import SymInt , SymFloat , SymBool , sym_not , sym_float , sym_max , sym_min # noqa: F401
1919from torch ._guards import ShapeGuard , Source
20- from torch .utils ._sympy .value_ranges import ValueRanges , ValueRangeAnalysis
21- from torch .utils ._sympy .interp import sympy_interp
2220
2321SymTypes = (SymInt , SymFloat , SymBool )
2422
@@ -118,26 +116,6 @@ def guard_scalar(a):
118116 else :
119117 raise AssertionError (f"unrecognized scalar { a } " )
120118
121- # inclusive both ways
122- def constrain_range (a , * , min : Optional [int ], max : Optional [int ] = None ):
123- if min is None :
124- min = - sympy .oo
125- if max is None :
126- max = sympy .oo
127- if not isinstance (a , SymInt ):
128- assert min <= a <= max
129- return
130- if isinstance (a .node .expr , sympy .Integer ):
131- assert min <= int (a .node .expr ) <= max
132- return
133- # TODO: Turn this into a runtime assert too
134- assert isinstance (a .node .expr , sympy .Symbol ), "constraining non-Symbols NYI"
135- r = a .node .shape_env .var_to_range [a .node .expr ]
136- a .node .shape_env .var_to_range [a .node .expr ] = ValueRanges (
137- builtins .max (r .lower , min ), builtins .min (r .upper , max )
138- )
139-
140-
141119def guard_bool (a ):
142120 if isinstance (a , SymBool ):
143121 return a .node .guard_bool ("" , 0 ) # NB: uses Python backtrace
@@ -1094,11 +1072,6 @@ def __init__(self, allow_scalar_outputs=True, strict_mark_dyn=False, assume_stat
10941072 # Maps symbolic ints to their original concrete values
10951073 # Currently populated from tensors
10961074 self .var_to_val : Dict ["sympy.Symbol" , "sympy.Integer" ] = {}
1097- # Maps symbolic ints to their min/max range. These ranges
1098- # are conservative: the int MUST fall in the range, but the
1099- # range may contain ints which may not actually appear in
1100- # practice
1101- self .var_to_range : Dict ["sympy.Symbol" , ValueRanges ] = {}
11021075 # Maps from sympy ints to expressions representing them
11031076 # Populated from equality guards (i.e. a.shape[0] == b.shape[0])
11041077 self .replacements : Dict ["sympy.Symbol" , "sympy.Expr" ] = {} #
@@ -1109,6 +1082,18 @@ def __init__(self, allow_scalar_outputs=True, strict_mark_dyn=False, assume_stat
11091082 self .val_to_var : Dict [int , "sympy.Expr" ] = {0 : sympy .Integer (0 ), 1 : sympy .Integer (1 )}
11101083 self .unbacked_symfloat_counter = itertools .count ()
11111084 self .unbacked_symint_counter = itertools .count ()
1085+ # A bunch of facts involving unbacked symints that we can
1086+ # attempt replacements with. This is very dumb and should
1087+ # be replaced with a proper entailment mechanism.
1088+ #
1089+ # The dictionary is indexed in the following way. Suppose you have
1090+ # a replacement s0 + s1 to e2. We arbitrarily pick a symbol in
1091+ # the source expression and place this substitution in the list of
1092+ # that key; e.g., {s0: (s0 + s1, e2)}. We will only attempt this
1093+ # substitution if s0 is present in the guard we're attempting to
1094+ # evaluate. The choice of key is arbitrary, since we will check
1095+ # for both s0 and s1 substitutions if s0 + s1 is in the key.
1096+ self .expr_subs : Dict ["sympy.Symbol" , List [Tuple ["sympy.Expr" , "sympy.Expr" ]]] = collections .defaultdict (list )
11121097 self .strict_mark_dyn = strict_mark_dyn
11131098 self .assume_static_by_default = assume_static_by_default
11141099
@@ -1205,13 +1190,11 @@ def create_symintnode(self, sym: "sympy.Expr", *, hint: Optional[int]):
12051190 def create_unbacked_symfloat (self ):
12061191 symbol = Symbol (f"f{ next (self .unbacked_symfloat_counter )} " )
12071192 symbol .stack = '' .join (traceback .format_list (traceback .extract_stack ()[:- 1 ]))
1208- self .var_to_range [symbol ] = ValueRanges .unknown ()
12091193 return SymFloat (SymNode (symbol , self , float , None ))
12101194
12111195 def create_unbacked_symint (self ):
12121196 symbol = Symbol (f"i{ next (self .unbacked_symint_counter )} " , integer = True )
12131197 symbol .stack = '' .join (traceback .format_list (traceback .extract_stack ()[:- 1 ]))
1214- self .var_to_range [symbol ] = ValueRanges .unknown ()
12151198 return SymInt (SymNode (symbol , self , int , None ))
12161199
12171200 # This is guaranteed to return a symbol or its negation is a sympy.Symbol,
@@ -1231,13 +1214,8 @@ def create_symbol(self, val: int, source: Source, dyn=False) -> "sympy.Expr":
12311214 self .var_to_val [sympy_expr ] = sympy .Integer (val )
12321215
12331216 if not dyn :
1234- # Non explicitly marked dynamic dims register to val_to_var to get duck shaped
1217+ # Only non dynamic goes here
12351218 self .val_to_var [val ] = sympy_expr
1236- # We also infer that they must not be 0/1
1237- self .var_to_range [sympy_expr ] = ValueRanges (2 , sympy .oo )
1238- else :
1239- # Avoid up front 0/1 specializing dynamic dims
1240- self .var_to_range [sympy_expr ] = ValueRanges (0 , sympy .oo )
12411219
12421220 if not dyn :
12431221 # This implements duck-shaping: input sizes that match are assigned
@@ -1444,23 +1422,13 @@ def _verify(expr, potential_expr):
14441422 log .warning (f"Failing guard allocated at: \n { tb } " )
14451423 raise
14461424
1447- # 3. Every symbol must be within its value range (this handles 0/1
1448- # specialization too). NB: because we never update value ranges
1449- # except in case of explicit user annotation, these are not included
1450- # in simplified. However, when we start updating value ranges
1451- # these should probably get reported in tests too
1425+ # 3. Every symbol must not be equal to 0/1
14521426 if not _simplified :
1453- for symbol , sources in symbol_to_source .items ():
1427+ for sources in symbol_to_source .values ():
14541428 assert sources
1455- r = self .var_to_range [symbol ]
1456- bounds = []
1457- if r .lower != - sympy .oo :
1458- bounds .append (str (r .lower ))
1459- bounds .append (source_ref (sources [0 ]))
1460- if r .upper != sympy .oo :
1461- bounds .append (str (r .upper ))
1462- if len (bounds ) > 1 :
1463- exprs .append (" <= " .join (bounds ))
1429+ # We must assert that each symbol is not zero or one, as we make
1430+ # negative inferences on shape variables
1431+ exprs .append (f"{ source_ref (sources [0 ])} != 0 and { source_ref (sources [0 ])} != 1" )
14641432
14651433 return exprs
14661434
@@ -1559,20 +1527,11 @@ def _maybe_evaluate_static(self, expr: "sympy.Expr") -> "Optional[sympy.Expr]":
15591527 if len (list (new_expr .free_symbols )) == 0 :
15601528 return new_expr
15611529
1562- # Check if the range can solve it statically
1563- range_env = {
1564- s : self .var_to_range [s ]
1565- for s in expr .free_symbols
1566- if s not in self .var_to_val
1567- }
1568- range_env .update ({
1569- new_shape_env [s ] - 1 : ValueRangeAnalysis .sub (self .var_to_range [s ], 1 )
1570- for s in expr .free_symbols
1571- if s in self .var_to_val
1572- })
1573- out = sympy_interp (ValueRangeAnalysis , range_env , new_expr )
1574- if out .is_singleton ():
1575- return out .lower
1530+ # Attempt expr_subs on the original expression
1531+ for s in new_expr .free_symbols :
1532+ new_expr = new_expr .subs (self .expr_subs [s ])
1533+ if len (list (new_expr .free_symbols )) == 0 :
1534+ return new_expr
15761535
15771536 return None
15781537
@@ -1638,13 +1597,10 @@ def size_hint(self, expr: "sympy.Expr"):
16381597 """
16391598 result_expr = safe_expand (expr ).xreplace (self .var_to_val )
16401599 if len (result_expr .free_symbols ) != 0 :
1641- range_env = {
1642- s : self .var_to_range [s ]
1643- for s in result_expr .free_symbols
1644- }
1645- out = sympy_interp (ValueRangeAnalysis , range_env , result_expr )
1646- if out .is_singleton ():
1647- return out .lower
1600+ for s in result_expr .free_symbols :
1601+ result_expr = result_expr .subs (self .expr_subs [s ])
1602+ if len (list (result_expr .free_symbols )) == 0 :
1603+ return result_expr
16481604 raise self ._make_data_dependent_error (result_expr )
16491605 return result_expr
16501606
0 commit comments