Skip to content

Conversation

@stas00
Copy link
Collaborator

@stas00 stas00 commented Feb 10, 2021

This PR adds instructions to install 1bit_adam dependencies

I think there are further fixes that are needed.

  1. the config at https://www.deepspeed.ai/tutorials/onebit-adam/#32-configuration-for-bingbertsquad-with-deepspeed-and-1-bit-adam-enabled
    suggests the user to use:
"max_grad_norm": 1.0,

but the program crashes when I add it:

  File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zombies/deepspeed/__init__.py", line 110, in initialize
        engine = DeepSpeedEngine(args=args,engine = DeepSpeedEngine(args=args,

  File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zombies/deepspeed/runtime/engine.py", line 174, in __init__
    self._configure_optimizer(optimizer, model_parameters)
  File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zombies/deepspeed/runtime/engine.py", line 549, in _configure_optimizer
    basic_optimizer = self._configure_basic_optimizer(model_parameters)
  File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zombies/deepspeed/runtime/engine.py", line 592, in _configure_basic_optimizer
    basic_optimizer = self._configure_basic_optimizer(model_parameters)
  File "/mnt/nvme1/code/github/00optimize/DeepSpeed-zombies/deepspeed/runtime/engine.py", line 592, in 
    raise ValueError(
ValueError: 'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details
  1. Here https://www.deepspeed.ai/tutorials/onebit-adam/#pre-requisites-for-1-bit-adam it says one has to use the special launcher but it works just fine without the launcher, that is I was able to run it with just:
deepspeed script.py

The config I used was (logger dump):

    "fp16":{
        "enabled":true,
        "hysteresis":2,
        "initial_scale_power":16,
        "loss_scale":0,
        "loss_scale_window":1000,
        "min_loss_scale":1
    },
    "gradient_accumulation_steps":1,
    "gradient_clipping":1.0,
    "optimizer":{
        "params":{
            "bias_correction":false,
            "cuda_aware":true,
            "freeze_step":400,
            "lr":0.0002,
            "weight_decay":0.01
        },
        "type":"OneBitAdam"
    },
    "scheduler":{
        "params":{
            "warmup_max_lr":3e-05,
            "warmup_min_lr":0,
            "warmup_num_steps":500
        },
        "type":"WarmupLR"
    },
    "steps_per_print":2000,
    "train_micro_batch_size_per_gpu":4,
    "wall_clock_breakdown":false,
    "zero_allow_untested_optimizer":true,
    "zero_optimization":{
        "allgather_bucket_size":200000000.0,
        "allgather_partitions":true,
        "contiguous_gradients":true,
        "cpu_offload":true,
        "overlap_comm":true,
        "reduce_bucket_size":200000000.0,
        "reduce_scatter":true,
        "stage":2
    }
}

@conglongli
Copy link
Contributor

@stas00 Thanks for the PR and the questions! For your first question, yes it is a typo in our tutorial, let me make a commit to your PR to fix it. For your second question, because currently 1-bit Adam uses MPI backend to communicate during the compression stage (after "freeze_step" steps), so without the mvapich or openmpi launcher it may crash after reaching the compression stage, or you may not be able to achieve the expected throughput gain. We recently wrote an arxiv paper about 1-bit Adam so feel free to take a read if interested in the details https://arxiv.org/abs/2102.02888. If you have more questions about 1-bit Adam, feel free to let me and @awan-10 know. Thanks!

@stas00
Copy link
Collaborator Author

stas00 commented Feb 10, 2021

Thank you, @conglongli!

I was just running a quick does-it-work-test trying to rule our an error a user reported in transformers, but it should really be submitted here, since the crash is somewhere in the DeepSpeed code. So I asked her to file an Issue here.

I'm primarily focused now on figuring out the Pipeline, so no spare resources on 1-bit Adam at the moment. But this is definitely something I'd love to experiment with down the road.

For your second question, because currently 1-bit Adam uses MPI backend to communicate during the compression stage (after "freeze_step" steps), so without the mvapich or openmpi launcher it may crash after reaching the compression stage, or you may not be able to achieve the expected throughput gain.

Perhaps it'd be useful to add your comment as a note in the section where the launcher is discussed, so that unwary user like myself will not conclude that all is good based on a quick test when seeing the training complete without a hitch?

@conglongli
Copy link
Contributor

Perhaps it'd be useful to add your comment as a note in the section where the launcher is discussed, so that unwary user like myself will not conclude that all is good based on a quick test when seeing the training complete without a hitch?

@stas00 Yes I agree. Will do. I noticed that your PR is from your fork. So I will just approve and merge your PR first, and then create another PR to fix your two questions. And let us know when you have more bandwidth/interests/questions about 1-bit Adam :)

Copy link
Contributor

@conglongli conglongli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jeffra jeffra merged commit 6beca3c into deepspeedai:master Feb 10, 2021
@stas00 stas00 deleted the one-bit-adam branch February 10, 2021 20:55
sdtblck added a commit to EleutherAI/DeeperSpeed that referenced this pull request Feb 11, 2021
* Dist testing backend fixes, etc. (deepspeedai#708)

* set_batch_fn and remove old sanity check (deepspeedai#712)

* properly set engine.local_rank if it's set to -1

* Add executable permission to `ds_elastic` and `ds_report` in `bin`. (deepspeedai#711)

* Add executable permission to `ds_elastic` and `ds_report` in `bin`.

* Automatic `ds_elastic` formatting

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* local rank of -1 means not set (deepspeedai#720)

* bump to 0.3.11

* [launcher] look ma, no more zombies (deepspeedai#714)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* Improve starred expressions (deepspeedai#696)

* Improve starred expressions

`deepspeed/profiling/flops_profiler/profiler.py` uses starred expressions
that are no longer valid with [PEP 617][1]. The new Python parser is in 3.9,
and this change allows DeepSpeed to run with the newest Python version. I have
not checked all locations that has this issue. However, this change allows me
to run simple examples.

[1]: https://www.python.org/dev/peps/pep-0617/

* Match style for "Improve starred expressions", although readability suffers

The style guide might need to be updated for this new use case of expressions.
Python [Issue 40631][1] includes more discussion on the change.

[1]: https://bugs.python.org/issue40631

Co-authored-by: Cheng Li <pistasable@gmail.com>

* Fixed typo in Readme. (deepspeedai#737)

* 1bit_adam dependencies (deepspeedai#742)

* Clickable screenshots (deepspeedai#746)

* Fix docstring

* Make screenshots clickable for easier viewing

* Add flops profiler tutorial (deepspeedai#682)

* work on flops profiler tutorial

* update flops profiler tutorial

* add flops profiler tutorial and fix names

* work on flops profiler tutorial

* update flops profiler tutorial

* add flops profiler tutorial and fix names

* fix tailing ws

* fix names

* remove multistep profiling and update docs

* fix cases where functionals and submodules coexist in a parent module, update readme

* fix typo

* always invoke post hook function

* fix module flops sum and update tests

* update tutorial

* Only initialize distributed if required (deepspeedai#734)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: Jon Eyolfson <eyolfson@gmail.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: TheDudeFromCI <thedudefromci@gmail.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Sean Naren <sean@grid.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants