C++ API parity: MultiheadAttention#27309
C++ API parity: MultiheadAttention#27309pbelevich wants to merge 65 commits intogh/pbelevich/30/basefrom
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
Differential Revision: [D17766736](https://our.internmc.facebook.com/intern/diff/D17766736) [ghstack-poisoned]
Differential Revision: [D17766736](https://our.internmc.facebook.com/intern/diff/D17766736) [ghstack-poisoned]
Differential Revision: [D17766736](https://our.internmc.facebook.com/intern/diff/D17766736) [ghstack-poisoned]
Differential Revision: [D17766736](https://our.internmc.facebook.com/intern/diff/D17766736) [ghstack-poisoned]
Differential Revision: [D17766736](https://our.internmc.facebook.com/intern/diff/D17766736) [ghstack-poisoned]
Differential Revision: [D17766736](https://our.internmc.facebook.com/intern/diff/D17766736) [ghstack-poisoned]
Differential Revision: [D17766736](https://our.internmc.facebook.com/intern/diff/D17766736) [ghstack-poisoned]
Differential Revision: [D17766736](https://our.internmc.facebook.com/intern/diff/D17766736) [ghstack-poisoned]
Differential Revision: [D17766736](https://our.internmc.facebook.com/intern/diff/D17766736) [ghstack-poisoned]
Differential Revision: [D17766736](https://our.internmc.facebook.com/intern/diff/D17766736) [ghstack-poisoned]
Differential Revision: [D17766736](https://our.internmc.facebook.com/intern/diff/D17766736) [ghstack-poisoned]
CircleCI build failures summaryAs of commit d1aadb6:
Detailed failure analysisOne may explore the probable reasons each build failed interactively on the Dr. CI website. 3 upstream failures recognized by patterns:These builds matched patterns, but were probably caused by upstream breakages:
This comment was automatically generated by Dr. CI. Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 2 times. |
Differential Revision: [D17766736](https://our.internmc.facebook.com/intern/diff/D17766736) [ghstack-poisoned]
Differential Revision: [D17766736](https://our.internmc.facebook.com/intern/diff/D17766736) [ghstack-poisoned]
Differential Revision: [D17766736](https://our.internmc.facebook.com/intern/diff/D17766736) [ghstack-poisoned]
| TORCH_ARG(bool, add_zero_attn) = false; | ||
|
|
||
| /// total number of features in key. Default: c10::nullopt. | ||
| TORCH_ARG(int64_t, kdim); |
There was a problem hiding this comment.
where is it declared as an optional variable? I am thinking if we should change it to:
TORCH_ARG(c10::optional<int64_t>, kdim) = c10::nullopt;
| TORCH_ARG(int64_t, kdim); | ||
|
|
||
| /// total number of features in key. Default: c10::nullopt. | ||
| TORCH_ARG(int64_t, vdim); |
There was a problem hiding this comment.
same as above
TORCH_ARG(c10::optional<int64_t>, vdim) = c10::nullopt;
Differential Revision: [D17766736](https://our.internmc.facebook.com/intern/diff/D17766736) [ghstack-poisoned]
| } | ||
|
|
||
| void MultiheadAttentionImpl::reset() { | ||
| _qkv_same_embed_dim = options.kdim() == options.embed_dim() && |
There was a problem hiding this comment.
and after making kdim and vdim optional in options/activation.h, we can add code corresponding to these lines in activation.py
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
if(options.kdim() != c10::nullopt) { options.kdim(options.embed_dim()); }
if(options.vdim() != c10::nullopt) { options.vdim(options.embed_dim()); }
|
@pbelevich merged this pull request in 47766e6. |
Stack from ghstack:
Differential Revision: D17766736