Fix preserve_rng_state for activation checkpointing#4690
Fix preserve_rng_state for activation checkpointing#4690JackCaoG merged 6 commits intopytorch:masterfrom
Conversation
|
Thanks! Mostly LGTM. Can you add a test case to maybe https://github.com/pytorch/xla/blob/master/test/test_operations.py ? You can compare the result with xla device and cpu device. this way we won't regress this. |
| torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ | ||
| torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): | ||
| outputs = ctx.run_function(*detached_inputs) | ||
| with xm.fork_rng(): |
There was a problem hiding this comment.
any reason not to pass the rng_devices and ctx.preserve_rng_state ?
There was a problem hiding this comment.
It looks like the upstream code doesn't reset the state. @YangFei1990 Do you know why?
There was a problem hiding this comment.
I guess upstream seed is handled by torch.random.fork_rng? through I am not sure why it doesn't work with pytorch/xla...
There was a problem hiding this comment.
Yeah upstream seed is handled by torch.random.fork_rng. It will fork torch seed but somehow it won't set XLA's RNG. This seed torch_xla._XLAC._xla_get_rng_seed(str(device) is it independent to torch seed? How torch XLA in general handle RNGs?
There was a problem hiding this comment.
I did not change the previous behavior, i.e. upstream seed will still be maintained as it was (check code below). I simply add another preserve RNG states.
| output = torch.sum(output) | ||
| output.backward() | ||
| xm.mark_step() | ||
| same_output = torch.allclose(model.to_save[0], model.to_save[1]) |
There was a problem hiding this comment.
to_save is the container to hold the output tensor. With activation checkpointing the FWD will run twice, this container can capture both tensors. Check line 2352.
| same_output = torch.allclose(model.to_save[0], model.to_save[1]) | ||
| if not same_output: | ||
| print(f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}") | ||
| self.assertTrue(same_output) |
There was a problem hiding this comment.
I think you can do something similar to
self.assertTrue(same_output, f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}")
There was a problem hiding this comment.
Awesome didn't know could do that. Updating.
alanwaketan
left a comment
There was a problem hiding this comment.
Mostly, LGTM. Please address the comments.
| torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ | ||
| torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): | ||
| outputs = ctx.run_function(*detached_inputs) | ||
| with xm.fork_rng(): |
There was a problem hiding this comment.
It looks like the upstream code doesn't reset the state. @YangFei1990 Do you know why?
|
I will take care of the backport |
In the activation checkpointing implementation we have the
preserve_rng_stateoption, if it is set toTrue, activation checkpointing should use the same RNG state for the two forward runs in a single step. Consider the following test script with activation checkpoint and a dropout op in the model:If everything works right
same_outputshould beTrue. However we observed without XLA it works correctlyBut with XLA it is wrong
This PR fixed this issue by also saving/loading the XLA's RNG state in the activation checkpointing implementation. After the fix the output matches between the 2 forwards.