Preserve format tag 'any' in tensor::desc::to_grouped#291
Preserve format tag 'any' in tensor::desc::to_grouped#291yanbing-j merged 2 commits intointel:ideep_pytorchfrom
Conversation
…backward_weights for group > 1
jgong5
left a comment
There was a problem hiding this comment.
Is it better to fix to_grouped function implementation instead? The format any should be preserved in it?
Hi @jgong5. Thanks for your review. Unfortunately, the original format tag cannot be obtained due to limitation of oneDNN |
I saw |
The /// Memory format kind
enum class format_kind {
/// Undefined memory format kind, used for empty memory descriptors.
undef = dnnl_format_kind_undef,
/// A special format kind that indicates that the actual format will be
/// selected by a primitive automatically.
any = dnnl_format_kind_any,
/// A tensor in a generic format described by the stride and blocking
/// values in each dimension.
blocked = dnnl_blocked,
#ifdef DNNL_EXPERIMENTAL_SPARSE
/// Format kind for sparse tensors.
sparse = dnnl_format_kind_sparse,
#endif
/// A special format kind that indicates that tensor format is opaque.
opaque = dnnl_format_kind_opaque,
};It does not return the actual format tag. We can get |
We only need to take care of |
In this case, yes. But if we want |
I don't think it makes sense to ask |
Did you mean that inside |
Yes. |
Ok. Thanks. I have updated the PR to preserve format tag |
|
Hi @yanbing-j Could you please help merge this PR? Thanks. |
|
Please port to other branches if needed. @Xia-Weiwen |
* Convert weight desct to format any for conv backward_data and deconv backward_weights for group > 1 * Preserve format tag 'any' in tensor::desc::to_grouped
Description
This is a bug fix. Weight desc format should be
anyto query primitive desc. For convolution/deconvolution with group > 1,tensor::desc::to_groupedis called to created a grouped desc for weight.Previously,
to_groupdid not preserve the original format tag but return a new desc with plain format tag. So, in the following cases, where grouped weight desc is not converted to format any explicitly, the weight format for computation is wrong and oneDNN ref:any path is used:This PR fixes the issue by preserving format tag
anyinto_grouped.Test results
Convolution with group > 1 backward_data
test code:
before
FP32
onednn_verbose,primitive,exec,cpu,convolution,ref:any,backward_data,src_f32::blocked:acdb::f0 wei_f32::blocked:abcde::f0 bia_undef::undef::: dst_f32::blocked:acdb::f0,attr-scratchpad:user ,alg:convolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,40.8711
BF16
onednn_verbose,primitive,exec,cpu,convolution,ref:any,backward_data,src_bf16::blocked:acdb::f0 wei_bf16::blocked:abcde::f0 bia_undef::undef::: dst_bf16::blocked:acdb::f0,attr-scratchpad:user ,alg:convolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,68.365
after
FP32
onednn_verbose,primitive,exec,cpu,convolution,jit_dw:avx512_core,backward_data,src_f32::blocked:acdb::f0 wei_f32:a:blocked:Abcde16a::f0 bia_undef::undef::: dst_f32::blocked:acdb::f0,attr-scratchpad:user ,alg:convolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,0.610107
BF16
onednn_verbose,primitive,exec,cpu,convolution,jit_dw:avx512_core_bf16,backward_data,src_bf16::blocked:acdb::f0 wei_bf16:a:blocked:Abcde16a::f0 bia_undef::undef::: dst_bf16::blocked:acdb::f0,attr-scratchpad:user ,alg:convolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,0.689941
Deconvolution with group > 1 backward_weights
test code:
before
FP32
onednn_verbose,primitive,exec,cpu,deconvolution,conv:any+ref:any,backward_weights,src_f32::blocked:acdb::f0 wei_f32::blocked:abcde::f0 bia_f32:a:blocked:a::f0 dst_f32::blocked:acdb::f0,attr-scratchpad:user ,alg:deconvolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,3.62012
BF16
onednn_verbose,primitive,exec,cpu,deconvolution,conv:any+ref:any,backward_weights,src_bf16::blocked:acdb::f0 wei_bf16::blocked:abcde::f0 bia_bf16:a:blocked:a::f0 dst_bf16::blocked:acdb::f0,attr-scratchpad:user ,alg:deconvolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,7.23511
after
FP32
onednn_verbose,primitive,exec,cpu,deconvolution,conv:any+jit_dw:avx512_core,backward_weights,src_f32::blocked:acdb::f0 wei_f32:a:blocked:Abcde16a::f0 bia_f32:a:blocked:a::f0 dst_f32::blocked:acdb::f0,attr-scratchpad:user ,alg:deconvolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,0.391113
BF16
onednn_verbose,primitive,exec,cpu,deconvolution,conv:any+jit_dw:avx512_core_bf16,backward_weights,src_bf16::blocked:acdb::f0 wei_bf16:a:blocked:Abcde16a::f0 bia_bf16:a:blocked:a::f0 dst_bf16::blocked:acdb::f0,attr-scratchpad:user ,alg:deconvolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,0.580078
@jgong5 @yanbing-j Please review. Thanks.