Skip to content

[fix] OSS pytorch-compliant state dict#61

Merged
blefaudeux merged 3 commits intomasterfrom
oss_pytorch_state_dict
Sep 3, 2020
Merged

[fix] OSS pytorch-compliant state dict#61
blefaudeux merged 3 commits intomasterfrom
oss_pytorch_state_dict

Conversation

@blefaudeux
Copy link
Copy Markdown
Contributor

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #60 .

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Except for the realization that all my lint errors were due to isort and black being version-dependent in their behaviour (FFS), all good 🙃

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 2, 2020
Copy link
Copy Markdown
Contributor

@msbaines msbaines left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if you un-sharded in state-dict and then re-sharded in load_state_dict? That way, we work exactly like PyTorch is expecting.

return {"state": self._all_states}
return {
"state": [s["state"] for s in self._all_states],
"param_groups": [s["param_groups"] for s in self._all_states],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove param_groups from _all_states to avoid having a redundant copy?

Copy link
Copy Markdown
Contributor Author

@blefaudeux blefaudeux Sep 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't own this state dict object, it's returned and could have any lifetime, I don't see the issue with a redundant copy ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A redundant copy won't consume memory since they are just references but I'm wondering if it will increase disk space. Would be good to check that. I'm also wondering what will happen on load. Will the two copies refer to the same objects or will there be two copies. That could cause OOM.

OSS created the state_dicts in _collect_sharded_states. There are no references outside OSS at this point.

Copy link
Copy Markdown
Contributor Author

@blefaudeux blefaudeux Sep 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aaaah, I see what you mean, but it's not related to this PR, right ? In general I agree, _all_states could be changed to remove duplicates, but to me this is a relatively independent change, I was trying to address the get/load state interface only. In the same fashion, it feels to me like there must be duplicates in between self.param_groups, self._all_states and self.optim.param_groups

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re: lifetime, I was refering to the object you commented on, return {}, I thought that was the subject of your comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried about this potentially introducing a regression to fairseq. They are now using this in master and I tested with only a single copy of the parameters. Maybe just confirm that disk space does not increase and we don't get duplicate tensors for the params on load. Might be cool to add a memory benchmark (@andersonic added a memory assertion in one of the gpipe benchmarks). If the duplicates are referencing the same object then it is not an issue.

@blefaudeux
Copy link
Copy Markdown
Contributor Author

What if you un-sharded in state-dict and then re-sharded in load_state_dict? That way, we work exactly like PyTorch is expecting.

The pytorch expectation is not super clear to me, param_groups is a list (even if described as a dict), there does not seem to be a lot more constraints in https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.state_dict than that.

You mean unrolling the list so that param_groups is [dict()] ? I can do that

@blefaudeux blefaudeux marked this pull request as draft September 2, 2020 22:01
@blefaudeux blefaudeux changed the title Oss pytorch state dict [fix] OSS pytorch-compliant state dict Sep 3, 2020
@blefaudeux blefaudeux marked this pull request as ready for review September 3, 2020 17:41
@blefaudeux blefaudeux merged commit 1d1d15e into master Sep 3, 2020
@blefaudeux blefaudeux deleted the oss_pytorch_state_dict branch September 3, 2020 18:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Rewrite state_dict in a more pytorch idiomatic way

3 participants