[Operator] Add new operator "normalize" that makes a group of layers (layer norm, group norm and instance norm) faster using hidet script#257
Conversation
…ementation WIP fix compute modify resolve wip
yaoyaoding
left a comment
There was a problem hiding this comment.
Thanks @xinli-git ! Great job!
I left some comments, and we can discuss them offline. This PR itself is ready be to merged after passing the CI.
| def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]: | ||
| dims = op.attrs['dims'] | ||
| x: Tensor = op.inputs[0] | ||
| if not is_contiguous_norm(dims, len(x.shape)): |
There was a problem hiding this comment.
Let's discuss offline whether it is possible to implement the version that support incontiguous dimentions.
| count_a[0] = count | ||
|
|
||
| @hidet.script | ||
| def norm_kernel(x: f16[x.const_shape], y: f16[y.const_shape]): |
There was a problem hiding this comment.
We might want to read float16x2 at a time in the future instead of single float16, the former one would have better efficiency.
There was a problem hiding this comment.
you are right, this approach currently only loads 64 bytes coalesced. I will add it
| from .norm import NormalizeTask | ||
|
|
||
|
|
||
| class NormalizeF16Task(NormalizeTask): |
There was a problem hiding this comment.
I am not sure which one is better: implement a fp16 version seperately vs. implement the fp16 & fp32 version in the same template.
There was a problem hiding this comment.
with 2xf16 it is probably better to keep them seperate. I will add it.
Otherwise most of the logic is the same and we can basically template on <input dtype, accumulate dtype>
|
Hi Yaoyao, I will merge this for now and leave the performance enhancements in a separate PR |
Refactor the normalize op to use a dedicated task and hidet script implementation
3 main optimizations
Before
now:
I have a suspicion that this can somehow also work with epilogue fusion enabled, but probably need to discuss further with Yaoyao. The use case is that usually there is an affine transformation after the normalization.
I tried the benchmark on bert-base-uncased (resnet is unaffected), there was not a huge improvement in performance, I think this model is not heavily limited by layernorm. I'll look at some micro benchmarking soon.
and
As compared to main