Skip to content

[JAX] Triton binding#2437

Merged
phu0ngng merged 11 commits into
NVIDIA:mainfrom
phu0ngng:triton_binding
Dec 2, 2025
Merged

[JAX] Triton binding#2437
phu0ngng merged 11 commits into
NVIDIA:mainfrom
phu0ngng:triton_binding

Conversation

@phu0ngng

@phu0ngng phu0ngng commented Dec 1, 2025

Copy link
Copy Markdown
Collaborator

Description

This PR adds utilities to lower the JAX custom call to the Triton kernel.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@greptile-apps

greptile-apps Bot commented Dec 1, 2025

Copy link
Copy Markdown
Contributor

Greptile Overview

Greptile Summary

This PR adds a new triton_extensions module to enable JAX primitives to use Triton kernels. The implementation provides:

  • triton_call_lowering: A utility function that handles MLIR lowering for Triton kernels, including support for autotuned kernels with multiple configurations
  • Kernel compilation with caching: Compiled kernels are cached by signature, constants, and compile options to avoid redundant compilation
  • JAX-Triton bridge: Uses JAX's internal gpu_triton APIs to create kernel calls and lower them through the FFI

Key implementation details:

  • Autotuned kernels compile all configurations upfront for runtime selection via TritonAutotunedKernelCall
  • Grid dimensions are normalized to 3D tuples for the GPU backend
  • The test demonstrates proper grid calculation using the minimum BLOCK_SIZE from autotuner configs to ensure all elements are processed regardless of which config is selected at runtime

Confidence Score: 4/5

  • This PR is safe to merge - it adds a new module without modifying existing code paths, and includes test coverage for the core functionality.
  • Score of 4 reflects: clean implementation using JAX's official gpu_triton APIs, test coverage demonstrating the feature works, and proper handling of autotuned kernels. Minor concerns include the test primitive not implementing all BasePrimitive abstract methods (limiting it to single-device JIT), and limited test coverage (single shape/dtype).
  • The test primitive in tests/jax/test_triton_custom_calls.py is intentionally limited scope - consider expanding test coverage for production use.

Important Files Changed

File Analysis

Filename Score Overview
build_tools/jax.py 5/5 Added triton to test dependencies - simple, safe change.
tests/jax/test_triton_custom_calls.py 4/5 New test file demonstrating Triton kernel integration. Test primitive uses atomic_max for amax computation with autotuning. Grid calculation correctly uses minimum BLOCK_SIZE. Test coverage limited to a single shape/dtype combination.
transformer_engine/jax/triton_extensions/init.py 5/5 Module init file with clear documentation about Triton dependencies and usage patterns.
transformer_engine/jax/triton_extensions/utils.py 4/5 Core utility module providing triton_call_lowering function for JAX-Triton integration. Implements kernel compilation with caching, dtype mapping, and autotuner support. Uses JAX's internal gpu_triton APIs for lowering.

Sequence Diagram

sequenceDiagram
    participant User as JAX Primitive
    participant TCL as triton_call_lowering
    participant CT as compile_triton
    participant GPU as gpu_triton (JAX)
    participant FFI as JAX FFI

    User->>TCL: Call with kernel_fn, ctx, arrays, grid
    TCL->>TCL: Get compute capability
    TCL->>TCL: Build signature from avals
    TCL->>TCL: Normalize grid to 3D

    alt Autotuned Kernel
        loop For each config
            TCL->>CT: Compile with config params
            CT->>CT: Check cache
            CT->>CT: Build ASTSource
            CT->>CT: Compile to PTX
            CT->>GPU: Create TritonKernel
            GPU-->>CT: Return kernel
            CT-->>TCL: Return cached/compiled kernel
            TCL->>GPU: Create TritonKernelCall
        end
        TCL->>GPU: Create TritonAutotunedKernelCall
    else Regular Kernel
        TCL->>CT: Compile kernel
        CT-->>TCL: Return kernel
        TCL->>GPU: Create TritonKernelCall
    end

    TCL->>GPU: Serialize to protobuf
    TCL->>FFI: ffi_lowering with compressed proto
    FFI-->>User: Return MLIR result
Loading

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment thread tests/jax/test_triton_custom_calls.py
Comment thread transformer_engine/jax/triton_extensions/utils.py
phu0ngng and others added 2 commits December 1, 2025 12:28
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@phu0ngng

phu0ngng commented Dec 1, 2025

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L0

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment thread tests/jax/test_triton_custom_calls.py
Comment thread transformer_engine/jax/triton_extensions/utils.py
Comment thread transformer_engine/jax/triton_extensions/utils.py Outdated

@jberchtold-nvidia jberchtold-nvidia left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
phu0ngng and others added 2 commits December 2, 2025 10:37
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng

phu0ngng commented Dec 2, 2025

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L0

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Comment thread tests/jax/test_triton_custom_calls.py Outdated
Comment thread tests/jax/test_triton_custom_calls.py Outdated
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng

phu0ngng commented Dec 2, 2025

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L0

@greptile-apps greptile-apps Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@phu0ngng phu0ngng merged commit f1512b2 into NVIDIA:main Dec 2, 2025
21 of 23 checks passed
@phu0ngng phu0ngng deleted the triton_binding branch December 2, 2025 17:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants