Skip to content

Preserve format tag 'any' in tensor::desc::to_grouped#291

Merged
yanbing-j merged 2 commits intointel:ideep_pytorchfrom
Xia-Weiwen:w_desc_to_format_any
Mar 16, 2024
Merged

Preserve format tag 'any' in tensor::desc::to_grouped#291
yanbing-j merged 2 commits intointel:ideep_pytorchfrom
Xia-Weiwen:w_desc_to_format_any

Conversation

@Xia-Weiwen
Copy link
Copy Markdown
Contributor

@Xia-Weiwen Xia-Weiwen commented Mar 14, 2024

Description

This is a bug fix. Weight desc format should be any to query primitive desc. For convolution/deconvolution with group > 1, tensor::desc::to_grouped is called to created a grouped desc for weight.
Previously, to_group did not preserve the original format tag but return a new desc with plain format tag. So, in the following cases, where grouped weight desc is not converted to format any explicitly, the weight format for computation is wrong and oneDNN ref:any path is used:

  • Convolution backward_data when group > 1
  • Deconvolution backward_weights when group > 1

This PR fixes the issue by preserving format tag any in to_grouped.

Test results

Convolution with group > 1 backward_data

test code:

import torch
m = torch.nn.Conv2d(in_channels=960, out_channels=960, kernel_size=3, stride=1, padding=1, groups=960)
input = torch.randn(96, 960, 7, 7, requires_grad=True)
m.to(memory_format=torch.channels_last)
input.to(memory_format=torch.channels_last)
if use_bf16:
    m = m.to(torch.bfloat16)
    input = input.to(torch.bfloat16)
output = m(input)
loss = output.mean()
loss.backward()

before
FP32
onednn_verbose,primitive,exec,cpu,convolution,ref:any,backward_data,src_f32::blocked:acdb::f0 wei_f32::blocked:abcde::f0 bia_undef::undef::: dst_f32::blocked:acdb::f0,attr-scratchpad:user ,alg:convolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,40.8711
BF16
onednn_verbose,primitive,exec,cpu,convolution,ref:any,backward_data,src_bf16::blocked:acdb::f0 wei_bf16::blocked:abcde::f0 bia_undef::undef::: dst_bf16::blocked:acdb::f0,attr-scratchpad:user ,alg:convolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,68.365

after
FP32
onednn_verbose,primitive,exec,cpu,convolution,jit_dw:avx512_core,backward_data,src_f32::blocked:acdb::f0 wei_f32:a:blocked:Abcde16a::f0 bia_undef::undef::: dst_f32::blocked:acdb::f0,attr-scratchpad:user ,alg:convolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,0.610107

BF16
onednn_verbose,primitive,exec,cpu,convolution,jit_dw:avx512_core_bf16,backward_data,src_bf16::blocked:acdb::f0 wei_bf16:a:blocked:Abcde16a::f0 bia_undef::undef::: dst_bf16::blocked:acdb::f0,attr-scratchpad:user ,alg:convolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,0.689941

Deconvolution with group > 1 backward_weights

test code:

import torch
m = torch.nn.ConvTranspose2d(in_channels=960, out_channels=960, kernel_size=3, stride=1, padding=1, groups=960)
input = torch.randn(96, 960, 7, 7, requires_grad=True)
m.to(memory_format=torch.channels_last)
input.to(memory_format=torch.channels_last)
if use_bf16:
    m = m.to(torch.bfloat16)
    input = input.to(torch.bfloat16)
output = m(input)
loss = output.mean()
loss.backward()

before
FP32
onednn_verbose,primitive,exec,cpu,deconvolution,conv:any+ref:any,backward_weights,src_f32::blocked:acdb::f0 wei_f32::blocked:abcde::f0 bia_f32:a:blocked:a::f0 dst_f32::blocked:acdb::f0,attr-scratchpad:user ,alg:deconvolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,3.62012

BF16
onednn_verbose,primitive,exec,cpu,deconvolution,conv:any+ref:any,backward_weights,src_bf16::blocked:acdb::f0 wei_bf16::blocked:abcde::f0 bia_bf16:a:blocked:a::f0 dst_bf16::blocked:acdb::f0,attr-scratchpad:user ,alg:deconvolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,7.23511

after
FP32
onednn_verbose,primitive,exec,cpu,deconvolution,conv:any+jit_dw:avx512_core,backward_weights,src_f32::blocked:acdb::f0 wei_f32:a:blocked:Abcde16a::f0 bia_f32:a:blocked:a::f0 dst_f32::blocked:acdb::f0,attr-scratchpad:user ,alg:deconvolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,0.391113
BF16
onednn_verbose,primitive,exec,cpu,deconvolution,conv:any+jit_dw:avx512_core_bf16,backward_weights,src_bf16::blocked:acdb::f0 wei_bf16:a:blocked:Abcde16a::f0 bia_bf16:a:blocked:a::f0 dst_bf16::blocked:acdb::f0,attr-scratchpad:user ,alg:deconvolution_direct,g960mb96_ic960oc960_ih7oh7kh3sh1dh0ph1_iw7ow7kw3sw1dw0pw1,0.580078


@jgong5 @yanbing-j Please review. Thanks.

Copy link
Copy Markdown

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to fix to_grouped function implementation instead? The format any should be preserved in it?

@Xia-Weiwen
Copy link
Copy Markdown
Contributor Author

