vulkan: f16 mixed-precision state for GATED_DELTA_NET#20376
vulkan: f16 mixed-precision state for GATED_DELTA_NET#20376ProgenyAlpha wants to merge 1 commit intoggml-org:masterfrom
Conversation
|
Please rebase and resolve the conflicts. |
Working on this now. |
|
Btw, I'm not super confident that this cast is safe in terms of quality. Since this is a recurrent state, even small deviations can accumulate to large errors with time. |
Do you know if the arithmetic could be moved to fp16 if the state stays in fp32? Is there a good way to find out what is safe and what isn't? |
This was more of a first pass to see if the f16 state approach was worth pursuing at all (the register pressure reduction did give measurable PP gains on my 890M). But given that CUDA and Metal both keep state in f32 for this op, I wanted to do more homework before committing to this path, anyway. Next steps from my end:
I can also close out the PR if you prefer? |
be06ee2 to
09c6d3b
Compare
|
To reduce the register pressure, implement the sharded approach as demonstrated in #20391 and #20361.
@0cc4m I'm not sure - don't have a good intuition about the recurrent state yet. The change could be fine - I'm just not really sure. |
I'll look into the sharded approach from #20391 for Vulkan. I had sharding on my todo list already, but held off on opening another PR due to the new policy. Will do some benchmarks and validity testing first, so I'm not wasting your time. May end up closing this one out depending on results. |
|
My hunch is that spreading the values across more invocations and/or shared memory will be better. The "shape" of the algorithm is similar enough to ssm_scan that it seems like the same techniques should work. |
Add subgroup-sharded GATED_DELTA_NET kernel that distributes state columns across subgroup lanes (2 regs/lane on wave64 vs 128 regs/thread). Uses subgroupAdd() for reductions with shared memory fallback. Also add f16 arithmetic variant (f32 state, f16 dot products) for precision comparison testing against f16 state variant. Four GDN pipeline variants now available for benchmarking: - f32 baseline (existing) - f16 state / f32 arithmetic (existing, PR ggml-org#20376) - f32 state / f16 arithmetic (new, flip variant) - sharded f32 (new, preferred when S_V >= subgroup_size)
Follow-up to #20334. Splits out the f16 mixed-precision state optimization into its own PR per @0cc4m's feedback.
Stores the 128-element state array in
float16_t, keeps all arithmetic infloat32. No precision loss (13/13 backend-ops tests passing). Lower register pressure gives a measurable PP boost.Depends on #20334
890M benchmarks (Qwen3-Coder-Next REAM Q4_K_M):
f16 pipeline auto-selects when the device supports
shaderFloat16, falls back to f32 otherwise.