Skip to content

Commit d869344

Browse files
pianpwkpytorchmergebot
authored andcommitted
[export] serialize sympy.Exprs as ASTs instead of strings (#140084)
Summary: The way we've been de/serializing sympy.Exprs is not roundtrippable in all cases (serialize by calling `str(expr)`, and deserialize by calling `sympy.sympify(expr_str)`). This has led to expressions being mathematically equivalent but structurally different, causing issues in ValueRanges. Example issue: #136797 This starts to deprecate the use of `expr_str` and stores expressions in AST format instead. For BC purposes, `expr_str` deserialization is still supported, but we will always serialize to `expr_ast`. We'll kill this once the serialization upgrader design is finalized and implemented. Test Plan: test_export Differential Revision: D65638757 Pull Request resolved: #140084 Approved by: https://github.com/angelayi
1 parent 7e9e83a commit d869344

6 files changed

Lines changed: 365 additions & 73 deletions

File tree

test/export/test_export.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8133,7 +8133,6 @@ def forward(self, x):
81338133
export(f, (inputs,), dynamic_shapes=dynamic_shapes)
81348134

81358135
@testing.expectedFailureRetraceabilityNonStrict
8136-
@testing.expectedFailureCppSerDes # dynamic shape serialization
81378136
def test_disable_forced_specializations_ok(self):
81388137
# check that we don't force specialization, and defer to runtime asserts
81398138
# with allow_complex_guards_as_runtime_asserts=True to successfully export
@@ -9192,9 +9191,7 @@ def forward(self, input1: torch.Tensor):
91929191
inps = (torch.randn(1, 224, 768, device="cpu"),)
91939192
export(Foo(), inps)
91949193

9195-
@testing.expectedFailureSerDer # TODO(pianpwk): PowByNatural valuerange deserialization
9196-
@testing.expectedFailureCppSerDes # TODO(pianpwk): PowByNatural valuerange deserialization
9197-
@testing.expectedFailureSerDerNonStrict
9194+
@testing.expectedFailureCppSerDes
91989195
@testing.expectedFailureRetraceabilityNonStrict
91999196
def test_dim_dynamic(self):
92009197
dynamic = Dim.DYNAMIC

torch/_export/serde/schema.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch._export.serde.union import _Union
99

1010
# NOTE: Please update this value if any modifications are made to the schema
11-
SCHEMA_VERSION = (8, 1)
11+
SCHEMA_VERSION = (8, 2)
1212
TREESPEC_VERSION = 1
1313

1414

@@ -62,13 +62,42 @@ class SymExprHint(_Union):
6262
as_bool: bool
6363

6464

65+
# A leaf node in a SymExprNode, containing a bool/float/int/sympy.Symbol.
66+
@dataclass(repr=False)
67+
class SymBase(_Union):
68+
as_bool: bool
69+
as_float: float
70+
as_int: int
71+
as_symbol: str
72+
73+
74+
# Represents an AST node in a sympy.Expr.
75+
# If not a leaf node, "target" is a string representing the operator,
76+
# and "args" is a list of child SymExprNodes.
77+
# If a leaf node, "target" is None, "args" is empty, and "base" is a SymBase
78+
# representing the leaf value.
79+
@dataclass(repr=False)
80+
class SymExprNode:
81+
args: List["SymExprNode"] = field(default_factory=list)
82+
target: Optional[str] = None
83+
base: Optional[SymBase] = None
84+
85+
6586
# This is for storing the symbolic expressions behind symints/symfloats/symbools
66-
# For example, we can get something like
87+
# The deprecated "expr_str" field is easier to explain; we could store
6788
# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4)
68-
# if we also have the hint that s0 and s1 are both 2.
89+
# for an expression where s0 == s1 == 2.
90+
# We're moving away from expr_str for roundtrippability, and now deserialize into
91+
# the "expr_ast" field, which is a tree representation of the expression,
92+
# containing a root SymExprNode.
93+
# While we're deprecating this, we'll store an empty string in "expr_str" for now,
94+
# and support deserialization for both "expr_str" and "expr_ast" fields.
95+
# We'll only serialize to "expr_ast".
96+
# TODO(pianpwk): implement upgrader & delete.
6997
@dataclass
7098
class SymExpr:
7199
expr_str: str
100+
expr_ast: Optional[SymExprNode] = None
72101
hint: Optional[SymExprHint] = None
73102

74103

torch/_export/serde/schema.yaml

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# @generated by update_schema.py
2-
# checksum<<8e27d48014d4ec1c773aef056c0c20b61bead54be8338e95d3347d3422472b9a>>
2+
# checksum<<74ae9f550efb42873fc58f07990cefe4da35d38c673beeacbe1e20808d9a6962>>
33
Argument:
44
kind: union
55
fields:
@@ -346,6 +346,17 @@ SchemaVersion:
346346
type: int
347347
minor:
348348
type: int
349+
SymBase:
350+
kind: union
351+
fields:
352+
as_bool:
353+
type: bool
354+
as_float:
355+
type: float
356+
as_int:
357+
type: int
358+
as_symbol:
359+
type: str
349360
SymBool:
350361
kind: union
351362
fields:
@@ -365,6 +376,9 @@ SymExpr:
365376
fields:
366377
expr_str:
367378
type: str
379+
expr_ast:
380+
type: Optional[SymExprNode]
381+
default: None
368382
hint:
369383
type: Optional[SymExprHint]
370384
default: None
@@ -377,6 +391,18 @@ SymExprHint:
377391
type: float
378392
as_bool:
379393
type: bool
394+
SymExprNode:
395+
kind: struct
396+
fields:
397+
args:
398+
type: List[SymExprNode]
399+
default: '[]'
400+
target:
401+
type: Optional[str]
402+
default: None
403+
base:
404+
type: Optional[SymBase]
405+
default: None
380406
SymInt:
381407
kind: union
382408
fields:
@@ -437,5 +463,5 @@ UserOutputSpec:
437463
type: Argument
438464
SCHEMA_VERSION:
439465
- 8
440-
- 1
466+
- 2
441467
TREESPEC_VERSION: 1

torch/_export/serde/schema_check.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import re
66
import typing
77
from enum import IntEnum
8-
from typing import Any, Dict, List, Optional, Tuple, Union
8+
from typing import Any, Dict, ForwardRef, List, Optional, Tuple, Union
99

1010
from torch._export.serde import schema
1111
from torch._export.serde.union import _Union
@@ -71,6 +71,8 @@ def dump_type(t) -> Tuple[str, str]:
7171
)
7272
elif t == ():
7373
return "()", ""
74+
elif isinstance(t, ForwardRef):
75+
return t.__forward_arg__, f"ForwardRef<{t.__forward_arg__}>"
7476
else:
7577
raise AssertionError(f"Type {t} is not supported in export schema.")
7678

0 commit comments

Comments
 (0)