[ONNX] Add onnx::LayerNorm support for version 17#84293
[ONNX] Add onnx::LayerNorm support for version 17#84293titaiwangms wants to merge 13 commits intogh/AllenTiTaiWang/10/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 7aa8ac9 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
[ghstack-poisoned]
|
I would add a test |
We already have it. pytorch/test/onnx/test_pytorch_onnx_onnxruntime.py Lines 3676 to 3679 in 641c395 And this PR should pass it. |
torch/onnx/symbolic_opset17.py
Outdated
|
|
||
|
|
||
| @symbolic_helper.parse_args("v", "is", "v", "v", "f", "i") | ||
| def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): |
There was a problem hiding this comment.
one more note: Please annotate the type and decorate the function with _beartype.beartype
There was a problem hiding this comment.
offline meeting: more _beartype.beartype use case will be introduced in another PR from Justin. I have done annotation and beartype on shared layer norm function.
torch/onnx/symbolic_opset17.py
Outdated
|
|
||
|
|
||
| @symbolic_helper.parse_args("v", "is", "v", "v", "f", "i") | ||
| def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): |
There was a problem hiding this comment.
| def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): | |
| def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): |
cudnn_enable is a boolean. Since it is not used, we can annotate as "none" in parse args
There was a problem hiding this comment.
Should we use None or keep it as no conversion(use v) in this case?
There was a problem hiding this comment.
If "none" is a common usage, I will add it into doc of parse args.
torch/onnx/symbolic_opset17.py
Outdated
| weight, | ||
| bias, | ||
| epsilon_f=eps, | ||
| axis_i=len(normalized_shape), |
There was a problem hiding this comment.
Can you write a brief comment inline explaining what axis is and why len(normalized_shape) is the same thing?
There was a problem hiding this comment.
Good question, I think particularly in this test case, the value happens to match...
There was a problem hiding this comment.
I was referencing on this:
pytorch/torch/onnx/symbolic_opset9.py
Line 2309 in 14093b5
And according to https://github.com/onnx/onnx/blob/main/docs/Changelog.md#LayerNormalization-17
axis : int (default is -1)
The first normalization dimension. If rank(X) is r, axis' allowed range is [-r, r]. Negative value means counting dimensions from the back.
I tried different params for layer_norm and input, but it works so far. @BowenBao is there a test case you have on mind that might fail?
There was a problem hiding this comment.
What are the combinations you tried?
I suspect it should be axis_i = -len(normalized_shape). However for this current test case, both axis_i = 2 and axis_i = -2 maps to the same dimension. So I'd recommend trying alternate the rank
model = torch.nn.LayerNorm([10, 10, 10])
x = torch.randn(20, 5, 10, 10, 10)
self.run_test(model, x) [ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
| # As layer_norm works on the last D dimension, please keep | ||
| # this test case at least three dimension to prevent the | ||
| # situation of axis=2 mapping to the same axis as axis=-2 | ||
| model2 = torch.nn.LayerNorm([10, 10, 10]) |
There was a problem hiding this comment.
Any reason this is named model2? For clarity we can just do layer_norm_model = ... or just model
There was a problem hiding this comment.
Ahh, I was going to put two tests, and then only left one. I will clean it, and figure out the CI failure with beartype.
| return res | ||
|
|
||
|
|
||
| @_beartype.beartype |
There was a problem hiding this comment.
You may remove the beartype check for this one here. I can handle it with #84091
[ghstack-poisoned]
|
@AllenTiTaiWang you can get rid of the build CI issue by rebasing with latest master. |
[ghstack-poisoned]
[ghstack-poisoned]
torch/onnx/symbolic_opset17.py
Outdated
| ): | ||
| # normalized_shape: input shape from an expected input of size | ||
| # axis: The first normalization dimension. | ||
| # layer_norm normalizes on the last D dimension, |
There was a problem hiding this comment.
| # layer_norm normalizes on the last D dimension, | |
| # layer_norm normalizes on the last D dimensions, |
torch/onnx/symbolic_opset17.py
Outdated
| # normalized_shape: input shape from an expected input of size | ||
| # axis: The first normalization dimension. | ||
| # layer_norm normalizes on the last D dimension, | ||
| # which D is the size of normalized_shape |
There was a problem hiding this comment.
| # which D is the size of normalized_shape | |
| # where D is the size of normalized_shape |
[ghstack-poisoned]
|
@pytorchbot merge -g |
|
@pytorchbot successfully started a merge job. Check the current status here. |
|
Hey @AllenTiTaiWang. |
Summary: Pull Request resolved: #84293 Approved by: https://github.com/justinchuby, https://github.com/BowenBao Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/7c4c7dafbdf2c41ccd9042f1db4f9f9f01a42f00 Reviewed By: mehtanirav Differential Revision: D39277708 fbshipit-source-id: 3101337ef4d3ff564c5df604afcae4fd86e1a76d
Stack from ghstack (oldest at bottom):