Is it better to fix to_grouped function implementation instead? The format any should be preserved in it?

Hi @jgong5. Thanks for your review. Unfortunately, the original format tag cannot be obtained due to limitation of oneDNN memory::desc API, so we don't know the original format. Then it leaves to the caller to preserve the format. This is also done in other cases, such as conv/deconv forward and conv backward_weights. The cases in this PR had been missed before.

@Xia-Weiwen Xia-Weiwen requested a review from jgong5 March 15, 2024 00:49
@jgong5
Copy link
Copy Markdown

jgong5 commented Mar 15, 2024

Is it better to fix to_grouped function implementation instead? The format any should be preserved in it?

Hi @jgong5. Thanks for your review. Unfortunately, the original format tag cannot be obtained due to limitation of oneDNN memory::desc API, so we don't know the original format. Then it leaves to the caller to preserve the format. This is also done in other cases, such as conv/deconv forward and conv backward_weights. The cases in this PR had been missed before.

I saw get_format_kind method of memory descriptor. It is not usable?

@Xia-Weiwen
Copy link
Copy Markdown
Contributor Author

I saw get_format_kind method of memory descriptor. It is not usable?

The get_format_kind returns the following enum class:

    /// Memory format kind
    enum class format_kind {
        /// Undefined memory format kind, used for empty memory descriptors.
        undef = dnnl_format_kind_undef,
        /// A special format kind that indicates that the actual format will be
        /// selected by a primitive automatically.
        any = dnnl_format_kind_any,
        /// A tensor in a generic format described by the stride and blocking
        /// values in each dimension.
        blocked = dnnl_blocked,
#ifdef DNNL_EXPERIMENTAL_SPARSE
        /// Format kind for sparse tensors.
        sparse = dnnl_format_kind_sparse,
#endif
        /// A special format kind that indicates that tensor format is opaque.
        opaque = dnnl_format_kind_opaque,
    };

It does not return the actual format tag. We can get any but no other specific format tags. So, we cannot preserve the format tag in most cases in to_grouped.

@jgong5
Copy link
Copy Markdown

jgong5 commented Mar 15, 2024

It does not return the actual format tag. We can get any but no other specific format tags. So, we cannot preserve the format tag in most cases in to_grouped.

We only need to take care of any but not other formats here, right?

@Xia-Weiwen
Copy link
Copy Markdown
Contributor Author

It does not return the actual format tag. We can get any but no other specific format tags. So, we cannot preserve the format tag in most cases in to_grouped.

We only need to take care of any but not other formats here, right?

In this case, yes. But if we want to_grouped to preserve format, I think it has to work in all cases not just for any, right?

@jgong5
Copy link
Copy Markdown

jgong5 commented Mar 15, 2024

It does not return the actual format tag. We can get any but no other specific format tags. So, we cannot preserve the format tag in most cases in to_grouped.

We only need to take care of any but not other formats here, right?

In this case, yes. But if we want to_grouped to preserve format, I think it has to work in all cases not just for any, right?

I don't think it makes sense to ask to_grouped to keep format right? The format with groups would be different if it is specified. But preserving any does make sense?

@Xia-Weiwen
Copy link
Copy Markdown
Contributor Author

In this case, yes. But if we want to_grouped to preserve format, I think it has to work in all cases not just for any, right?

I don't think it makes sense to ask to_grouped to keep format right? The format with groups would be different if it is specified. But preserving any does make sense?

Did you mean that inside to_grouped, we preserve any for any and return plain format for other formats?

@jgong5
Copy link
Copy Markdown

jgong5 commented Mar 15, 2024

In this case, yes. But if we want to_grouped to preserve format, I think it has to work in all cases not just for any, right?

I don't think it makes sense to ask to_grouped to keep format right? The format with groups would be different if it is specified. But preserving any does make sense?

Did you mean that inside to_grouped, we preserve any for any and return plain format for other formats?

Yes.

@Xia-Weiwen Xia-Weiwen changed the title Convert weight desct to format any for conv/deconv backward Preserve format tag 'any' in tensor::desc::to_grouped Mar 15, 2024
@Xia-Weiwen
Copy link
Copy Markdown
Contributor Author

In this case, yes. But if we want to_grouped to preserve format, I think it has to work in all cases not just for any, right?

I don't think it makes sense to ask to_grouped to keep format right? The format with groups would be different if it is specified. But preserving any does make sense?

Did you mean that inside to_grouped, we preserve any for any and return plain format for other formats?

Yes.

Ok. Thanks. I have updated the PR to preserve format tag any.

@Xia-Weiwen
Copy link
Copy Markdown
Contributor Author

Hi @yanbing-j Could you please help merge this PR? Thanks.

@yanbing-j yanbing-j merged commit 9fedc30 into intel:ideep_pytorch Mar 16, 2024
@yanbing-j
Copy link
Copy Markdown
Contributor

Please port to other branches if needed. @Xia-Weiwen

Xia-Weiwen added a commit to Xia-Weiwen/ideep that referenced this pull request Mar 16, 2024
* Convert weight desct to format any for conv backward_data and deconv backward_weights for group > 1

* Preserve format tag 'any' in tensor::desc::to_grouped
yanbing-j pushed a commit that referenced this pull request Mar 18, 2024
* Convert weight desct to format any for conv backward_data and deconv backward_weights for group > 1

* Preserve format tag 'any' in tensor::desc::to_grouped
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants