-
Notifications
You must be signed in to change notification settings - Fork 120
Closed
Labels
autodiffIssue related to auto-differentiationIssue related to auto-differentiationfeatureNew feature or requestNew feature or request
Description
Version of Awkward Array
2.8.3
Description and code to reproduce
To reproduce:
import awkward as ak
import numpy as np
import jax
ak.jax.register_and_check()
test_regulararray = ak.Array(
[[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]],
backend="jax"
)
test_regulararray_tangent = ak.Array(
[[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]],
backend="jax"
)
test_regulararray_jax = jax.numpy.array(
[[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]],
dtype=np.float64
)
test_regulararray_tangent_jax = jax.numpy.array(
[[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]],
dtype=np.float64
)
axis = 0
func_ak = ak.sort
func_jax = jax.numpy.sort
def func_ak_with_axis(x):
return func_ak(x, axis=axis)
def func_jax_with_axis(x):
return func_jax(x, axis=axis)
value_jvp, jvp_grad = jax.jvp(
func_ak_with_axis, (test_regulararray,), (test_regulararray_tangent,)
)
value_jvp_jax, jvp_grad_jax = jax.jvp(
func_jax_with_axis, (test_regulararray_jax,), (test_regulararray_tangent_jax,)
)
value_vjp, vjp_func = jax.vjp(func_ak_with_axis, test_regulararray)
value_vjp_jax, vjp_func_jax = jax.vjp(func_jax_with_axis, test_regulararray_jax)
if __name__ == "__main__":
print("Awkward JVP value:\n", value_jvp)
print("JAX JVP value:\n", value_jvp_jax)
print("Awkward JVP grad:\n", jvp_grad)
print("JAX JVP grad:\n", jvp_grad_jax)
print("Awkward VJP value:\n", value_vjp)
print("JAX VJP value:\n", value_vjp_jax)Prints out
Awkward JVP value:
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
JAX JVP value:
[[1. 2. 3.]
[4. 5. 6.]
[7. 8. 9.]]
Awkward JVP grad:
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
JAX JVP grad:
[[1. 2. 3.]
[4. 5. 6.]
[7. 8. 9.]]
Awkward VJP value:
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
JAX VJP value:
[[1. 2. 3.]
[4. 5. 6.]
[7. 8. 9.]]
So awkward gives zeros while jax gives the sorted array back.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
autodiffIssue related to auto-differentiationIssue related to auto-differentiationfeatureNew feature or requestNew feature or request