Skip to content

CUDA: Fix loop unrolling for BW in mul_mat_q_stream_k_fixup#19053

Merged
JohannesGaessler merged 1 commit intoggml-org:masterfrom
ORippler:osimons/fix_bw_mmq_fixup_kernel
Feb 3, 2026
Merged

CUDA: Fix loop unrolling for BW in mul_mat_q_stream_k_fixup#19053
JohannesGaessler merged 1 commit intoggml-org:masterfrom
ORippler:osimons/fix_bw_mmq_fixup_kernel

Conversation

@ORippler
Copy link
Collaborator

By providing stride_* variables as size_t (i.e., 64-bit), the compiler can correctly unroll the two for-loops on BW. This gives some perf for prefill/pp phase on BW, while not affecting other SMs.

For pointer arithmetic inside loops, general performance guidance moving forward is likely to be to perform it in 64-bit unless strictly necessary.

Perf numbers
GPU Model Test t/s master t/s osimons/fix_bw_mmq_fixup_kernel Speedup
NVIDIA RTX 6000 Ada Generation gpt-oss 20B MXFP4 MoE pp8096 8404.05 8375.79 1.00
NVIDIA RTX 6000 Ada Generation gpt-oss 20B MXFP4 MoE tg128 253.79 253.90 1.00
NVIDIA RTX 6000 Ada Generation llama 3B Q4_K_M pp8096 16148.93 16019.60 0.99
NVIDIA RTX 6000 Ada Generation llama 3B Q4_K_M tg128 315.50 315.08 1.00
NVIDIA RTX 6000 Ada Generation llama 8B Q4_0 pp8096 8008.29 7978.80 1.00
NVIDIA RTX 6000 Ada Generation llama 8B Q4_0 tg128 168.87 168.85 1.00
NVIDIA RTX 6000 Ada Generation nemotron_h 9B BF16 pp8096 4263.16 4248.53 1.00
NVIDIA RTX 6000 Ada Generation nemotron_h 9B BF16 tg128 48.61 48.59 1.00
NVIDIA RTX 6000 Ada Generation nemotron_h 9B Q4_K_M pp8096 5165.11 5157.43 1.00
NVIDIA RTX 6000 Ada Generation nemotron_h 9B Q4_K_M tg128 111.54 111.47 1.00
NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition gpt-oss 20B MXFP4 MoE pp8096 12582.80 12758.37 1.01
NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition gpt-oss 20B MXFP4 MoE tg128 352.58 353.16 1.00
NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition llama 3B Q4_K_M pp8096 16879.10 17619.47 1.04
NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition llama 3B Q4_K_M tg128 426.27 425.65 1.00
NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition llama 8B Q4_0 pp8096 10649.90 10982.65 1.03
NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition llama 8B Q4_0 tg128 260.32 260.25 1.00
NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition nemotron_h 9B BF16 pp8096 7717.73 7716.22 1.00
NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition nemotron_h 9B BF16 tg128 83.51 83.51 1.00
NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition nemotron_h 9B Q4_K_M pp8096 7301.90 7370.38 1.01
NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition nemotron_h 9B Q4_K_M tg128 172.99 172.78 1.00

By providing stride_* variables as size_t (i.e., 64-bit) the compiler can
correctly unroll the [two for-loops](https://github.com/ggml-org/llama.cpp/blob/557515be1e93ed8939dd8a7c7d08765fdbe8be31/ggml/src/ggml-cuda/mmq.cuh#L3789-L3816)
on BW. This gives some perf for prefill/pp phase on BW, while not affecting
other SMs:

| GPU                                                     | Model                 | Test   |   t/s master |   t/s osimons/fix_bw_mmq_fixup_kernel |   Speedup |
|:--------------------------------------------------------|:----------------------|:-------|-------------:|--------------------------------------:|----------:|
| NVIDIA RTX 6000 Ada Generation                          | gpt-oss 20B MXFP4 MoE | pp8096 |      8404.05 |                               8375.79 |      1.00 |
| NVIDIA RTX 6000 Ada Generation                          | llama 3B Q4_K_M       | pp8096 |     16148.93 |                              16019.60 |      0.99 |
| NVIDIA RTX 6000 Ada Generation                          | llama 8B Q4_0         | pp8096 |      8008.29 |                               7978.80 |      1.00 |
| NVIDIA RTX 6000 Ada Generation                          | nemotron_h 9B BF16    | pp8096 |      4263.16 |                               4248.53 |      1.00 |
| NVIDIA RTX 6000 Ada Generation                          | nemotron_h 9B Q4_K_M  | pp8096 |      5165.11 |                               5157.43 |      1.00 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | gpt-oss 20B MXFP4 MoE | pp8096 |     12582.80 |                              12758.37 |      1.01 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | llama 3B Q4_K_M       | pp8096 |     16879.10 |                              17619.47 |      1.04 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | llama 8B Q4_0         | pp8096 |     10649.90 |                              10982.65 |      1.03 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | nemotron_h 9B BF16    | pp8096 |      7717.73 |                               7716.22 |      1.00 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | nemotron_h 9B Q4_K_M  | pp8096 |      7301.90 |                               7370.38 |      1.01 |
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jan 23, 2026
@JohannesGaessler
Copy link
Contributor

Sorry, I don't understand this PR. Why can the linked loops be unrolled if and only if the strides are unsigned 64 bit integers? And how is the preference for signed loop variables relevant? In any case, there is a template ggml_cuda_unroll in common.cuh that can be used to manually unroll loops if the compiler can't do it due to e.g. continue. I think in cases such as this it's preferable to use that instead.

@ORippler
Copy link
Collaborator Author

ORippler commented Jan 26, 2026

Sorry, I don't understand this PR. Why can the linked loops be unrolled if and only if the strides are unsigned 64 bit integers? And how is the preference for signed loop variables relevant?

Let me rephrase: The compiler does unroll on all CCs, but due to changes introduced for >= BW in how signed int overflow is handled for 32-bit pointer arithmetic it will no longer do constant folding and group all data-accessing LDG/STG instructions in the aforementioned loop, sequentially processing them instead:

image

As a consequence, we pay the latency cost of data-access more often than we have to. By explicitly performing pointer-arithmetic in 64-bit the compiler will group all data-accessing LDG/STG instructions, reducing the # of times we have to wait for data:

Untitled.

In any case, there is a template ggml_cuda_unroll in common.cuh that can be used to manually unroll loops if the compiler can't do it due to e.g. continue. I think in cases such as this it's preferable to use that instead.

Consequentially, applying ggml_cuda_unroll while doing 32-bit pointer arithmetic will not do complete constant-folding/grouping

Untitled

, whereas combining it with explicitly declared 64-bit pointer arithmetic will:

Untitled
  1. My current understanding is that this is a permanent change that will apply to future HW gens also. Hence, moving forward, performance guidance will likely be to do pointer-arithmetic in 64-bit per default. This is guaranteed to give best performance on all CCs and is in line with the prevalence of 64-bit applications. While 32-bit pointer arithmetic will most often be optimized to the same performance degree (we use it extensively throughout ggml without issues for BW GPUs), it may fail in edge cases such as the one observed here & require compiler massaging to work (I did confirm that this specific instance is not a compiler bug).
  2. Since ggml_cuda_unroll + 64-bit pointer arithmetic does not give better perf over doing 64-bit pointer arithmetic alone I did not push related changes I show screenshots of.

Hope that clarifies the PR.

@JohannesGaessler
Copy link
Contributor

Sorry, I am not at all familiar with SASS. Do I understand you correctly that Is this not a difference in terms of which PTX instructions the CUDA compiler emits but rather that the same PTX instructions get optimized differently? I am asking because instructions like st support up to 16 bytes per thread and my impression was that that would be the bandwidth-optimal way to copy data.

Generally speaking, I'm not sure my use of signed 32 bit integers for most things is optimal. It's how I started the code when I had no CUDA experience at all and only later on did the code become so optimized that the specific data types would make a meaningful difference. For the vast majority of kernels the used integers are non-negative and fit within the 32 bit range. Does the compiler optimize the code correctly if uint32_t is used instead of size_t? Because then we can simply start using uint32_t as the default integer data type without having to worry about register pressure or (presumably) slower 64 bit integer arithmetic.

One more question: is the recommendation still to use signed integers for loops that cannot be unrolled at compile time?

@ORippler
Copy link
Collaborator Author

ORippler commented Jan 29, 2026

Sorry, I am not at all familiar with SASS. Do I understand you correctly that Is this not a difference in terms of which PTX instructions the CUDA compiler emits but rather that the same PTX instructions get optimized differently? I am asking because instructions like st support up to 16 bytes per thread and my impression was that that would be the bandwidth-optimal way to copy data.

Explicitly doing pointer arithmetic in 64 bit will generate you different ptx than doing it in 32 bit and letting the compiler resolve the type-mismatch when addressing your actual 64-bit pointer (which is always loaded as a 64-bit int). This is fine, as you are asking for two different things already in CUDA, not in PTX. Consequentially, two explicit cvt instructions are present in the 64-bit function's ptx code that upcast tid.x and tid.y to 64 bit + different mix of ptx instructions are used for the subsequent pointer arithmetic in the godbolt example here.

The issue is also not in the width of the individual data-accessing instructions, but in the potential inter-dependency of array between the individual loop iterations. While array is marked as __restrict__, the 32-bit pointer arithmetic used to subscript it could overflow. Consequentially, the pointer at iteration n+1 could point to the same value as iteration n, and the loop can no longer be unrolled as the data for n+1 would then depend on n. While the issue I highlighted is also there for pre-BW GPUs, the undefinedness of 32-bit signed int overflow was exploited more aggressively on previous HW generations: a >8GB array (2**31= 2 Gigaelements * 4byte/element) is something feasible on current HW, and the above optimization may lead to issues/faulty behavior here should the pointer arithmetic overflow. While 64-bit pointer arithmetic may overflow also, having 2**63= 9 ExaElements is infeasible on current HW and thus overflow is currently a non-issue.

Side note: For bandwidth-optimal data copies, one should make sure to read/write whole sectors (or better yet cache lines) in a coalesced manner across all threads in a CTA; Basically bandwidth-optimality requires CTA-level optimization, similar to how compute-optimality for GEMM requires tensor-core use that are also shared within a CTA. Since CTA-level-optimal code is at times convoluted to write in a SIMT language such as CUDA, there exists meta-compiler such as CuTILE and TRITON that work on CTA-level. While using wider loads are a good way to read/write sectors and cache lines in a coalesced way, they don't strictly have to be used unless the MMU queue gets full (LG/MIO throttle would be the key words to look out for in Nsight Compute).

Generally speaking, I'm not sure my use of signed 32 bit integers for most things is optimal. It's how I started the code when I had no CUDA experience at all and only later on did the code become so optimized that the specific data types would make a meaningful difference. For the vast majority of kernels the used integers are non-negative and fit within the 32 bit range. Does the compiler optimize the code correctly if uint32_t is used instead of size_t? Because then we can simply start using uint32_t as the default integer data type without having to worry about register pressure or (presumably) slower 64 bit integer arithmetic.

  1. uint32_t is worse than int, since unsigned int overflow is defined in C and nvcc complies with it; loop unrolling will not happen, even on older HW (double_loop_32u vs double_loop_32s in the godbolt example). Hence, recommendation is still to use signed int for loop counters when working in 32 bit precision.
  2. NVGPUs are optimized for 64-bit integer arithmetic, and slow-downs are not expected to happen based on instruction throughput.
  3. Register pressure may be an issue, but should be tackled once its observed.

In summary: From a performance perspective, pointer arithmetic should default to 64-bit (ideally using size_t since that's the way to represent object/pointer sizes in c++), as the compiler will always fully optimize this. The main use-case for 32-bit pointer arithmetic in a 64-bit application is register pressure. Here, one should use signed int where possible for some degree of compiler optimization (most often reaching the same optimization-level as 64-bit pointer arithmetic), and be ready for compiler massaging on BW in the cases where the compiler fails to yield the same level of optimization (i.e. do manual loop hoisting etc.). As the kernel targeted in this PR does not suffer from register pressure, one can simply switch to 64-bit pointer arithmetic on all platforms for full compiler optimization.

@ORippler
Copy link
Collaborator Author

ORippler commented Feb 2, 2026

@JohannesGaessler lmk should there be a need for additional clarifications

@JohannesGaessler
Copy link
Contributor

Thank you, the situation is clear to me now. I checked the performance impact for the GPUs I have:

Performance
GPU Model Microbatch size Test t/s b7818 t/s 390146e Speedup
MI60 / MI50 llama 8B Q4_0 1 pp512 103.66 104.44 1.01
MI60 / MI50 llama 8B Q4_0 2 pp512 180.19 181.64 1.01
MI60 / MI50 llama 8B Q4_0 4 pp512 171.34 171.63 1.00
MI60 / MI50 llama 8B Q4_0 8 pp512 236.45 236.20 1.00
MI60 / MI50 llama 8B Q4_0 16 pp512 367.90 367.42 1.00
MI60 / MI50 llama 8B Q4_0 32 pp512 476.84 476.90 1.00
MI60 / MI50 llama 8B Q4_0 64 pp512 548.83 548.87 1.00
MI60 / MI50 llama 8B Q4_0 128 pp512 764.81 764.55 1.00
MI60 / MI50 llama 8B Q4_0 256 pp512 905.83 905.99 1.00
MI60 / MI50 llama 8B Q4_0 512 pp512 1071.82 1074.52 1.00
MI100 llama 8B Q4_0 1 pp512 133.11 132.05 0.99
MI100 llama 8B Q4_0 2 pp512 210.47 210.01 1.00
MI100 llama 8B Q4_0 4 pp512 229.75 229.12 1.00
MI100 llama 8B Q4_0 8 pp512 328.93 328.34 1.00
MI100 llama 8B Q4_0 16 pp512 757.79 758.39 1.00
MI100 llama 8B Q4_0 32 pp512 1210.77 1202.70 0.99
MI100 llama 8B Q4_0 64 pp512 1738.06 1727.17 0.99
MI100 llama 8B Q4_0 128 pp512 2000.20 1984.62 0.99
MI100 llama 8B Q4_0 256 pp512 2386.45 2376.55 1.00
MI100 llama 8B Q4_0 512 pp512 2409.18 2397.74 1.00
P40 llama 8B Q4_0 1 pp512 59.13 59.13 1.00
P40 llama 8B Q4_0 2 pp512 115.58 115.52 1.00
P40 llama 8B Q4_0 4 pp512 163.00 163.03 1.00
P40 llama 8B Q4_0 8 pp512 218.17 216.92 0.99
P40 llama 8B Q4_0 16 pp512 484.58 484.00 1.00
P40 llama 8B Q4_0 32 pp512 676.80 676.97 1.00
P40 llama 8B Q4_0 64 pp512 792.07 792.50 1.00
P40 llama 8B Q4_0 128 pp512 910.35 911.28 1.00
P40 llama 8B Q4_0 256 pp512 998.89 994.37 1.00
P40 llama 8B Q4_0 512 pp512 1037.51 1039.30 1.00
RTX 3090 llama 8B Q4_0 1 pp512 163.03 163.05 1.00
RTX 3090 llama 8B Q4_0 2 pp512 274.14 274.45 1.00
RTX 3090 llama 8B Q4_0 4 pp512 425.08 423.56 1.00
RTX 3090 llama 8B Q4_0 8 pp512 545.44 544.63 1.00
RTX 3090 llama 8B Q4_0 16 pp512 1072.51 1072.82 1.00
RTX 3090 llama 8B Q4_0 32 pp512 1768.16 1770.29 1.00
RTX 3090 llama 8B Q4_0 64 pp512 2754.56 2788.27 1.01
RTX 3090 llama 8B Q4_0 128 pp512 3595.69 3588.35 1.00
RTX 3090 llama 8B Q4_0 256 pp512 4497.44 4491.01 1.00
RTX 3090 llama 8B Q4_0 512 pp512 4856.09 4852.63 1.00
RTX 4090 llama 8B Q4_0 1 pp512 199.28 199.31 1.00
RTX 4090 llama 8B Q4_0 2 pp512 346.35 345.74 1.00
RTX 4090 llama 8B Q4_0 4 pp512 664.32 663.70 1.00
RTX 4090 llama 8B Q4_0 8 pp512 1097.16 1094.22 1.00
RTX 4090 llama 8B Q4_0 16 pp512 1889.16 1887.01 1.00
RTX 4090 llama 8B Q4_0 32 pp512 3368.46 3363.13 1.00
RTX 4090 llama 8B Q4_0 64 pp512 5769.53 5935.56 1.03
RTX 4090 llama 8B Q4_0 128 pp512 8609.46 8583.75 1.00
RTX 4090 llama 8B Q4_0 256 pp512 11487.90 11494.33 1.00
RTX 4090 llama 8B Q4_0 512 pp512 13093.98 13312.52 1.02
RTX 5090 llama 8B Q4_0 1 pp512 303.49 303.73 1.00
RTX 5090 llama 8B Q4_0 2 pp512 449.04 448.99 1.00
RTX 5090 llama 8B Q4_0 4 pp512 799.99 800.69 1.00
RTX 5090 llama 8B Q4_0 8 pp512 1212.24 1211.00 1.00
RTX 5090 llama 8B Q4_0 16 pp512 2075.28 2145.84 1.03
RTX 5090 llama 8B Q4_0 32 pp512 3676.99 3826.37 1.04
RTX 5090 llama 8B Q4_0 64 pp512 5833.65 6366.52 1.09
RTX 5090 llama 8B Q4_0 128 pp512 8203.00 8968.29 1.09
RTX 5090 llama 8B Q4_0 256 pp512 11764.66 12579.47 1.07
RTX 5090 llama 8B Q4_0 512 pp512 15114.02 15812.35 1.05
RX 9060 XT llama 8B Q4_0 1 pp512 49.48 49.44 1.00
RX 9060 XT llama 8B Q4_0 2 pp512 94.41 94.32 1.00
RX 9060 XT llama 8B Q4_0 4 pp512 167.09 167.18 1.00
RX 9060 XT llama 8B Q4_0 8 pp512 209.94 209.89 1.00
RX 9060 XT llama 8B Q4_0 16 pp512 607.60 607.04 1.00
RX 9060 XT llama 8B Q4_0 32 pp512 824.41 824.82 1.00
RX 9060 XT llama 8B Q4_0 64 pp512 1477.18 1473.82 1.00
RX 9060 XT llama 8B Q4_0 128 pp512 2234.53 2227.27 1.00
RX 9060 XT llama 8B Q4_0 256 pp512 2452.85 2443.00 1.00
RX 9060 XT llama 8B Q4_0 512 pp512 2525.50 2519.72 1.00
V100-PCIE-32GB llama 8B Q4_0 1 pp512 142.22 142.33 1.00
V100-PCIE-32GB llama 8B Q4_0 2 pp512 260.42 260.64 1.00
V100-PCIE-32GB llama 8B Q4_0 4 pp512 359.87 359.94 1.00
V100-PCIE-32GB llama 8B Q4_0 8 pp512 511.35 511.15 1.00
V100-PCIE-32GB llama 8B Q4_0 16 pp512 731.66 733.11 1.00
V100-PCIE-32GB llama 8B Q4_0 32 pp512 1118.66 1119.35 1.00
V100-PCIE-32GB llama 8B Q4_0 64 pp512 622.17 622.28 1.00
V100-PCIE-32GB llama 8B Q4_0 128 pp512 1185.36 1185.88 1.00
V100-PCIE-32GB llama 8B Q4_0 256 pp512 2065.24 2065.93 1.00
V100-PCIE-32GB llama 8B Q4_0 512 pp512 3041.05 3040.52 1.00

My RTX 4090 seems to also be benefiting slightly in some scenarios, for the RTX 5090 the benefits are larger. I am not seeing differences for any other GPUs that are larger than run-to-run variance so probably it should be fine to use size_t universally for strides (instead of e.g. having to define a type per GPU vendor or generation).

@JohannesGaessler JohannesGaessler merged commit 1f1e57f into ggml-org:master Feb 3, 2026
72 of 73 checks passed
@ORippler ORippler deleted the osimons/fix_bw_mmq_fixup_kernel branch February 3, 2026 10:37
agent-enemy-2 pushed a commit to agent-enemy-2/llama.cpp that referenced this pull request Feb 4, 2026
…#19053)

By providing stride_* variables as size_t (i.e., 64-bit) the compiler can
correctly unroll the [two for-loops](https://github.com/ggml-org/llama.cpp/blob/557515be1e93ed8939dd8a7c7d08765fdbe8be31/ggml/src/ggml-cuda/mmq.cuh#L3789-L3816)
on BW. This gives some perf for prefill/pp phase on BW, while not affecting
other SMs:

| GPU                                                     | Model                 | Test   |   t/s master |   t/s osimons/fix_bw_mmq_fixup_kernel |   Speedup |
|:--------------------------------------------------------|:----------------------|:-------|-------------:|--------------------------------------:|----------:|
| NVIDIA RTX 6000 Ada Generation                          | gpt-oss 20B MXFP4 MoE | pp8096 |      8404.05 |                               8375.79 |      1.00 |
| NVIDIA RTX 6000 Ada Generation                          | llama 3B Q4_K_M       | pp8096 |     16148.93 |                              16019.60 |      0.99 |
| NVIDIA RTX 6000 Ada Generation                          | llama 8B Q4_0         | pp8096 |      8008.29 |                               7978.80 |      1.00 |
| NVIDIA RTX 6000 Ada Generation                          | nemotron_h 9B BF16    | pp8096 |      4263.16 |                               4248.53 |      1.00 |
| NVIDIA RTX 6000 Ada Generation                          | nemotron_h 9B Q4_K_M  | pp8096 |      5165.11 |                               5157.43 |      1.00 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | gpt-oss 20B MXFP4 MoE | pp8096 |     12582.80 |                              12758.37 |      1.01 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | llama 3B Q4_K_M       | pp8096 |     16879.10 |                              17619.47 |      1.04 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | llama 8B Q4_0         | pp8096 |     10649.90 |                              10982.65 |      1.03 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | nemotron_h 9B BF16    | pp8096 |      7717.73 |                               7716.22 |      1.00 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | nemotron_h 9B Q4_K_M  | pp8096 |      7301.90 |                               7370.38 |      1.01 |
liparetejas pushed a commit to liparetejas/llama.cpp that referenced this pull request Feb 23, 2026
…#19053)

By providing stride_* variables as size_t (i.e., 64-bit) the compiler can
correctly unroll the [two for-loops](https://github.com/ggml-org/llama.cpp/blob/557515be1e93ed8939dd8a7c7d08765fdbe8be31/ggml/src/ggml-cuda/mmq.cuh#L3789-L3816)
on BW. This gives some perf for prefill/pp phase on BW, while not affecting
other SMs:

| GPU                                                     | Model                 | Test   |   t/s master |   t/s osimons/fix_bw_mmq_fixup_kernel |   Speedup |
|:--------------------------------------------------------|:----------------------|:-------|-------------:|--------------------------------------:|----------:|
| NVIDIA RTX 6000 Ada Generation                          | gpt-oss 20B MXFP4 MoE | pp8096 |      8404.05 |                               8375.79 |      1.00 |
| NVIDIA RTX 6000 Ada Generation                          | llama 3B Q4_K_M       | pp8096 |     16148.93 |                              16019.60 |      0.99 |
| NVIDIA RTX 6000 Ada Generation                          | llama 8B Q4_0         | pp8096 |      8008.29 |                               7978.80 |      1.00 |
| NVIDIA RTX 6000 Ada Generation                          | nemotron_h 9B BF16    | pp8096 |      4263.16 |                               4248.53 |      1.00 |
| NVIDIA RTX 6000 Ada Generation                          | nemotron_h 9B Q4_K_M  | pp8096 |      5165.11 |                               5157.43 |      1.00 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | gpt-oss 20B MXFP4 MoE | pp8096 |     12582.80 |                              12758.37 |      1.01 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | llama 3B Q4_K_M       | pp8096 |     16879.10 |                              17619.47 |      1.04 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | llama 8B Q4_0         | pp8096 |     10649.90 |                              10982.65 |      1.03 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | nemotron_h 9B BF16    | pp8096 |      7717.73 |                               7716.22 |      1.00 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | nemotron_h 9B Q4_K_M  | pp8096 |      7301.90 |                               7370.38 |      1.01 |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants