Skip to content

[ONNX] Add onnx::LayerNorm support for version 17#84293

Closed
titaiwangms wants to merge 13 commits intogh/AllenTiTaiWang/10/basefrom
gh/AllenTiTaiWang/10/head
Closed

[ONNX] Add onnx::LayerNorm support for version 17#84293
titaiwangms wants to merge 13 commits intogh/AllenTiTaiWang/10/basefrom
gh/AllenTiTaiWang/10/head

Conversation

@titaiwangms
Copy link
Copy Markdown
Collaborator

@titaiwangms titaiwangms commented Aug 30, 2022

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Aug 30, 2022
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Aug 30, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 7aa8ac9 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@justinchuby
Copy link
Copy Markdown
Collaborator

I would add a test

@justinchuby justinchuby self-assigned this Aug 30, 2022
@titaiwangms
Copy link
Copy Markdown
Collaborator Author

titaiwangms commented Aug 30, 2022

I would add a test

We already have it.

def test_layer_norm(self):
model = torch.nn.LayerNorm([10, 10])
x = torch.randn(20, 5, 10, 10)
self.run_test(model, x)

And this PR should pass it.



@symbolic_helper.parse_args("v", "is", "v", "v", "f", "i")
def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one more note: Please annotate the type and decorate the function with _beartype.beartype

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offline meeting: more _beartype.beartype use case will be introduced in another PR from Justin. I have done annotation and beartype on shared layer norm function.



@symbolic_helper.parse_args("v", "is", "v", "v", "f", "i")
def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):

cudnn_enable is a boolean. Since it is not used, we can annotate as "none" in parse args

Copy link
Copy Markdown
Collaborator Author

@titaiwangms titaiwangms Aug 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use None or keep it as no conversion(use v) in this case?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If "none" is a common usage, I will add it into doc of parse args.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

weight,
bias,
epsilon_f=eps,
axis_i=len(normalized_shape),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you write a brief comment inline explaining what axis is and why len(normalized_shape) is the same thing?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I think particularly in this test case, the value happens to match...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was referencing on this:

axes = [-i for i in range(len(normalized_shape), 0, -1)]

And according to https://github.com/onnx/onnx/blob/main/docs/Changelog.md#LayerNormalization-17

axis : int (default is -1)
The first normalization dimension. If rank(X) is r, axis' allowed range is [-r, r]. Negative value means counting dimensions from the back.

I tried different params for layer_norm and input, but it works so far. @BowenBao is there a test case you have on mind that might fail?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the combinations you tried?
I suspect it should be axis_i = -len(normalized_shape). However for this current test case, both axis_i = 2 and axis_i = -2 maps to the same dimension. So I'd recommend trying alternate the rank

     model = torch.nn.LayerNorm([10, 10, 10]) 
     x = torch.randn(20, 5, 10, 10, 10) 
     self.run_test(model, x) 

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

titaiwangms added a commit that referenced this pull request Aug 31, 2022
ghstack-source-id: 7d6cd16
Pull Request resolved: #84293
titaiwangms added a commit that referenced this pull request Aug 31, 2022
ghstack-source-id: b1630e9
Pull Request resolved: #84293
# As layer_norm works on the last D dimension, please keep
# this test case at least three dimension to prevent the
# situation of axis=2 mapping to the same axis as axis=-2
model2 = torch.nn.LayerNorm([10, 10, 10])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason this is named model2? For clarity we can just do layer_norm_model = ... or just model

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, I was going to put two tests, and then only left one. I will clean it, and figure out the CI failure with beartype.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return res


@_beartype.beartype
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may remove the beartype check for this one here. I can handle it with #84091

@BowenBao
Copy link
Copy Markdown
Collaborator

BowenBao commented Sep 1, 2022

@AllenTiTaiWang you can get rid of the build CI issue by rebasing with latest master.

titaiwangms added a commit that referenced this pull request Sep 1, 2022
ghstack-source-id: 2dfc81e
Pull Request resolved: #84293
titaiwangms added a commit that referenced this pull request Sep 3, 2022
ghstack-source-id: 3f12925
Pull Request resolved: #84293
):
# normalized_shape: input shape from an expected input of size
# axis: The first normalization dimension.
# layer_norm normalizes on the last D dimension,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# layer_norm normalizes on the last D dimension,
# layer_norm normalizes on the last D dimensions,

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# normalized_shape: input shape from an expected input of size
# axis: The first normalization dimension.
# layer_norm normalizes on the last D dimension,
# which D is the size of normalized_shape
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# which D is the size of normalized_shape
# where D is the size of normalized_shape

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

titaiwangms added a commit that referenced this pull request Sep 4, 2022
ghstack-source-id: 4155b51
Pull Request resolved: #84293
@titaiwangms
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge -g

@titaiwangms titaiwangms added the module: onnx Related to torch.onnx label Sep 4, 2022
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a merge job. Check the current status here.
The merge job was triggered with the green (-g) flag. This means that your change will be merged once all checks on your PR have passed (ETA: 0-4 Hours). If this is not the intended behavior, feel free to use some of the other merge options in the wiki.
Please reach out to the PyTorch DevX Team with feedback or questions!

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Sep 4, 2022

Hey @AllenTiTaiWang.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Sep 7, 2022
Summary:
Pull Request resolved: #84293
Approved by: https://github.com/justinchuby, https://github.com/BowenBao

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/7c4c7dafbdf2c41ccd9042f1db4f9f9f01a42f00

Reviewed By: mehtanirav

Differential Revision: D39277708

fbshipit-source-id: 3101337ef4d3ff564c5df604afcae4fd86e1a76d
@facebook-github-bot facebook-github-bot deleted the gh/AllenTiTaiWang/10/head branch September 7, 2022 14:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants