Skip to content

gerlero/mutax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

98 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Mutax

SciPy-like differential evolution for JAX

Fully jitted optimization of any JAX-compatible function. Serial and parallel execution on CPU, GPU, and TPU.

Documentation CI Codecov Ruff ty uv Publish PyPI PyPI - Python Version

Installation

pip install mutax

Quick start

import jax.numpy as jnp
from mutax import differential_evolution

def cost_function(xs):
    return jnp.sum(xs**2)

bounds = [(-5, 5)] * 10  # 10-dimensional problem with bounds for each dimension

result = differential_evolution(cost_function, bounds)
print("Best solution:", result.x)
print("Objective value:", result.fun)

Documentation

The documentation is available at Read the Docs.