@@ -33,6 +33,30 @@ def _orthogonalize(matrix, epsilon=1e-8):
3333
3434
3535class PowerSGDState (object ):
36+ """
37+ Stores both the gradient compression configs and the internal states for all the gradients during the training.
38+ Particularly, `matrix_approximation_rank` and `start_powerSGD_iter` are the main configs that need to be tuned by the user.
39+ Although `use_error_feedback` and `warm_start` can also be tuned by the user,
40+ they are typically turned on for performance.
41+
42+ Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`]
43+ ~~~~~~~~~~~~~~~~~~~~~~~~~~
44+ 1) To tune `matrix_approximation_rank`, the user can increase it from 1 by factors of 2,
45+ until a satisfying accuracy can be reached.
46+ The increase of `matrix_approximation_rank` can substantially increase the computation costs of the compression.
47+ However, the accuracy may not be futher improved beyond a certain `matrix_approximation_rank` value.
48+ 2) To tune `start_powerSGD_iter`, the user can typically start with 10% of total training steps,
49+ and increase it until a satisfying accuracy can be reached.
50+ Deferrring PowerSGD can effectively improve the accuracy,
51+ even a relatively small `matrix_approximation_rank` is used.
52+ This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients,
53+ and compressing gradients too early may make the training quickly take a suboptimal trajectory,
54+ which can result in an irrecoverable impact on the accuracy.
55+ The minimum value allowed in DDP is 2, if error feedback or warm-up is enabled.
56+ This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP,
57+ and this can conflict with any tensor memorized before the rebuild process.
58+ """
59+
3660 __slots__ = [
3761 "process_group" ,
3862 # The two fields below are the configs that usually need to be tuned by the user.
@@ -58,6 +82,16 @@ def __init__(
5882 warm_start = True ,
5983 random_seed = 0 ,
6084 ):
85+ logging .info (
86+ "PowerSGD config: matrix_approximation_rank = {}; "
87+ "start_powerSGD_iter = {}; use_error_feedback = {}; warm_start = {}." .format (
88+ matrix_approximation_rank ,
89+ start_powerSGD_iter ,
90+ use_error_feedback ,
91+ warm_start ,
92+ )
93+ )
94+
6195 self .process_group = process_group
6296 # The low rank for matrix approximation controls the size of compressed low-rank tensors,
6397 # which determines the computation ratio.
@@ -80,7 +114,11 @@ def __init__(
80114 # However, this means that the shape of input bucketized tensors is subject to change,
81115 # which will complicate the implementations of error feedback and warm-up.
82116 # Running vanilla allreduce in the first few iterations can avoid this complexity.
83- assert start_powerSGD_iter >= 1
117+ if (use_error_feedback or warm_start ) and start_powerSGD_iter <= 1 :
118+ raise ValueError (
119+ "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
120+ "because PowerSGD can only be applied after the first two iterations in DDP."
121+ )
84122 self .start_powerSGD_iter = start_powerSGD_iter
85123 # Error feedback is usually crucial for both for convergence and generalization,
86124 # because PowerSGD is a biased compressor,
@@ -111,16 +149,6 @@ def __init__(
111149 # Iteration/step in the training loop.
112150 self .iter = 0
113151
114- logging .info (
115- "PowerSGD config: matrix_approximation_rank = {}; "
116- "start_powerSGD_iter = {}; use_error_feedback = {}; warm_start = {}." .format (
117- self .matrix_approximation_rank ,
118- self .start_powerSGD_iter ,
119- self .use_error_feedback ,
120- self .warm_start ,
121- )
122- )
123-
124152 def maybe_increase_iter (self , bucket ):
125153 # Since bucket 0 is the last bucket to allreduce in an iteration.
126154 # Only increase `iter` when bucket 0 is processed.
@@ -165,6 +193,7 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
165193
166194 Args:
167195 state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
196+ To tune the compression configs, see Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`].
168197 bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
169198 Note that since DDP comm hook only supports single process single device mode at this time,
170199 only exactly one tensor is stored in this bucket.
@@ -399,6 +428,13 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
399428 7) Computes M, which is approximately equal to PQ^T.
400429 8) Truncates the input tensor to the original length.
401430
431+ This variant is faster than `powerSGD_hook` that runs layer-wise gradient compression,
432+ but it usually results in a much lower accuracy, unless `matrix_approximation_rank` in the state is 1.
433+ Increasing `matrix_approximation_rank` may not necessarily increase the accuracy,
434+ because batching per-parameter tensors without column/row alignment can destroy low-rank structure.
435+ Therefore, the user shoud always consider `powerSGD_hook` first,
436+ and only consider this variant when a satisfying accuracy can be achieved when `matrix_approximation_rank` is 1.
437+
402438 Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations.
403439 This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy,
404440 but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers.
@@ -409,6 +445,7 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
409445
410446 Args:
411447 state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
448+ To tune the compression configs, see Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`].
412449 bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
413450 Note that since DDP comm hook only supports single process single device mode at this time,
414451 only exactly one tensor is stored in this bucket.
0 commit comments