Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124573
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| self, | ||
| cond_fn: Callable, | ||
| body_fn: Callable, | ||
| body_grad_fn: Callable, |
There was a problem hiding this comment.
Ideally, the backward of while_loop should be another while_loop operator with the same function signature cond_fn, body_fn, operands.
To get a backward formula for backward cond and backward body, consider the following example.
- forward: Inp0 -> body_fn -> inp1 -> body_fn -> … inpN-1 -> body_fn -> output.
- corresponding backward would be: grad_output, output, inpN-1-> backward_body_fn -> grad_inpN-1, inpN-1, , inpN-2 -> backward_body_fn -> …inp1, grad_inp1, inp0 -> backward_body_fn -> grad_inp0
Specifically, one possible design could be:
def backward_cond_fn(grad, fw_outputs(0…N)):
return fw_outputs.size() > 1def backward_body_fn(grad: Tensors, fw_outputs: TensorList):
Output = fw_outputs.pop() # (0…N-1)
InpN-1 = fw_outputs.back()
# do a re-computation based on inpN-1 since we didn't save the intermediates of each iteration.
# we could extend this by saving the important intermediates when it's necessary.
grad_N-1 = fw_bw(grad, output, inpN-1,)
Return gradN-1, fw_outputs #(0…N-1)The backward is then:
while_loop(backward_cond_fn, backward_body_fn, (grad_out, fw_outputs))
This might require us to support a dynamic list with unspecialized length cc @zou3519
There was a problem hiding this comment.
I tried to follow your suggestion in my latest commit.
|
I just leave this here for future developments on this:
|
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
|
@pytorchbot label no-stale |
|
❌ 🤖 pytorchbot command failed: Try |
|
@pytorchbot label no-stale |
b7aefd8 to
eaf9d31
Compare
This PR needs a
|
eaf9d31 to
01bb249
Compare
This PR is an attempt to add Autograd to the
while_loopfunctionality of PT.@ydwu4
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec