Skip to content

ak.sort for JAX Tracer [was: JVP and VJP different between ak.sort and jax.numpy.sort] #3541

@ikrommyd

Description

@ikrommyd

Version of Awkward Array

2.8.3

Description and code to reproduce

To reproduce:

import awkward as ak
import numpy as np
import jax

ak.jax.register_and_check()

test_regulararray = ak.Array(
    [[1.0, 2.0, 3.0],
     [4.0, 5.0, 6.0],
     [7.0, 8.0, 9.0]],
    backend="jax"
)

test_regulararray_tangent = ak.Array(
    [[1.0, 2.0, 3.0],
     [4.0, 5.0, 6.0],
     [7.0, 8.0, 9.0]],
    backend="jax"
)

test_regulararray_jax = jax.numpy.array(
    [[1.0, 2.0, 3.0],
     [4.0, 5.0, 6.0],
     [7.0, 8.0, 9.0]],
    dtype=np.float64
)

test_regulararray_tangent_jax = jax.numpy.array(
    [[1.0, 2.0, 3.0],
     [4.0, 5.0, 6.0],
     [7.0, 8.0, 9.0]],
    dtype=np.float64
)

axis = 0
func_ak = ak.sort
func_jax = jax.numpy.sort

def func_ak_with_axis(x):
    return func_ak(x, axis=axis)

def func_jax_with_axis(x):
    return func_jax(x, axis=axis)

value_jvp, jvp_grad = jax.jvp(
    func_ak_with_axis, (test_regulararray,), (test_regulararray_tangent,)
)

value_jvp_jax, jvp_grad_jax = jax.jvp(
    func_jax_with_axis, (test_regulararray_jax,), (test_regulararray_tangent_jax,)
)

value_vjp, vjp_func = jax.vjp(func_ak_with_axis, test_regulararray)
value_vjp_jax, vjp_func_jax = jax.vjp(func_jax_with_axis, test_regulararray_jax)

if __name__ == "__main__":
    print("Awkward JVP value:\n", value_jvp)
    print("JAX JVP value:\n", value_jvp_jax)
    print("Awkward JVP grad:\n", jvp_grad)
    print("JAX JVP grad:\n", jvp_grad_jax)
    print("Awkward VJP value:\n", value_vjp)
    print("JAX VJP value:\n", value_vjp_jax)

Prints out

Awkward JVP value:
 [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
JAX JVP value:
 [[1. 2. 3.]
 [4. 5. 6.]
 [7. 8. 9.]]
Awkward JVP grad:
 [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
JAX JVP grad:
 [[1. 2. 3.]
 [4. 5. 6.]
 [7. 8. 9.]]
Awkward VJP value:
 [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
JAX VJP value:
 [[1. 2. 3.]
 [4. 5. 6.]
 [7. 8. 9.]]

So awkward gives zeros while jax gives the sorted array back.

Metadata

Metadata

Assignees

Labels

autodiffIssue related to auto-differentiationfeatureNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions