Skip to content

Commit 11c786d

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
[BE] Make maybe_aliasing_or_mutating proper tag (#131990)
For better tracking, we need to make maybe aliasing/mutating ops with proper tag. We need to special case native_batch_norm because it is not a CIA but has a wrong schema. I guess native_batch_norm will be removed at some point, so until then we just keep it around. D60347117 Pull Request resolved: #131990 Approved by: https://github.com/bdhirsh
1 parent c513f01 commit 11c786d

7 files changed

Lines changed: 59 additions & 32 deletions

File tree

aten/src/ATen/native/native_functions.yaml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -312,25 +312,25 @@
312312
- func: _shape_as_tensor(Tensor self) -> Tensor
313313

314314
- func: dropout(Tensor input, float p, bool train) -> Tensor
315-
tags: nondeterministic_seeded
315+
tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
316316

317317
- func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
318318
tags: nondeterministic_seeded
319319

320320
- func: feature_dropout(Tensor input, float p, bool train) -> Tensor
321-
tags: nondeterministic_seeded
321+
tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
322322

323323
- func: feature_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
324324
tags: nondeterministic_seeded
325325

326326
- func: alpha_dropout(Tensor input, float p, bool train) -> Tensor
327-
tags: nondeterministic_seeded
327+
tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
328328

329329
- func: alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
330330
tags: nondeterministic_seeded
331331

332332
- func: feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor
333-
tags: nondeterministic_seeded
333+
tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
334334

335335
- func: feature_alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
336336
tags: nondeterministic_seeded
@@ -480,7 +480,7 @@
480480

481481
- func: conj_physical(Tensor self) -> Tensor
482482
variants: function, method
483-
tags: pointwise
483+
tags: [pointwise, maybe_aliasing_or_mutating]
484484

485485
- func: conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
486486
dispatch:
@@ -1035,17 +1035,20 @@
10351035

10361036
- func: atleast_1d(Tensor self) -> Tensor
10371037
variants: function
1038+
tags: maybe_aliasing_or_mutating
10381039

10391040
- func: atleast_1d.Sequence(Tensor[] tensors) -> Tensor[]
10401041

10411042
- func: atleast_2d(Tensor self) -> Tensor
10421043
variants: function
1044+
tags: maybe_aliasing_or_mutating
10431045

10441046
- func: atleast_2d.Sequence(Tensor[] tensors) -> Tensor[]
10451047
variants: function
10461048

10471049
- func: atleast_3d(Tensor self) -> Tensor
10481050
variants: function
1051+
tags: maybe_aliasing_or_mutating
10491052

10501053
- func: atleast_3d.Sequence(Tensor[] tensors) -> Tensor[]
10511054
variants: function
@@ -1079,13 +1082,15 @@
10791082
autogen: bartlett_window.periodic_out
10801083

10811084
- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor
1085+
tags: maybe_aliasing_or_mutating
10821086

10831087
- func: quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor
10841088
dispatch:
10851089
QuantizedCPU: quantized_batch_norm
10861090
autogen: quantized_batch_norm.out
10871091

10881092
- func: _batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)
1093+
tags: maybe_aliasing_or_mutating
10891094

10901095
- func: _batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)
10911096

@@ -1468,6 +1473,7 @@
14681473
variants: function, method
14691474
device_check: NoCheck
14701475
device_guard: False
1476+
tags: maybe_aliasing_or_mutating
14711477

14721478
- func: chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]
14731479
variants: function, method
@@ -7758,6 +7764,7 @@
77587764

77597765
- func: cartesian_prod(Tensor[] tensors) -> Tensor
77607766
variants: function
7767+
tags: maybe_aliasing_or_mutating
77617768

77627769
- func: combinations(Tensor self, int r=2, bool with_replacement=False) -> Tensor
77637770
variants: function

aten/src/ATen/native/tags.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,9 @@
7272
Pointwise operators are operators where each element of the output is computed only by accessing
7373
the corresponding element of all the broadcasted inputs. The output shape will be the broadcasted
7474
shape of the inputs.
75+
- tag: maybe_aliasing_or_mutating
76+
desc: |
77+
For some ops, we can't statically determine whether the op is functional or not. Note that this is only
78+
relevant to CIA ops that decompose before functionalization/autograd. It is useful to
79+
know this information for export as we would want to decompose these ops as they are unsafe to be
80+
preserved.

torch/_decomp/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"register_decomposition",
3232
"get_decompositions",
3333
"core_aten_decompositions",
34-
"_special_op_to_preserve_cia",
34+
"_should_decompose_because_unsafe_op",
3535
]
3636

3737
_T = TypeVar("_T")
@@ -48,6 +48,24 @@
4848
meta_table = global_decomposition_table["meta"]
4949

5050

51+
def _should_decompose_because_unsafe_op(op: torch._ops.OperatorBase) -> bool:
52+
"""
53+
Returns True if the op must always decompose in export/compile tracing system
54+
55+
In export, we always decompose certain CIA ops that are tagged with
56+
maybe_aliasing_or_mutating because we statically need to know if the op is
57+
mutating or not. But these CIA ops could have different behaviour in runtime.
58+
59+
native_batch_norm is a prim op which has a wrong schema and it needs to be replaced
60+
with correct schema. But until then, we will force decompose it via this tag.
61+
"""
62+
if not isinstance(op, torch._ops.OpOverload):
63+
return False
64+
if torch.Tag.maybe_aliasing_or_mutating in op.tags:
65+
return True
66+
return op == torch.ops.aten.native_batch_norm.default
67+
68+
5169
def _add_op_to_registry(registry, op, fn):
5270
"""
5371
This is an internal API for adding an op to the decomposition table.

torch/_export/__init__.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
4646

4747
from .wrappers import _wrap_submodules
48+
from .utils import _materialize_cpp_cia_ops
4849

4950
log = logging.getLogger(__name__)
5051

@@ -169,9 +170,23 @@ def capture_pre_autograd_graph(
169170
# Do not decompose dropout for exported models, because in eval mode the dropout
170171
# op disappears from the graph, which makes it difficult to switch to train mode.
171172
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
173+
174+
# We force create native_batch_norm because the below materialization logic
175+
# only applies to CIA ops.
176+
maybe_aliasing_or_mutating_ops = [torch.ops.aten.native_batch_norm.default]
177+
178+
_materialize_cpp_cia_ops()
179+
180+
for op in torch.ops.aten:
181+
op_obj = getattr(torch.ops.aten, op)
182+
for overload in op_obj.overloads():
183+
op_overload = getattr(op_obj, overload)
184+
if torch.Tag.maybe_aliasing_or_mutating in op_overload.tags:
185+
maybe_aliasing_or_mutating_ops.append(op_overload)
186+
172187
decomp_table = {
173188
op: op.decompose
174-
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
189+
for op in maybe_aliasing_or_mutating_ops
175190
if op != torch.ops.aten.dropout.default
176191
}
177192
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():

torch/_export/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,10 +1036,10 @@ def _special_op_to_preserve_cia(*args, **kwargs):
10361036
# 1. The op should be known statically that it is functional
10371037
# 2. If it is maybe aliasing, we decompose because we must know if an op
10381038
# is mutating or aliasing.
1039-
# TODO (tmanlaibaatar) make this utility function and share it with functional_tensor
1040-
# decomp part. (https://github.com/pytorch/pytorch/issues/129431)
10411039
def _check_valid_to_preserve(op_overload: "OperatorBase"):
1042-
if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops:
1040+
from torch._decomp import _should_decompose_because_unsafe_op
1041+
1042+
if _should_decompose_because_unsafe_op(op_overload):
10431043
return False
10441044
if op_overload in FunctionalTensor.metadata_fns:
10451045
return False

torch/_subclasses/functional_tensor.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -91,26 +91,6 @@ class FunctionalTensor(torch.Tensor):
9191
torch.ops.prim.device.default, # type: ignore[has-type]
9292
]
9393

94-
# These are ops that claim to be functional, but actually are maybe-mutating/maybe-aliasing
95-
# TODO (tmanlaibaatar) make it a tag
96-
maybe_aliasing_or_mutating_ops = [
97-
torch.ops.aten.dropout.default, # type: ignore[has-type]
98-
torch.ops.aten.batch_norm.default, # type: ignore[has-type]
99-
torch.ops.aten.native_batch_norm.default, # type: ignore[has-type]
100-
torch.ops.aten._batch_norm_impl_index.default, # type: ignore[has-type]
101-
torch.ops.aten.cudnn_batch_norm.default, # type: ignore[has-type]
102-
torch.ops.aten.miopen_batch_norm.default, # type: ignore[has-type]
103-
torch.ops.aten.atleast_1d.default, # type: ignore[has-type]
104-
torch.ops.aten.atleast_2d.default, # type: ignore[has-type]
105-
torch.ops.aten.atleast_3d.default, # type: ignore[has-type]
106-
torch.ops.aten.cartesian_prod.default, # type: ignore[has-type]
107-
torch.ops.aten.conj_physical.default, # type: ignore[has-type]
108-
torch.ops.aten.alpha_dropout.default, # type: ignore[has-type]
109-
torch.ops.aten.feature_dropout.default, # type: ignore[has-type]
110-
torch.ops.aten.feature_alpha_dropout.default, # type: ignore[has-type]
111-
torch.ops.aten.unsafe_chunk.default, # type: ignore[has-type]
112-
]
113-
11494
# Used by auto_functionalize to determine base of tensors during inference mode.
11595
_inference_mode_base: Optional["FunctionalTensor"] = None
11696

@@ -410,7 +390,9 @@ def _can_decompose(func):
410390
return False
411391

412392
# We unconditionally decompose ops that are maybe aliasing or mutating ops
413-
if func in FunctionalTensor.maybe_aliasing_or_mutating_ops:
393+
from torch._decomp import _should_decompose_because_unsafe_op
394+
395+
if _should_decompose_because_unsafe_op(func):
414396
return True
415397

416398
# (1) we unconditionally decompose maybe-aliasing or maybe-mutating ops,

torch/export/_trace.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1794,7 +1794,6 @@ def _produce_guards_callback(gm):
17941794
)
17951795

17961796

1797-
# TODO (tmanlaibaatar) We need to preserve aten.to here somehow
17981797
@_log_export_wrapper
17991798
@_disable_prexisiting_fake_mode
18001799
def _export_for_training(

0 commit comments

Comments
 (0)