See change in https://github.com/ml-explore/mlx/pull/1764 -- this may require an update in how `vjp` is called.