|
8 | 8 | from torch._export.serde.union import _Union |
9 | 9 |
|
10 | 10 | # NOTE: Please update this value if any modifications are made to the schema |
11 | | -SCHEMA_VERSION = (8, 1) |
| 11 | +SCHEMA_VERSION = (8, 2) |
12 | 12 | TREESPEC_VERSION = 1 |
13 | 13 |
|
14 | 14 |
|
@@ -62,13 +62,42 @@ class SymExprHint(_Union): |
62 | 62 | as_bool: bool |
63 | 63 |
|
64 | 64 |
|
| 65 | +# A leaf node in a SymExprNode, containing a bool/float/int/sympy.Symbol. |
| 66 | +@dataclass(repr=False) |
| 67 | +class SymBase(_Union): |
| 68 | + as_bool: bool |
| 69 | + as_float: float |
| 70 | + as_int: int |
| 71 | + as_symbol: str |
| 72 | + |
| 73 | + |
| 74 | +# Represents an AST node in a sympy.Expr. |
| 75 | +# If not a leaf node, "target" is a string representing the operator, |
| 76 | +# and "args" is a list of child SymExprNodes. |
| 77 | +# If a leaf node, "target" is None, "args" is empty, and "base" is a SymBase |
| 78 | +# representing the leaf value. |
| 79 | +@dataclass(repr=False) |
| 80 | +class SymExprNode: |
| 81 | + args: List["SymExprNode"] = field(default_factory=list) |
| 82 | + target: Optional[str] = None |
| 83 | + base: Optional[SymBase] = None |
| 84 | + |
| 85 | + |
65 | 86 | # This is for storing the symbolic expressions behind symints/symfloats/symbools |
66 | | -# For example, we can get something like |
| 87 | +# The deprecated "expr_str" field is easier to explain; we could store |
67 | 88 | # SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4) |
68 | | -# if we also have the hint that s0 and s1 are both 2. |
| 89 | +# for an expression where s0 == s1 == 2. |
| 90 | +# We're moving away from expr_str for roundtrippability, and now deserialize into |
| 91 | +# the "expr_ast" field, which is a tree representation of the expression, |
| 92 | +# containing a root SymExprNode. |
| 93 | +# While we're deprecating this, we'll store an empty string in "expr_str" for now, |
| 94 | +# and support deserialization for both "expr_str" and "expr_ast" fields. |
| 95 | +# We'll only serialize to "expr_ast". |
| 96 | +# TODO(pianpwk): implement upgrader & delete. |
69 | 97 | @dataclass |
70 | 98 | class SymExpr: |
71 | 99 | expr_str: str |
| 100 | + expr_ast: Optional[SymExprNode] = None |
72 | 101 | hint: Optional[SymExprHint] = None |
73 | 102 |
|
74 | 103 |
|
|
0 commit comments