Skip to content

fix: error if a complex JAX arrays type is passed to Awkward's C++ kernels#3546

Merged
pfackeldey merged 4 commits intomainfrom
pfackeldey/jax_error_differentiating_through_cpp_kernels
Jun 14, 2025
Merged

fix: error if a complex JAX arrays type is passed to Awkward's C++ kernels#3546
pfackeldey merged 4 commits intomainfrom
pfackeldey/jax_error_differentiating_through_cpp_kernels

Conversation

@pfackeldey
Copy link
Copy Markdown
Collaborator

Closes #3541.

Let's error instead of silently skipping the Awkward C++ kernel call (and thus returning wrong zeros)

Example error with this PR:

import awkward as ak
import numpy as np
import jax

ak.jax.register_and_check()

test_regulararray = ak.Array(
  [[1.0, 2.0, 3.0],
    [4.0, 5.0, 6.0],
    [7.0, 8.0, 9.0]],
  backend="jax"
)

test_regulararray_tangent = ak.Array(
  [[1.0, 2.0, 3.0],
    [4.0, 5.0, 6.0],
    [7.0, 8.0, 9.0]],
  backend="jax"
)

value_jvp, jvp_grad = jax.jvp(
  lambda x: ak.sort(x, axis=0), (test_regulararray,), (test_regulararray_tangent,)
)
# >> ValueError: Encountered Traced<float64[9]>with<JVPTrace> with
#   primal = Array([1., 4., 7., 2., 5., 8., 3., 6., 9.], dtype=float64)
#   tangent = Array([1., 4., 7., 2., 5., 8., 3., 6., 9.], dtype=float64) as an (invalid) input to the 'awkward_sort' Awkward C++ kernel. This kernel is not differentiable by the JAX backend.
# 
# This error occurred while calling
#
#    ak.sort(
#        <Array [[...], [...], [...]] type='3 * var * float64'>
#        axis = 0
#    )

Making Awkward's C++ kernels differentiable with JAX is quite complicated. If we ever want to do that we'd have to define per kernel the forward and backward differentiation implementation, and properly pass the dual number (primal & tangent) that's passed by JAX to the kernel through those. Alternatively, we could try expressing those kernels with JAX own primitives that define those implementations already - not sure if that is possible for all kernels though.

Copy link
Copy Markdown
Member

@ianna ianna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pfackeldey - thanks! Looks great! Please merge it if you finished with it. Thanks!

@pfackeldey pfackeldey merged commit cce47a2 into main Jun 14, 2025
42 checks passed
@pfackeldey pfackeldey deleted the pfackeldey/jax_error_differentiating_through_cpp_kernels branch June 14, 2025 19:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ak.sort for JAX Tracer [was: JVP and VJP different between ak.sort and jax.numpy.sort]

2 participants