Update numerical verification for SPMD Linear checkpointing#9113
Update numerical verification for SPMD Linear checkpointing#9113sdasgup3 wants to merge 1 commit intopytorch:masterfrom
Conversation
|
I see that both the results are running on TPU, one with checkpointing and one without gradient checkpointing. The result should be exactly the same. Generally, there is some tolerance when we are running on different hardwares but in this case, I expected them to be exactly the same. I think that we should take a closer look at this problem. We might find something is wrong with the way we checkpoint. |
|
This has been a long-lived bug and Bhavya initially expressed concern about some deviation between grad checkpointing and not. But more recently, he is okay with merging this fix and opening a new issue to further investigate numerical equiv. of checkpointing vs not. I have set this PR to automerge once the tests pass. Thanks guys. |
|
As I look at this bug a bit further, I think it's actually okay to include some tol and not expect exact equivalence between grad checkpointing and not. With grad checkpointing vs not, you could imagine XLA reordering some operations, leading to some small rounding issues in the final bit of the weights. The best tolerance to measure would be on the weight and activation itself... by the time you get to loss, you can imagine a decent amount of movement. |
|
This PR is stuck in cicd and the creator is not working on it. Closing, and replaced with #9404 |
The current PR tracks a issue where an internal TPU CI is failing on v5p hardware. A specific test failing with assertion failure at test_train_spmd_linear_model.py#L49 and test_train_spmd_linear_model.py#L51 with maximum absolute difference of
0.0042718649and0.0000191778respectively.The fix here is to update the corresponding atols.