vulkan: Add State Space Model (SSM) Operations Support#16463
vulkan: Add State Space Model (SSM) Operations Support#164630cc4m merged 2 commits intoggml-org:masterfrom
Conversation
jeffbolznv
left a comment
There was a problem hiding this comment.
Thanks for this contribution!
| warp_sdata[warp_offset + lane] = val; | ||
| barrier(); | ||
|
|
||
| if (lane < 16) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 16]; |
There was a problem hiding this comment.
This seems like it's assuming a subgroup size of 32 (also at line 37).
There was a problem hiding this comment.
Do I understand correctly that this doesn't actually rely on a subgroup size of 32, but it's splitting the workgroup into groups of 32 and just reducing those (and it looks like some reduction across groups of 32 has already happened?).
There was a problem hiding this comment.
sorry I've missed this one. Yeah I don't think it would work with a size != 32. I need to think more through this one.
Do you've any suggestions on what I could do here?
There was a problem hiding this comment.
I think this may work because you're not relying on SubgroupInvocationId or SubgroupID, you've just split the workgroup into groups of 32. Maybe we can just test it on AMD (with wave64) and Intel and verify that it works.
There was a problem hiding this comment.
it work on Intel, but I am worried about all these settings that we made configurable. I've not really tried how it behaves with different values of the constants we defined. Or is the assumption that these values should not be tweaked from vulkan-shaders-gen.cpp without also changing the implementation in the shader?
There was a problem hiding this comment.
You could change this to a while loop that will handle any power-of-two value of WARP_SIZE. We do want to allow the spec constants to be changeable but it's fine to have limitations like "must be a power of two".
There was a problem hiding this comment.
Hmm wave64 AMD and wave8 llvmpipe are failing one test here, possibly due to this. All other tests are passing.
[SSM_SCAN] NMSE = 31335529439335960.000000000 > 0.000000100 SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4): FAIL
I think Intel also has a subgroup size of 32 so it wouldn't be a good test for this.
| return warp_sdata[warp_offset]; | ||
| } | ||
|
|
||
| void main() { |
There was a problem hiding this comment.
Do all threads always load/store in bounds? In the host code there was some rounding up going on, which suggests maybe some threads don't correspond to in-bounds locations.
There was a problem hiding this comment.
Ping on this one. I don't really understand what this shader does and which locations it should be accessing.
There was a problem hiding this comment.
I've tried to follow what the CUDA shader does. I'll spend more time on it and see if there is anything I can improve about memory access and make sure all the assumptions in the code are checked.
|
I've addressed the comments and pushed a new version. The results are even better now: |
| @@ -919,6 +919,12 @@ void process_shaders() { | |||
| string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}}); | |||
| string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}}); | |||
|
|
|||
| string_to_spv("ssm_scan_f32_d16", "ssm_scan.comp", {{"A_TYPE", "float"}}); | |||
| string_to_spv("ssm_scan_f32_d128", "ssm_scan.comp", {{"A_TYPE", "float"}}); | |||
| string_to_spv("ssm_scan_f32_d256", "ssm_scan.comp", {{"A_TYPE", "float"}}); | |||
There was a problem hiding this comment.
These three are all identical now, you only need one.
| warp_sdata[warp_offset + lane] = val; | ||
| barrier(); | ||
|
|
||
| if (lane < 16) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 16]; |
There was a problem hiding this comment.
You could change this to a while loop that will handle any power-of-two value of WARP_SIZE. We do want to allow the spec constants to be changeable but it's fine to have limitations like "must be a power of two".
|
I've completely replaced the code to reduce sum with In the last version I've also renamed |
|
Be aware that not all devices support subgroup commands. If there's a performance advantage to using them, you can do that, but it would still need a fallback to using a shared memory reduction. If there isn't a performance advantage, just use the shared memory reduction for compatibility. |
with the subgroup code I get: if I revert to the reduction loop, I've: |
|
I've reverted to the version with a for loop. We can look at the subgroup optimization later |
c467631 to
6e70718
Compare
c9d79db to
b5ed953
Compare
|
could it be approved again? |
|
I'll do a proper review soon. |
| }; | ||
|
|
||
| void main() { | ||
| const uint global_thread_id = gl_WorkGroupID.x * gl_WorkGroupSize.x + gl_LocalInvocationID.x; |
There was a problem hiding this comment.
In Vulkan you can shorten this to gl_GlobalInvocationID.x
| const int stride_dt = int(src2_nb1) / 4; | ||
| const int stride_B = int(src4_nb2) / 4; | ||
| const int stride_C = int(src5_nb2) / 4; | ||
| const int stride_y = int(n_head * d_head); |
There was a problem hiding this comment.
Why use int everywhere? It leads to a lot of casting, and the values don't look like they can/should be negative. Indices should be uints.
| state[j] = s0[s0_base_idx + j * D_STATE + tid]; | ||
| } | ||
|
|
||
| if (tid >= D_STATE) { |
There was a problem hiding this comment.
Isn't D_STATE the workgroup size as well? If that's the case, this can't be true.
There was a problem hiding this comment.
I've added only to make sure reading the code that there can't be any OOB access, but I agree it is superfluous. I'll drop it
| float dt_soft_plus = dt[dt_base_idx + i * stride_dt]; | ||
| dt_soft_plus = softplus(dt_soft_plus); |
There was a problem hiding this comment.
| float dt_soft_plus = dt[dt_base_idx + i * stride_dt]; | |
| dt_soft_plus = softplus(dt_soft_plus); | |
| const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]); |
Not sure how much of a difference this kind of stuff makes, but with the large variety of Vulkan compilers, I prefer to make it as easy as possible for them with const.
|
|
||
| int lane = tid % SUBGROUP_SIZE; | ||
|
|
||
| warp_sdata[tid] = y; |
There was a problem hiding this comment.
Just curious, I don't understand or know the algorithm. Is the switch to the second shared buffer here to save on a barrier? It should be possible to continue the reduction with only stateC, or am I missing something?
Of course it also makes replacing the step with subgroup operations easier.
There was a problem hiding this comment.
no reason, I've tried different ways before I got it to work. Your version is much better, thanks for the suggestion
|
thanks for the review, pushed an updated version |
|
@jeffbolznv thanks again, fixed the last comments |
jeffbolznv
left a comment
There was a problem hiding this comment.
LGTM. All my comments have been addressed. I haven't had a chance to actually run it yet, I can try to do that tomorrow, but I don't want to block the change.
|
@0cc4m would you mind another look? |
|
I ran the backend tests locally and they passed and had no validation errors. I noticed one test was unsupported and I thought this was supported in earlier versions of the change. Was it intentional to remove it? |
|
Oh, there's a failure in the lavapipe CI: lavapipe uses an unusual subgroup size, that might be related. |
yeah I've dropped support for Mamba 1 from this PR as I've not got it to work yet. I can add it as a follow up PR |
|
I did a quick experiment and changing spec constant 1 from device->subgroup_size to 32 fixed the lavapipe failure. |
how do I run this locally? In what CI task is it happening? |
|
It's the ubuntu-24-cmake-vulkan CI job. I think you'll need either Linux or WSL to run lavapipe (I run it on Windows with WSL). You'll probably need to set the env var GGML_VK_VISIBLE_DEVICES=0 (assuming this is enumerated as the first device) to enable lavapipe. |
nevermind, I see it now. Taking a look |
|
to simplify the review, I post here the patch I've applied on the top of the previous version: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
index 06aa10bfe..daebb5dc0 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
@@ -96,7 +96,7 @@ void main() {
barrier();
}
- [[unroll]] for (uint j = 0; j < SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
+ [[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
const uint idx = (tid % SUBGROUP_SIZE) +
D_STATE * (tid / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
@@ -104,13 +104,13 @@ void main() {
uint lane = tid % SUBGROUP_SIZE;
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
- if (lane < offset) {
+ if (idx < SPLIT_H * D_STATE) {
stateC[idx] += stateC[idx + offset];
}
barrier();
}
- if (tid % SUBGROUP_SIZE == 0) {
+ if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
d[y_base_idx + i * stride_y + k] = stateC[idx];
}the test passes locally now |
|
Passes for me, too. |
|
Wow, that really is a game changer for Vulkan on nvidia cards. Here an example of my 4060 Mobile GPU: . . Original: . I understand that one needs special SSM/Mamba(?) LLM models in order to get this. Does anyone know how to specifically search for these kind of models on Huggingface? Thanks so much for your work on this, giuseppe, and please get this into upstream... Regards, |
0cc4m
left a comment
There was a problem hiding this comment.
It works as intended, I just have some comments about the supports_op code.
ggml/src/ggml-vulkan/ggml-vulkan.cpp
Outdated
| const uint32_t MAX_D_STATE = 256; | ||
|
|
||
| size_t stateC_size = SPLIT_H * MAX_D_STATE * sizeof(float); | ||
| size_t warp_sdata_size = MAX_D_STATE * sizeof(float); |
There was a problem hiding this comment.
You forgot to update this calculation.
ggml/src/ggml-vulkan/ggml-vulkan.cpp
Outdated
| const uint32_t SPLIT_H = 16; | ||
| const uint32_t MAX_D_STATE = 256; | ||
|
|
||
| size_t stateC_size = SPLIT_H * MAX_D_STATE * sizeof(float); |
There was a problem hiding this comment.
Shouldn't this be using d_state instead of MAX_D_STATE, since the smaller shader may fit even if the large one does not?
Add State Space Model scan operation to the Vulkan backend. Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
Add State Space Model conv operation to the Vulkan backend. Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
|
@0cc4m fixed the last comments and also added the new operations to |
0cc4m
left a comment
There was a problem hiding this comment.
Thank you, looks good now.
|
@0cc4m thanks! CI is green now |
* vulkan: implement SSM scan operation Add State Space Model scan operation to the Vulkan backend. Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com> * vulkan: implement SSM conv operation Add State Space Model conv operation to the Vulkan backend. Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com> --------- Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
* vulkan: implement SSM scan operation Add State Space Model scan operation to the Vulkan backend. Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com> * vulkan: implement SSM conv operation Add State Space Model conv operation to the Vulkan backend. Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com> --------- Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
* vulkan: implement SSM scan operation Add State Space Model scan operation to the Vulkan backend. Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com> * vulkan: implement SSM conv operation Add State Space Model conv operation to the Vulkan backend. Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com> --------- Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
implement SSM scan and SSM conv for Vulkan.
Intel Arc:
ggml_vulkan: 0 = Intel(R) Arc(tm) Graphics (MTL) (Intel open-source Mesa driver) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 65536 | int dot: 1 | matrix cores: none
build: bc07349 (6756)
ggml_vulkan: 0 = Intel(R) Arc(tm) Graphics (MTL) (Intel open-source Mesa driver) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 65536 | int dot: 1 | matrix cores: none
build: a4d94598e (6769)
NVIDIA:
ggml_vulkan: 0 = NVIDIA L40S (NVIDIA) | uma: 0 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: KHR_coopmat
build: 554fd57 (6766)
ggml_vulkan: 0 = NVIDIA L40S (NVIDIA) | uma: 0 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: KHR_coopmat
build: a4d94598e (6769)