Skip to content

[Operator] Add new operator "normalize" that makes a group of layers (layer norm, group norm and instance norm) faster using hidet script#257

Merged
xinli-git merged 14 commits intohidet-org:mainfrom
xinli-git:normalize
May 29, 2023

Conversation

@xinli-git
Copy link
Copy Markdown
Contributor

@xinli-git xinli-git commented May 29, 2023

Refactor the normalize op to use a dedicated task and hidet script implementation

3 main optimizations

  • welford reduce: this is a one-pass algorithm that only reads from global memory once and writes it back once. Compared to the previous approach, this will reduce global memory by at least 2x.
  • fused mean var calculation and norm into a single kernel -> reduce redundant global memory access
  • 2 stage reduce (1024 -> 32 -> 1) for larger reduction sizes

Before

Compiling cuda task subtract(x=float32(1, 128, 768), y=float32(1, 128, 1), z=float32(1, 128, 768))...
Compiling cuda task reduce_avg(x=float32(1, 128, 768), y=float32(1, 128, 1), dims=[2], keep_dim=True, reduce_type='avg', accumulate_dtype='float32')...
Compiling cuda task fused(x=float32(1, 128, 768), y=float32(1, 128, 1), fused_ops='square reduce_avg adds rsqrt', anchor='reduce_avg')...
Compiling cuda task fused(y=float32(768,), x=float32(1, 128, 768), y=float32(1, 128, 1), z=float32(1, 128, 768), fused_ops='mul mul', anchor='mul')...
Compiling cuda task fused(y=float32(768,), y=float32(768,), x=float32(1, 128, 768), y=float32(1, 128, 1), z=float32(1, 128, 768), fused_ops='mul mul add', anchor='add')...

now:

Compiling cuda task normalize_float32(x=float32(1, 128, 768), y=float32(1, 128, 768), dims=[2], accumulate_dtype='float32', epsilon=1e-12f)...

I have a suspicion that this can somehow also work with epilogue fusion enabled, but probably need to discuss further with Yaoyao. The use case is that usually there is an affine transformation after the normalization.

I tried the benchmark on bert-base-uncased (resnet is unaffected), there was not a huge improvement in performance, I think this model is not heavily limited by layernorm. I'll look at some micro benchmarking soon.

hidet bench --cache-dir ./test_cache --space 2 --dtype float32 --report bert-seq128-f32.txt --tensor-core nlp --seq-length 128 --models bert-base-uncased
...
| model                   | inputs             |   eager |   reduce-overhead |   max-autotune |   hidet(2) |

|-------------------------|--------------------|---------|-------------------|----------------|------------|

| model/bert-base-uncased | f32, bs=1, seq=128 |   2.863 |             2.715 |          2.815 |      2.023 |

and

hidet bench --cache-dir ./test_cache --space 2 --dtype float16 --report bert-seq128-f16.txt --tensor-core nlp --seq-length 128 --models bert-base-uncased
| model                   | inputs             |   eager |   reduce-overhead |   max-autotune |   hidet(2) |

|-------------------------|--------------------|---------|-------------------|----------------|------------|

| model/bert-base-uncased | f16, bs=1, seq=128 |   1.710 |             1.701 |          1.097 |      1.521 |

As compared to main

FP32
| model                   | inputs             |   eager |   reduce-overhead |   max-autotune |   hidet(2) |

|-------------------------|--------------------|---------|-------------------|----------------|------------|

| model/bert-base-uncased | f32, bs=1, seq=128 |   3.043 |             2.858 |          2.779 |      2.076 |

| model                   | inputs             |   eager |   reduce-overhead |   max-autotune |   hidet(2) |

FP16

|-------------------------|--------------------|---------|-------------------|----------------|------------|

| model/bert-base-uncased | f16, bs=1, seq=128 |   1.704 |             1.671 |          1.098 |      1.573 |

Copy link
Copy Markdown
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @xinli-git ! Great job!

I left some comments, and we can discuss them offline. This PR itself is ready be to merged after passing the CI.

def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]:
dims = op.attrs['dims']
x: Tensor = op.inputs[0]
if not is_contiguous_norm(dims, len(x.shape)):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's discuss offline whether it is possible to implement the version that support incontiguous dimentions.

count_a[0] = count

@hidet.script
def norm_kernel(x: f16[x.const_shape], y: f16[y.const_shape]):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to read float16x2 at a time in the future instead of single float16, the former one would have better efficiency.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, this approach currently only loads 64 bytes coalesced. I will add it

from .norm import NormalizeTask


class NormalizeF16Task(NormalizeTask):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure which one is better: implement a fp16 version seperately vs. implement the fp16 & fp32 version in the same template.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with 2xf16 it is probably better to keep them seperate. I will add it.

Otherwise most of the logic is the same and we can basically template on <input dtype, accumulate dtype>

@xinli-git
Copy link
Copy Markdown
Contributor Author

Hi Yaoyao, I will merge this for now and leave the performance enhancements in a separate PR

@xinli-git xinli-git merged commit 13217c7 into hidet-org:main May 29, 2023
@xinli-git xinli-git deleted the normalize branch May 29, 2023 15:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants