Skip to content

PTX contract validation: layer tracing doesn't trace to PTX level #219

@noahgift

Description

@noahgift

Problem

Our compiler/model contract system and layer tracing tooling has a significant gap: it does not trace or validate at the PTX (kernel) level. This allowed several GPU kernel bugs to go undetected:

  1. PMAT-PREFILL-FIX: Batched prefill KV scatter used stale position_buf from previous generation's graph capture, causing all tokens to scatter to position 0. Root cause: validate_gpu_first_token() captures a CUDA graph with position_buf=Some(0), then the next generate_gpu_resident() call ran batched prefill which checked position_buf.is_some() → took indirect scatter path → read stale position 0.

  2. BatchedVectorizedRmsNormKernel u64 shared memory addressing: Batched kernel used u64 registers for shared memory addresses while single-vector used u32. Correct for sm_89 but not portable.

  3. BatchedQ6KGemvKernel dequant bugs: Three independent dequant errors in batched variant that didn't exist in single-vector kernel.

Root Cause

The current tooling stack:

  • apr trace traces transformer layers (RMSNorm, Attention, FFN, etc.)
  • apr validate checks tensor shapes and metadata
  • apr qa validates end-to-end golden output

Missing: No tool validates that a PTX kernel produces correct output for known inputs. No tool compares batched vs single-vector kernel outputs. No tool traces the actual PTX instructions to detect addressing mismatches.

Proposed Solution

1. PTX Parity Contract (apr parity --ptx)

For every batched kernel variant, automatically verify that it produces bit-identical output to the single-vector variant for M=1:

  • BatchedVectorizedRmsNormKernel vs VectorizedRmsNormKernel
  • BatchedQ4KGemvKernel vs Q4KGemvKernel
  • BatchedQ6KGemvKernel vs Q6KGemvKernel
  • BatchedResidualAddKernel vs ResidualAddKernel
  • BatchedRopeKernel vs RopeKernel
  • BatchedSwigluKernel vs SwigluKernel

2. PTX Static Analysis

At compile time, verify PTX register types match expected patterns:

  • Shared memory addresses should use u32 registers (%r) on all targets
  • Validate that batched kernels use ctaid.y for row selection
  • Detect when a kernel reads from position_buf (indirect mode) vs immediate position

3. Layer-to-PTX Tracing

Extend apr trace to include PTX-level information:

  • Which kernel was launched for each layer operation
  • Launch config (grid, block, shared memory)
  • Key parameter values (position, seq_len, batch_size)
  • Whether indirect (graph-mode) or direct path was taken for scatter/RoPE

4. KV Cache State Validation

Add a --validate-kv-cache flag to apr qa that:

  • Runs prefill with both serial and batched paths
  • Compares KV cache contents after prefill
  • Verifies kv_cache_lengths matches expected values
  • Detects stale position_buf / seq_len_buf

Impact

Without PTX-level validation, kernel bugs are only caught by end-to-end golden output tests, which don't isolate the failing component. This makes debugging extremely time-consuming (the PMAT-PREFILL bug took ~8 hours to isolate).

Labels

enhancement, tooling, quality

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions