Enable mixed type LayerNorm kernel for NSA indexer#12044
Enable mixed type LayerNorm kernel for NSA indexer#12044Fridge003 merged 4 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
@akhilg-nv Please add the nsys profile result showing how much time the three kernels were using before applying the change |
| if not self.elementwise_affine: | ||
| return self.forward_native(x) | ||
|
|
||
| if _flashinfer_layernorm_available and x.dtype == torch.bfloat16 and self.dtype == torch.float32: |
There was a problem hiding this comment.
I did some benchmarking and found for most cases I looked at, using the flashinfer kernel with weight = ones and bias = zeros is still faster than torch or torch compile. So I combined these execution paths (by default I now initialize weight/bias to ones/zeros). Ideally we can make the affine transform optional to the flashinfer kernel, but this shouldn't affect DSv3.2 perf anyway.
|
@Fridge003 Could you provide insight on resolving the failing CI/CD tests? I see errors that seem unrelated to my changes, like: and |
|
@akhilg-nv That's unrelated to this PR |
Motivation
Currently we cast input to layernorm to fp32 and use native Torch layernorm, then cast back. Instead we pull in a more efficient TRT-LLM kernel (via flashinfer) that supports mixed precision inputs.
Modifications
Add flashinfer layernorm kernel and apply in NSA indexer.
Accuracy Tests
sample results
nsys trace shows the new kernel being used:
Benchmarking and Profiling
Checklist