Problem
Our compiler/model contract system and layer tracing tooling has a significant gap: it does not trace or validate at the PTX (kernel) level. This allowed several GPU kernel bugs to go undetected:
-
PMAT-PREFILL-FIX: Batched prefill KV scatter used stale position_buf from previous generation's graph capture, causing all tokens to scatter to position 0. Root cause: validate_gpu_first_token() captures a CUDA graph with position_buf=Some(0), then the next generate_gpu_resident() call ran batched prefill which checked position_buf.is_some() → took indirect scatter path → read stale position 0.
-
BatchedVectorizedRmsNormKernel u64 shared memory addressing: Batched kernel used u64 registers for shared memory addresses while single-vector used u32. Correct for sm_89 but not portable.
-
BatchedQ6KGemvKernel dequant bugs: Three independent dequant errors in batched variant that didn't exist in single-vector kernel.
Root Cause
The current tooling stack:
apr trace traces transformer layers (RMSNorm, Attention, FFN, etc.)
apr validate checks tensor shapes and metadata
apr qa validates end-to-end golden output
Missing: No tool validates that a PTX kernel produces correct output for known inputs. No tool compares batched vs single-vector kernel outputs. No tool traces the actual PTX instructions to detect addressing mismatches.
Proposed Solution
1. PTX Parity Contract (apr parity --ptx)
For every batched kernel variant, automatically verify that it produces bit-identical output to the single-vector variant for M=1:
BatchedVectorizedRmsNormKernel vs VectorizedRmsNormKernel
BatchedQ4KGemvKernel vs Q4KGemvKernel
BatchedQ6KGemvKernel vs Q6KGemvKernel
BatchedResidualAddKernel vs ResidualAddKernel
BatchedRopeKernel vs RopeKernel
BatchedSwigluKernel vs SwigluKernel
2. PTX Static Analysis
At compile time, verify PTX register types match expected patterns:
- Shared memory addresses should use u32 registers (
%r) on all targets
- Validate that batched kernels use
ctaid.y for row selection
- Detect when a kernel reads from
position_buf (indirect mode) vs immediate position
3. Layer-to-PTX Tracing
Extend apr trace to include PTX-level information:
- Which kernel was launched for each layer operation
- Launch config (grid, block, shared memory)
- Key parameter values (position, seq_len, batch_size)
- Whether indirect (graph-mode) or direct path was taken for scatter/RoPE
4. KV Cache State Validation
Add a --validate-kv-cache flag to apr qa that:
- Runs prefill with both serial and batched paths
- Compares KV cache contents after prefill
- Verifies
kv_cache_lengths matches expected values
- Detects stale
position_buf / seq_len_buf
Impact
Without PTX-level validation, kernel bugs are only caught by end-to-end golden output tests, which don't isolate the failing component. This makes debugging extremely time-consuming (the PMAT-PREFILL bug took ~8 hours to isolate).
Labels
enhancement, tooling, quality
Problem
Our compiler/model contract system and layer tracing tooling has a significant gap: it does not trace or validate at the PTX (kernel) level. This allowed several GPU kernel bugs to go undetected:
PMAT-PREFILL-FIX: Batched prefill KV scatter used stale
position_buffrom previous generation's graph capture, causing all tokens to scatter to position 0. Root cause:validate_gpu_first_token()captures a CUDA graph withposition_buf=Some(0), then the nextgenerate_gpu_resident()call ran batched prefill which checkedposition_buf.is_some()→ took indirect scatter path → read stale position 0.BatchedVectorizedRmsNormKernel u64 shared memory addressing: Batched kernel used u64 registers for shared memory addresses while single-vector used u32. Correct for sm_89 but not portable.
BatchedQ6KGemvKernel dequant bugs: Three independent dequant errors in batched variant that didn't exist in single-vector kernel.
Root Cause
The current tooling stack:
apr tracetraces transformer layers (RMSNorm, Attention, FFN, etc.)apr validatechecks tensor shapes and metadataapr qavalidates end-to-end golden outputMissing: No tool validates that a PTX kernel produces correct output for known inputs. No tool compares batched vs single-vector kernel outputs. No tool traces the actual PTX instructions to detect addressing mismatches.
Proposed Solution
1. PTX Parity Contract (
apr parity --ptx)For every batched kernel variant, automatically verify that it produces bit-identical output to the single-vector variant for M=1:
BatchedVectorizedRmsNormKernelvsVectorizedRmsNormKernelBatchedQ4KGemvKernelvsQ4KGemvKernelBatchedQ6KGemvKernelvsQ6KGemvKernelBatchedResidualAddKernelvsResidualAddKernelBatchedRopeKernelvsRopeKernelBatchedSwigluKernelvsSwigluKernel2. PTX Static Analysis
At compile time, verify PTX register types match expected patterns:
%r) on all targetsctaid.yfor row selectionposition_buf(indirect mode) vs immediate position3. Layer-to-PTX Tracing
Extend
apr traceto include PTX-level information:4. KV Cache State Validation
Add a
--validate-kv-cacheflag toapr qathat:kv_cache_lengthsmatches expected valuesposition_buf/seq_len_bufImpact
Without PTX-level validation, kernel bugs are only caught by end-to-end golden output tests, which don't isolate the failing component. This makes debugging extremely time-consuming (the PMAT-PREFILL bug took ~8 hours to isolate).
Labels
enhancement, tooling, quality