Skip to content

[TEST ONLY] print statements for test_zero1.py to debug#5377

Closed
janeyx99 wants to merge 2 commits intomasterfrom
test-print-statements
Closed

[TEST ONLY] print statements for test_zero1.py to debug#5377
janeyx99 wants to merge 2 commits intomasterfrom
test-print-statements

Conversation

@janeyx99
Copy link
Copy Markdown
Contributor

@janeyx99 janeyx99 commented Jul 31, 2023

Ah, the issue is that the XLA ZeroRedundantOptimizer makes a deepcopy and keeps the original state_dict passed in unmodified whereas our optimizer no longer makes a copy so ours will have progressed with step.

The following logs are for when we have the following print statements:

    opt1 = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    opt2 = ZeroRedundancyOptimizer(
        model.parameters(),
        torch.optim.SGD,
        lr=0.01,
        momentum=0.9,
        grad_clipping=False)

    opt1.step()
    opt2.step()
    s1 = opt1.state_dict()
    s2 = opt2.state_dict()
    print("AFTER STEPPING ONCE")
    print("opt1.state", opt1.state)
    print("opt1.state_dict()", s1)
    print("opt2.state[base]", opt2.state['base'])
    print("opt2.state_dict()[base]", s2['base'])
    self.assertEqual(s1, s2['base'])

    # s1_clone = deepcopy(s1)
    # s2_clone = deepcopy(s2)
    opt1.load_state_dict(s1)
    opt2.load_state_dict(s2)
    print("AFTER LOADING THE STATE_DICTs, should be same as before")
    print("opt1.state", opt1.state)
    print("opt1.state_dict()", opt1.state_dict())
    print("opt2.state", opt2.state['base'])
    print("opt2.state_dict()[base]", opt2.state_dict()['base'])
    self.assertEqual(opt1.state_dict(), opt2.state_dict()['base'])

    # step still runnable
    opt1.step()
    opt2.step()
    print("AFTER STEPPING AGAIN, WILL be different")
    print("opt1.state", opt1.state)
    print("opt1.state_dict()", opt1.state_dict())
    print("opt2.state", opt2.state['base'])
    print("opt2.state_dict()[base]", opt2.state_dict()['base'])
    opt1.load_state_dict(s1)
    opt2.load_state_dict(s2)
    print("AFTER LOADING THE STATE_DICTs, should be same as before")
    print("opt1.state", opt1.state)
    print("opt1.state_dict()", opt1.state_dict())
    print("opt2.state", opt2.state['base'])
    print("opt2.state_dict()[base]", opt2.state_dict()['base'])
    self.assertEqual(opt1.state_dict(), opt2.state_dict()['base'])

https://github.com/pytorch/xla/actions/runs/5717993385/job/15493587952?pr=5377

I will update this PR to test a fix.

@janeyx99 janeyx99 force-pushed the test-print-statements branch from 8ab5fc7 to baff431 Compare July 31, 2023 16:15
@janeyx99 janeyx99 force-pushed the test-print-statements branch from baff431 to 0f139ab Compare July 31, 2023 17:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant