FIX: jax_intro timeout - use lax.fori_loop instead of Python for loop#249
Merged
FIX: jax_intro timeout - use lax.fori_loop instead of Python for loop#249
Conversation
✅ Deploy Preview for incomparable-parfait-2417f8 ready!
To edit notification comments on pull requests, go to your Netlify project configuration. |
The compute_call_price_jax function was timing out during builds due to JAX unrolling the Python for loop during JIT compilation. With large arrays (M=10,000,000), this causes excessive compilation time. Solution: Replace the Python for loop with jax.lax.fori_loop, which compiles the loop efficiently without unrolling. Same fix as QuantEcon/lecture-python-programming#442
569d0a6 to
fa8d48e
Compare
Contributor
Author
|
@jstac just transferring the fix from QuantEcon/lecture-python-programming#442 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
The
compute_call_price_jaxfunction injax_intro.mdwas timing out during builds (600s cell execution timeout).Root Cause
JAX unrolls Python
forloops during JIT compilation. With large arrays (M=10,000,000), this causes excessive compilation time as JAX traces through each iteration separately.Solution
Replace the Python
forloop withjax.lax.fori_loop, which compiles the loop efficiently without unrolling:Added an explanatory note for students about why we use
fori_loop.Related
This is the same fix as QuantEcon/lecture-python-programming#442