Skip to content

Commit 53ecb81

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
Introduce statically_known_false (#154291)
Pull Request resolved: #154291 Approved by: https://github.com/mengluy0125
1 parent 2dfc0e3 commit 53ecb81

5 files changed

Lines changed: 68 additions & 0 deletions

File tree

docs/source/fx.experimental.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ torch.fx.experimental.symbolic_shapes
4848
constrain_unify
4949
canonicalize_bool_expr
5050
statically_known_true
51+
statically_known_false
5152
has_static_value
5253
lru_cache
5354
check_consistent

test/test_dynamic_shapes.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
is_symbolic,
3535
ShapeEnv,
3636
StatelessSymbolicContext,
37+
statically_known_false,
3738
statically_known_true,
3839
)
3940
from torch.testing._internal.common_dtype import all_types_and
@@ -1214,6 +1215,36 @@ def test_statically_known_true(self):
12141215
# No guards should be generated
12151216
self.assertEqual(len(shape_env.guards), 0)
12161217

1218+
def test_statically_known_false(self):
1219+
shape_env = ShapeEnv()
1220+
s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
1221+
1222+
# Statically known true
1223+
self.assertFalse(statically_known_false(True))
1224+
self.assertFalse(statically_known_false(s2 == s2))
1225+
self.assertFalse(statically_known_false(s2 * s3 > s3))
1226+
self.assertFalse(statically_known_false(s3 * s4 > s4))
1227+
self.assertFalse(statically_known_false((s3 + s3) % 2 == 0))
1228+
1229+
# Statically known false
1230+
self.assertTrue(statically_known_false(False))
1231+
self.assertTrue(statically_known_false(s3 * s4 <= s4))
1232+
self.assertTrue(statically_known_false((s3 + s3) % 2 == 1))
1233+
1234+
# True for hints, but not known statically
1235+
self.assertFalse(statically_known_false(s2 + s2 == s4))
1236+
self.assertFalse(statically_known_false(s4 % s2 == 0))
1237+
self.assertFalse(statically_known_false(s2 != s3))
1238+
self.assertFalse(statically_known_false(s3 * s4 > s2))
1239+
1240+
# False for hints, but not known statically
1241+
self.assertFalse(statically_known_false(s2 == s3))
1242+
self.assertFalse(statically_known_false(s2 > s3))
1243+
self.assertFalse(statically_known_false(s3 + s3 == s4))
1244+
1245+
# No guards should be generated
1246+
self.assertEqual(len(shape_env.guards), 0)
1247+
12171248
def test_ephemeral_source_simplification(self):
12181249
from torch._dynamo.source import EphemeralSource
12191250

torch/_dynamo/trace_rules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@
315315
"torch.fx.experimental.symbolic_shapes.guard_or_true": TorchInGraphFunctionVariable,
316316
"torch.fx.experimental.symbolic_shapes.guard_or_false": TorchInGraphFunctionVariable,
317317
"torch.fx.experimental.symbolic_shapes.statically_known_true": TorchInGraphFunctionVariable,
318+
"torch.fx.experimental.symbolic_shapes.statically_known_false": TorchInGraphFunctionVariable,
318319
"torch.fx.experimental.symbolic_shapes.has_static_value": TorchInGraphFunctionVariable,
319320
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
320321
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,

torch/_dynamo/variables/torch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,17 @@ def handle_guard_or_false(self, tx: "InstructionTranslator", expr):
923923
elif isinstance(expr, ConstantVariable):
924924
return expr
925925

926+
@register(torch.fx.experimental.symbolic_shapes.statically_known_false)
927+
def handle_statically_known_false(self, tx: "InstructionTranslator", expr):
928+
if isinstance(expr, SymNodeVariable):
929+
return variables.ConstantVariable.create(
930+
torch.fx.experimental.symbolic_shapes.statically_known_false(
931+
expr.sym_num
932+
)
933+
)
934+
elif isinstance(expr, ConstantVariable):
935+
return expr
936+
926937
@register(torch.fx.experimental.symbolic_shapes.statically_known_true)
927938
def handle_statically_known_true(self, tx: "InstructionTranslator", expr):
928939
if isinstance(expr, SymNodeVariable):

torch/fx/experimental/symbolic_shapes.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class PendingUnbackedSymbolNotFound(RuntimeError):
160160
"SymIntSymbolicContext",
161161
"TrackedFake",
162162
"statically_known_true",
163+
"statically_known_false",
163164
"guard_size_oblivious",
164165
"check_consistent",
165166
"compute_unbacked_bindings",
@@ -1292,6 +1293,29 @@ def _static_eval_sym_bool(x: SymBool) -> Optional[bool]:
12921293
return None
12931294

12941295

1296+
def statically_known_false(x: BoolLikeType) -> bool:
1297+
"""
1298+
Returns True if x can be simplified to a constant and is False.
1299+
If x cannot be evaluated from static, we return False
1300+
1301+
.. note::
1302+
This function doesn't introduce new guards, so the expression may end
1303+
up evaluating to False at runtime even if this function returns False.
1304+
1305+
Args:
1306+
x (bool, SymBool): The expression to try statically evaluating
1307+
"""
1308+
if not isinstance(x, SymBool):
1309+
assert isinstance(x, bool)
1310+
return not x
1311+
1312+
result = _static_eval_sym_bool(x)
1313+
if result is None:
1314+
return False
1315+
1316+
return not result
1317+
1318+
12951319
def statically_known_true(x: BoolLikeType) -> bool:
12961320
"""
12971321
Returns True if x can be simplified to a constant and is true.

0 commit comments

Comments
 (0)