Make representation computation branchless in TransformerEncoderBase#4818
Make representation computation branchless in TransformerEncoderBase#4818alexeib merged 2 commits intofacebookresearch:mainfrom
Conversation
|
cc @suo |
|
@dianaml0 do you mind taking a look at this? |
|
cc @alexeib |
There was a problem hiding this comment.
isn't has_pads a boolean? 'bool' object doesn't have 'type_as' attribute!
There was a problem hiding this comment.
has_pads is a Tensor object with a bool scalar, which could be converted by type_as method on Tensor type. That's my understanding.
There was a problem hiding this comment.
actually, if the device is "xla" for src_token, it would be a bool which is rare but possible.
Let me update the PR to handle that.
There was a problem hiding this comment.
@moslehpour updated. mind looking again?
There was a problem hiding this comment.
Great. This should work!
Just one last comment, can you confirm if torchscript support type casting with type_as?
There was a problem hiding this comment.
@moslehpour sure, I just tried that:
@torch.jit.script
def foo(x, y):
return x.type_as(y)
a = foo(torch.tensor(True), torch.ones(3, 2))
print(a)
and seems Torchscript can give us correct result:
tensor(1.)
There was a problem hiding this comment.
Perfect. Thanks for checking it.
There was a problem hiding this comment.
Perfect. Thanks for checking it.
Summary: We want to make the computation branchless here because fairseq code may be exported and traced for deployment purposes, and tracing mechanisms can break the correctness for a captured program if it's dependent on input data. In this diff we try to rewrite the code to remove one branch so that tracer can proceed here and preserve the correct semantics of the model. Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
|
thanks! |
|
the tests are failing now because of this change: Variable 'has_pads' is annotated with type Tensor but is being assigned to a value of type bool: |
@alexeib I'm looking into this issue, and I'll try to send in a fix real quick. |
* fix imports referencing moved metrics.py file * Make representation computation branchless in TransformerEncoderBase (#4818) Summary: We want to make the computation branchless here because fairseq code may be exported and traced for deployment purposes, and tracing mechanisms can break the correctness for a captured program if it's dependent on input data. In this diff we try to rewrite the code to remove one branch so that tracer can proceed here and preserve the correct semantics of the model. Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * Fix Torchscript typing in transformer_encoder.py (#4847) * Add Generative Spoken Dialogue Language Modeling (#4879) * Update deprecated torch.qr in glow.py example (#4685) torch.qr is deprecated for a long time and is being removed by pytorch/pytorch#70989. This PR makes the example compatible with new and old PyTorch versions. * Emotion Conversion Paper Open Source (#4895) * data2vec v2.0 (#4903) data2v2c 2.0 Co-authored-by: Arun Babu <arbabu@fb.com> Co-authored-by: Wei-Ning Hsu <wnhsu@csail.mit.edu> * remove missing config entries when loading task from checkpoint (#4905) * make apex optional (#4906) * Add file to generate manifests for stop dataset. (#4891) * Update STOP dataset README to include proper link. (#4892) * Update README.md (#4893) * using foreach to reduce kernel (#4904) * using foreach to reduce kernel * set reproducibility to looser threshold * revert optimzer * update * update * update * update * update * update * update Co-authored-by: juntengjia <juntengjia@fb.com> * Update README.md to add data2vec blog post (#4913) * Update README.md * Update config to fix circleci failure (#4949) https://app.circleci.com/pipelines/github/fairinternal/fairseq-py/12635/workflows/3befbae2-79c4-458d-9fc4-aad4484183b4/jobs/26767 * Generative Spoken Dialogue Language Modeling Paper Open Source (#4957) * wav2vec2_laser (#4968) * ASR BLEU tool copied from ust branch into main (#4914) * Add transcript option for asr-bleu (#4981) --------- Co-authored-by: zhxchen17 <zhxchen17@outlook.com> Co-authored-by: zhxchen17 <zhxchen17@fb.com> Co-authored-by: Nguyen Tu Anh <nguyentuanh208@gmail.com> Co-authored-by: Sergii Dymchenko <kit1980@gmail.com> Co-authored-by: Felix Kreuk <felixkreuk@gmail.com> Co-authored-by: Alexei Baevski <alexei.b@gmail.com> Co-authored-by: padentomasello <pdtomasello@gmail.com> Co-authored-by: Junteng Jia <juntengjia@hotmail.com> Co-authored-by: juntengjia <juntengjia@fb.com> Co-authored-by: arbabu123 <arbabu@fb.com> Co-authored-by: dianaml0 <82468439+dianaml0@users.noreply.github.com> Co-authored-by: Pierre Andrews <mortimer@fb.com> Co-authored-by: Ilia Kulikov <kulikov@cs.nyu.edu> Co-authored-by: Xutai Ma <xutaima@gmail.com>
…acebookresearch#4818) Summary: We want to make the computation branchless here because fairseq code may be exported and traced for deployment purposes, and tracing mechanisms can break the correctness for a captured program if it's dependent on input data. In this diff we try to rewrite the code to remove one branch so that tracer can proceed here and preserve the correct semantics of the model. Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
* fix imports referencing moved metrics.py file * Make representation computation branchless in TransformerEncoderBase (facebookresearch#4818) Summary: We want to make the computation branchless here because fairseq code may be exported and traced for deployment purposes, and tracing mechanisms can break the correctness for a captured program if it's dependent on input data. In this diff we try to rewrite the code to remove one branch so that tracer can proceed here and preserve the correct semantics of the model. Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * Fix Torchscript typing in transformer_encoder.py (facebookresearch#4847) * Add Generative Spoken Dialogue Language Modeling (facebookresearch#4879) * Update deprecated torch.qr in glow.py example (facebookresearch#4685) torch.qr is deprecated for a long time and is being removed by pytorch/pytorch#70989. This PR makes the example compatible with new and old PyTorch versions. * Emotion Conversion Paper Open Source (facebookresearch#4895) * data2vec v2.0 (facebookresearch#4903) data2v2c 2.0 Co-authored-by: Arun Babu <arbabu@fb.com> Co-authored-by: Wei-Ning Hsu <wnhsu@csail.mit.edu> * remove missing config entries when loading task from checkpoint (facebookresearch#4905) * make apex optional (facebookresearch#4906) * Add file to generate manifests for stop dataset. (facebookresearch#4891) * Update STOP dataset README to include proper link. (facebookresearch#4892) * Update README.md (facebookresearch#4893) * using foreach to reduce kernel (facebookresearch#4904) * using foreach to reduce kernel * set reproducibility to looser threshold * revert optimzer * update * update * update * update * update * update * update Co-authored-by: juntengjia <juntengjia@fb.com> * Update README.md to add data2vec blog post (facebookresearch#4913) * Update README.md * Update config to fix circleci failure (facebookresearch#4949) https://app.circleci.com/pipelines/github/fairinternal/fairseq-py/12635/workflows/3befbae2-79c4-458d-9fc4-aad4484183b4/jobs/26767 * Generative Spoken Dialogue Language Modeling Paper Open Source (facebookresearch#4957) * wav2vec2_laser (facebookresearch#4968) * ASR BLEU tool copied from ust branch into main (facebookresearch#4914) * Add transcript option for asr-bleu (facebookresearch#4981) --------- Co-authored-by: zhxchen17 <zhxchen17@outlook.com> Co-authored-by: zhxchen17 <zhxchen17@fb.com> Co-authored-by: Nguyen Tu Anh <nguyentuanh208@gmail.com> Co-authored-by: Sergii Dymchenko <kit1980@gmail.com> Co-authored-by: Felix Kreuk <felixkreuk@gmail.com> Co-authored-by: Alexei Baevski <alexei.b@gmail.com> Co-authored-by: padentomasello <pdtomasello@gmail.com> Co-authored-by: Junteng Jia <juntengjia@hotmail.com> Co-authored-by: juntengjia <juntengjia@fb.com> Co-authored-by: arbabu123 <arbabu@fb.com> Co-authored-by: dianaml0 <82468439+dianaml0@users.noreply.github.com> Co-authored-by: Pierre Andrews <mortimer@fb.com> Co-authored-by: Ilia Kulikov <kulikov@cs.nyu.edu> Co-authored-by: Xutai Ma <xutaima@gmail.com>
Summary:
We want to make the computation branchless here because fairseq code may be exported and traced for deployment purposes, and tracing mechanisms can break the correctness for a captured program if it's dependent on input data. In this diff we try to rewrite the code to remove one branch so that tracer can proceed here and preserve the correct semantics of the model.
Test Plan:
CI
Reviewers:
Subscribers:
Tasks:
Tags:
Before submitting
What does this PR do?
Fixes # (issue).
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