Skip to content

Limit grad recursion depth by not recursing through non-grad inputs#1764

Merged
awni merged 2 commits intomainfrom
grad_recursion_depth
Jan 14, 2025
Merged

Limit grad recursion depth by not recursing through non-grad inputs#1764
awni merged 2 commits intomainfrom
grad_recursion_depth

Conversation

@awni
Copy link
Member

@awni awni commented Jan 10, 2025

Previously we enclose non-gradient inputs in a lambda which then get recursed through when we build the VJP graph.

This change to allow an internal VJP to take a argnums list so that the VJP can have access to all the inputs including the ones for which no gradient is requested.

  • Also fixes a bug in python grad where we return the wrong result when given multiple argnums of the same value.

An example case which would previously recurse to 10k+ and now just recurses a couple of times:

fun = lambda x, y: x * y
x = mx.array(2.0)
for _ in range(10000):
    x = mx.abs(x)
y = mx.array(3.0)
dfdx = mx.grad(fun)(x, y)

@awni
Copy link
Member Author

awni commented Jan 10, 2025

@davidkoski this limits the stack-overflow issue with VJP that we were discussing offline. It does require a change in how VJP is called from Python -> C++. I'm not sure how it's done in Swift (if you enclosed the non-grad inputs or not). But it may also require a change there to pass in the non-grad inputs to value_and_grad rather than enclosing them.

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of mind-bender but it makes perfect sense afterwards. Nice!

l[i] = recurse(l[i]);
}
return nb::cast<nb::object>(subtree);
return nb::cast<nb::object>(nb::tuple(l));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch.

@awni awni merged commit 33421c1 into main Jan 14, 2025
@awni awni deleted the grad_recursion_depth branch January 14, 2025 22:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants