Skip to content

FIX: jax_intro timeout - use lax.fori_loop instead of Python for loop#249

Merged
mmcky merged 1 commit intomainfrom
fix-jax-intro-fori-loop
Nov 28, 2025
Merged

FIX: jax_intro timeout - use lax.fori_loop instead of Python for loop#249
mmcky merged 1 commit intomainfrom
fix-jax-intro-fori-loop

Conversation

@mmcky
Copy link
Copy Markdown
Contributor

@mmcky mmcky commented Nov 28, 2025

Problem

The compute_call_price_jax function in jax_intro.md was timing out during builds (600s cell execution timeout).

Root Cause

JAX unrolls Python for loops during JIT compilation. With large arrays (M=10,000,000), this causes excessive compilation time as JAX traces through each iteration separately.

Solution

Replace the Python for loop with jax.lax.fori_loop, which compiles the loop efficiently without unrolling:

# Before (Python for loop - gets unrolled)
for t in range(n):
    key, subkey = jax.random.split(key)
    Z = jax.random.normal(subkey, (2, M))
    s = s + μ + jnp.exp(h) * Z[0, :]
    h = ρ * h + ν * Z[1, :]

# After (JAX fori_loop - compiled efficiently)
def update(i, state):
    s, h, key = state
    key, subkey = jax.random.split(key)
    Z = jax.random.normal(subkey, (2, M))
    s = s + μ + jnp.exp(h) * Z[0, :]
    h = ρ * h + ν * Z[1, :]
    return s, h, key

s, h, key = jax.lax.fori_loop(0, n, update, (s, h, key))

Added an explanatory note for students about why we use fori_loop.

Related

This is the same fix as QuantEcon/lecture-python-programming#442

@netlify
Copy link
Copy Markdown

netlify bot commented Nov 28, 2025

Deploy Preview for incomparable-parfait-2417f8 ready!

Name Link
🔨 Latest commit fa8d48e
🔍 Latest deploy log https://app.netlify.com/projects/incomparable-parfait-2417f8/deploys/69291a54534ff700080e34f1
😎 Deploy Preview https://deploy-preview-249--incomparable-parfait-2417f8.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify project configuration.

The compute_call_price_jax function was timing out during builds due to
JAX unrolling the Python for loop during JIT compilation. With large
arrays (M=10,000,000), this causes excessive compilation time.

Solution: Replace the Python for loop with jax.lax.fori_loop, which
compiles the loop efficiently without unrolling.

Same fix as QuantEcon/lecture-python-programming#442
@mmcky mmcky force-pushed the fix-jax-intro-fori-loop branch from 569d0a6 to fa8d48e Compare November 28, 2025 03:43
@mmcky
Copy link
Copy Markdown
Contributor Author

mmcky commented Nov 28, 2025

@jstac just transferring the fix from QuantEcon/lecture-python-programming#442

@github-actions
Copy link
Copy Markdown

github-actions bot commented Nov 28, 2025

@github-actions github-actions bot temporarily deployed to pull request November 28, 2025 03:49 Inactive
@github-actions github-actions bot temporarily deployed to pull request November 28, 2025 03:53 Inactive
@mmcky mmcky merged commit 1991775 into main Nov 28, 2025
7 checks passed
@mmcky mmcky deleted the fix-jax-intro-fori-loop branch November 28, 2025 04:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant