SciPy-like differential evolution for JAX
Fully jitted optimization of any JAX-compatible function. Serial and parallel execution on CPU, GPU, and TPU.
pip install mutaximport jax.numpy as jnp
from mutax import differential_evolution
def cost_function(xs):
return jnp.sum(xs**2)
bounds = [(-5, 5)] * 10 # 10-dimensional problem with bounds for each dimension
result = differential_evolution(cost_function, bounds)
print("Best solution:", result.x)
print("Objective value:", result.fun)The documentation is available at Read the Docs.