Skip to content

[refactor] OSS only use flat buffers#371

Merged
blefaudeux merged 8 commits intomasterfrom
oss_flat
Feb 8, 2021
Merged

[refactor] OSS only use flat buffers#371
blefaudeux merged 8 commits intomasterfrom
oss_flat

Conversation

@blefaudeux
Copy link
Copy Markdown
Contributor

@blefaudeux blefaudeux commented Feb 6, 2021

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Just a natural extension of the params being tensor views of a bigger tensor (#300), no need to keep the non-bucketed params, I feel a little dumb but it's just much simpler and just as fast to have one single buffer per device and per rank (or faster for multi node, less latency).

So this PR just does that, all the trainable params which belong to the same device/rank link are packed into a single tensor, like before but not bounded. There's no tradeoff really, the original parameter memory is released.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 6, 2021
@blefaudeux blefaudeux marked this pull request as draft February 6, 2021 18:00
@blefaudeux blefaudeux marked this pull request as ready for review February 6, 2021 18:40
Copy link
Copy Markdown
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like a really nice simplification and reduces the memory use.

the max size of the buffer used to batch the small parameter tensors, in number of elements (default 16M).
this will not impact the long term memory consumption, but the peak memory can be impacted by the moment
when the buffers are allocated and the bucketed params have not yet been relocated to them.
(deprecated) used to cap the size of the broadcast buffers, not being used anymore.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not being used -> not being used because broadcast buffer not long exists?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically all the parameters tied to a given recipient are now one single tensor (+tensor views so that the FW pass does not see that), this setting is not respected anymore. I should have done that from the moment I realized that tensor views were an option (I thought initially that it was only on recent pytorch, turns out it was already there in 1.5), but at the time I got a bit stuck in the mental model for the reduce, where the buckets are more of a compromise right now (it's more complicated or not possible to release them -opposed to gradient sharding-, and they hold the whole BW back instead of overlapping coms/compute).
In this case (optimizer step + broadcast the results) there's no reason not to batch all the parameters in one go (limited by device/ranks where they go), not that I can see at least, so this option is not useful anymore because the buffers are not capped (no duplicated memory, the original param storage is released), I let it there for backward compatibility only

Copy link
Copy Markdown
Contributor

@msbaines msbaines left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Some minor nits (non-blocking).

@blefaudeux blefaudeux merged commit 77d9486 into master Feb 8, 2021
@blefaudeux blefaudeux deleted the oss_flat branch February 8, 2021 05:01
myleott pushed a commit that referenced this pull request Feb 22, 2021
* [chore] Fix lint errors that broke master (#348)

authored-by: Anjali Sridhar <anj@devfair0443.h2.fair>

* [fix] ShardedDDP - cpu testfix - remove Gloo/CPU (#350)

* no idea about the root issue, but it proved to be fairly narrowed (gloo+cpu+python3.8+no cuda installed) so I guess that's out of scope for fairscale

* [feat][OSS] elastic and pytorch compatible checkpoints (#310)

* adding a test to prove the inter operability with upstream pytorch
* updating the changelog
* eager state pruning
* pytorch 1.5 compat

* [fix] ShardedDDP - properly handle post device change (#353)

* adding the .to(device) support + unit testing
* doc update

* [feat] Add AdaScaleWrapper (#347)

* [feat] Add AdaScaleWrapper

- This enables a different API for wrapping an optimizer with AdaScale.
- This also enables AdaScale to be wrapped by OSS.
- However, OSS wrapping AdaScale results in different optimization,
  which future research will be needed to study its effects.

testing: add unit tests.

* addressed comment: typo

* [refactor] Refactor and enable multiprocess nn.Pipe benchmarks. (#319)

* mp cleanup

* round of multiprocess refactoring

* test golden run

* print cuda stats

* fix lint errors

* enable multiprocess pipe benchmarks

* set world size to be available gpus

* more changes

* use synthetic loaders for intermediate pipeline stages

* merged master

* fix for the devices property

* dataloader fix

* modify rank check

* print wps stats

* enable verification

* fix logging

* fix flag name

* fix flag name

* check for rank

* fix indent

* pass args

* pass args

* modify golden data

* remove unused print messsage

* fix lint errors

* add comments

* fix benchmarks

Co-authored-by: Anjali Sridhar <anj@devfair0443.h2.fair>

* [refactor] pipe: simplify balance and module checks (#346)

* [chore] v0.1.5 (#355)

* [chore] disheartening switch off of a OSS cpu test (#356)

* precise skip, only if agent has only cpu

* [feat][minor] OSS Benchmark - regression test + background testing new optims (#352)

* restoring the regression test, adding a test of the for_each optims
* fix the regression test on circleci
* removing unused flags

* [refactor] multiprocess_pipe: cleanup __init__ (#357)

* [refactor] multiprocess_pipe: remove retain_graph __init__ param (#358)

It is not currently being used so we can simplify the interface
by removing it.

* [refactor] multiprocess_pipe: focus on LazyModule usage (#360)

* [feat] ShardedDDP : Adding a proper DDP parity / AMP unit test, overdue (#361)

* Adding a proper ddp parity / AMP unit test, overdue
* catch non-AMP pytorch

* [perf][OSS] Clip grad norm : minor obvious speedup (#363)

cache this iterator, easy speed up

* [refactor] multiprocess_pipe: remove pipelined_backward (#362)

* [perf] ShardedDDP - small memory use reduction - minor speedup (#366)

* minor

* minor

* [fix] repro+fix (#365)

fix a broken earlier commit, only worked for the first step

* [refactor] OSS only use flat buffers (#371)

* flat params all along, way simpler
* updating the docstring

* [refactor] AsyncPipe: do not sub-class MultiProcessPipe (#370)

* [refactor] remove multiprocess dependency on async (#373)

* [fix] Workaround need for pip --no-build-isolation (#375)

* Add fairscale.nn.misc.checkpoint_activations (#376)

* Add fairscale.utils.containers

Co-authored-by: Min Xu <24926999+min-xu-ai@users.noreply.github.com>

* Add fairscale.nn.misc.checkpoint_activations

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>

Co-authored-by: Min Xu <24926999+min-xu-ai@users.noreply.github.com>
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>

* [chore] v0.1.6 (#377)

* v0.1.6

Co-authored-by: anj-s <32556631+anj-s@users.noreply.github.com>
Co-authored-by: Benjamin Lefaudeux <blefaudeux@users.noreply.github.com>
Co-authored-by: Anjali Sridhar <anj@devfair0443.h2.fair>
Co-authored-by: msbaines <35972327+msbaines@users.noreply.github.com>
Co-authored-by: Leonard Lausen <leonard@lausen.nl>
Co-authored-by: Myle Ott <myleott@fb.com>
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants