-
Notifications
You must be signed in to change notification settings - Fork 584
feat(pd): support gradient accumulation #4920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdd gradient accumulation controlled by a new Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant Optimizer
participant Scheduler
participant AllReduce as AllReduce (world_size>1)
rect rgba(200,230,255,0.25)
Note right of Trainer: Repeat for each micro-step (1..acc_freq)
Trainer->>Trainer: forward()
Trainer->>Trainer: backward() -- accumulate grads
end
rect rgba(220,255,200,0.25)
Note right of Trainer: On accumulation boundary\n((_step_id+1) % acc_freq == 0)
Trainer->>AllReduce: fused_allreduce_gradients() [if world_size>1]
Trainer->>Trainer: gradient clipping() [if gradient_max_norm>0]
Trainer->>Optimizer: step()
Trainer->>Optimizer: clear_grad(set_to_zero=False)
Trainer->>Scheduler: step()
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. 📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 💡 Knowledge Base configuration:
You can enable these sources in your CodeRabbit configuration. 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (27)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
deepmd/pd/train/training.py (2)
835-856: Do not clear gradients during validation; it breaks accumulation. Use no_grad instead.Calling self.optimizer.clear_grad() here zeroes partially accumulated gradients whenever display logging runs (e.g., on step 1), corrupting the accumulation window. Wrap validation forward in paddle.no_grad() and remove the clear_grad.
- for ii in range(valid_numb_batch): - self.optimizer.clear_grad() - input_dict, label_dict, _ = self.get_data( - is_train=False, task_key=_task_key - ) - if input_dict == {}: - # no validation data - return {} - _, loss, more_loss = self.wrapper( - **input_dict, - cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), - label=label_dict, - task_key=_task_key, - ) + for ii in range(valid_numb_batch): + with paddle.no_grad(): + input_dict, label_dict, _ = self.get_data( + is_train=False, task_key=_task_key + ) + if input_dict == {}: + # no validation data + return {} + _, loss, more_loss = self.wrapper( + **input_dict, + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), + label=label_dict, + task_key=_task_key, + )
888-901: Do not clear gradients when logging other-task training metrics; it breaks accumulation. Use no_grad.Same issue as above. Clearing grads here will wipe partially accumulated grads in the current task when disp_training fires. Remove the clear and use no_grad.
- if _key != task_key: - self.optimizer.clear_grad() - input_dict, label_dict, _ = self.get_data( - is_train=True, task_key=_key - ) - _, loss, more_loss = self.wrapper( - **input_dict, - cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), - label=label_dict, - task_key=_key, - ) + if _key != task_key: + with paddle.no_grad(): + input_dict, label_dict, _ = self.get_data( + is_train=True, task_key=_key + ) + _, loss, more_loss = self.wrapper( + **input_dict, + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), + label=label_dict, + task_key=_key, + )
🧹 Nitpick comments (3)
deepmd/pd/train/training.py (3)
136-138: Validate and document acc_freq inputAdd a guard and a short docstring comment to ensure acc_freq is a positive integer. This avoids silent misconfigurations (e.g., 0 or negatives) and clarifies semantics.
self.num_steps = training_params["numb_steps"] self.acc_freq: int = training_params.get( - "acc_freq", 1 + "acc_freq", 1 ) # gradient accumulation steps +assert isinstance(self.acc_freq, int) and self.acc_freq >= 1, "training.acc_freq must be an integer >= 1"
794-801: Optional: move gradient clipping into optimizer via grad_clip to reduce per-step python overheadPaddle supports optimizer-level gradient clipping (e.g., grad_clip=paddle.nn.ClipGradByGlobalNorm). This avoids per-step Python calls and makes behavior uniform.
Example:
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=self.gradient_max_norm) self.optimizer = paddle.optimizer.Adam( learning_rate=self.scheduler, parameters=self.wrapper.parameters(), grad_clip=grad_clip, )
751-759: Consider using pref_lr tied to update count when accumulatingWhen acc_freq > 1, the “logical” step for LR scheduling is the optimizer update, not each micro-step. You’re calling scheduler.step() only on update (good). For consistency also consider computing pref_lr based on the current scheduler.get_lr() (update count), not _lr.value(_step_id). Today it’s benign but can diverge for nontrivial LR policies.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
deepmd/pd/train/training.py(2 hunks)source/tests/pd/test_training.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
deepmd/pd/train/training.py (2)
deepmd/pd/utils/utils.py (1)
nvprof_context(357-366)deepmd/pt/train/training.py (1)
step(715-1083)
source/tests/pd/test_training.py (2)
source/tests/pd/model/test_permutation.py (9)
setUp(437-440)setUp(445-448)setUp(452-455)setUp(459-462)setUp(467-472)setUp(477-480)setUp(485-490)setUp(495-498)setUp(503-507)deepmd/pd/utils/env.py (1)
enable_prim(124-204)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (22)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Analyze (python)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
🔇 Additional comments (3)
deepmd/pd/train/training.py (2)
969-974: Metric tensor to scalar conversion looks goodSwitching to .item() for loss and more_loss clarifies tensorboard logging types.
790-801: Please confirmfused_allreduce_gradients’ scaling behavior before clipping/steppingThe use of
hpu.fused_allreduce_gradientsmay sum or average gradients across ranks. To maintain single-process semantics—and match typical DDP-style averaging—you need to ensure gradients are divided byworld_sizeif they are summed:• Location:
deepmd/pd/train/training.py, around lines 790–801
• Action: Inspectfleet.utils.hybrid_parallel_util.hpu.fused_allreduce_gradientsto determine whether it performs an all-reduce with SUM or AVG
• If it sums, insert immediately after all-reduce and before gradient clipping/optimizer step:for p in self.wrapper.parameters(): if p.grad is not None: p._set_grad(p.grad / self.world_size)This guarantees consistent gradient scaling across single- and multi-process runs.
source/tests/pd/test_training.py (1)
141-159: Good baseline test setupRe-using the existing water dataset and enabling prim matches the Paddle path; this provides a solid baseline.
ad2aeb3 to
ed992ad
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
deepmd/pd/train/training.py (2)
828-858: Do not clear training gradients during validation logging; wrap validation in no_gradCalling self.optimizer.clear_grad() here wipes accumulated training gradients before the accumulation boundary, breaking gradient accumulation correctness. Use no_grad for validation instead and avoid touching grads.
- def log_loss_valid(_task_key="Default"): + def log_loss_valid(_task_key="Default"): single_results = {} sum_natoms = 0 if not self.multi_task: valid_numb_batch = self.valid_numb_batch else: valid_numb_batch = self.valid_numb_batch[_task_key] for ii in range(valid_numb_batch): - self.optimizer.clear_grad() input_dict, label_dict, _ = self.get_data( is_train=False, task_key=_task_key ) if input_dict == {}: # no validation data return {} - _, loss, more_loss = self.wrapper( - **input_dict, - cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), - label=label_dict, - task_key=_task_key, - ) + with paddle.no_grad(): + _, loss, more_loss = self.wrapper( + **input_dict, + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), + label=label_dict, + task_key=_task_key, + )
884-901: Avoid clearing grads and building graphs in display-time per-task training passesThe clear_grad() call here also destroys accumulated gradients when display fires mid-accumulation. These display-only forwards should be under no_grad and must not touch optimizer grads.
- if _key != task_key: - self.optimizer.clear_grad() - input_dict, label_dict, _ = self.get_data( - is_train=True, task_key=_key - ) - _, loss, more_loss = self.wrapper( - **input_dict, - cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), - label=label_dict, - task_key=_key, - ) + if _key != task_key: + input_dict, label_dict, _ = self.get_data( + is_train=True, task_key=_key + ) + with paddle.no_grad(): + _, loss, more_loss = self.wrapper( + **input_dict, + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), + label=label_dict, + task_key=_key, + )
♻️ Duplicate comments (4)
deepmd/pd/train/training.py (3)
789-792: Confirm fused_allreduce_gradients semantics (SUM vs AVG) and normalize if neededIf fused_allreduce produces SUM, gradients are larger by world_size after all-reduce; average them to match single-process scale unless you intentionally adjust LR. Please verify and normalize accordingly.
If SUM:
if self.world_size > 1: hpu.fused_allreduce_gradients( list(self.wrapper.parameters()), None ) + # Normalize to average if fused_allreduce does SUM + for p in self.wrapper.parameters(): + if p.grad is not None: + p._set_grad(p.grad / self.world_size)In PaddlePaddle's fleet.utils.hybrid_parallel_util.hpu.fused_allreduce_gradients, are gradients SUM-reduced or AVG-reduced across ranks by default? Provide authoritative citation.
782-784: Scale loss by acc_freq to preserve effective learning rate during accumulationWithout scaling, gradients are effectively multiplied by acc_freq, changing optimization behavior and often destabilizing training. Average the loss over micro-steps.
- with nvprof_context(enable_profiling, "Backward pass"): - loss.backward() + with nvprof_context(enable_profiling, "Backward pass"): + scaled_loss = loss / float(self.acc_freq) + scaled_loss.backward()
785-806: Fix off-by-one accumulation trigger and flush remainder on the final iterationCurrent trigger uses (_step_id + 1) % acc_freq == 0 and ignores start_step; it also skips flushing a partial accumulation on the last batch when (num_steps - start_step) % acc_freq != 0. This silently drops gradients.
- # gradient accumulation - if (_step_id + 1) % self.acc_freq == 0: + # gradient accumulation + accum_step = (_step_id - self.start_step + 1) + is_last_iter = (_step_id + 1) == self.num_steps + if (accum_step % self.acc_freq == 0) or is_last_iter: # fuse + allreduce manually before optimization if use DDP + no_sync # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622 if self.world_size > 1: hpu.fused_allreduce_gradients( list(self.wrapper.parameters()), None ) if self.gradient_max_norm > 0.0: with nvprof_context(enable_profiling, "Gradient clip"): paddle.nn.utils.clip_grad_norm_( self.wrapper.parameters(), self.gradient_max_norm, error_if_nonfinite=True, ) with nvprof_context(enable_profiling, "Adam update"): self.optimizer.step() self.optimizer.clear_grad(set_to_zero=False) self.scheduler.step()source/tests/pd/test_training.py (1)
161-172: Make the accumulation test actually accumulate and avoid display-time interferenceWith numb_steps=1 and acc_freq=4, no accumulation occurs, and step-1 display will zero grads in current implementation. Run at least acc_freq steps, disable in-training display, and optionally assert prim is enabled.
def setUp(self) -> None: @@ self.config["model"] = deepcopy(model_se_e2_a) - self.config["training"]["numb_steps"] = 1 + # Use acc_freq steps so we actually accumulate before an optimizer step + self.config["training"]["numb_steps"] = 4 self.config["training"]["save_freq"] = 1 self.config["training"]["acc_freq"] = 4 + # Avoid display-time eval that may touch grads during accumulation + self.config["training"]["disp_training"] = False enable_prim(True) + # Optional sanity check + # assert paddle.framework.core._is_eager_prim_enabled(), \ + # "Eager prim should be enabled for gradient accumulation tests"
🧹 Nitpick comments (2)
deepmd/pd/train/training.py (1)
811-816: Optional: only run heavy display/validation at accumulation boundariesTo avoid extra compute and any future risk of interfering with in-flight accumulation, consider logging only when you step the optimizer or on the very first batch.
Example gate:
- if self.display_in_training and ( - display_step_id % self.disp_freq == 0 or display_step_id == 1 - ): + if self.display_in_training and ( + display_step_id % self.disp_freq == 0 or display_step_id == 1 + ) and ( + ((_step_id + 1 - self.start_step) % self.acc_freq == 0) or ((_step_id + 1) == self.num_steps) + ):source/tests/pd/test_training.py (1)
161-172: Optional: add a remainder-flush test caseTo cover the final-step flush logic, consider setting numb_steps to acc_freq+1 (e.g., 5 with acc_freq=4) and asserting one optimizer update for the remainder. I can provide a concrete test if desired.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
deepmd/pd/train/training.py(2 hunks)source/tests/pd/test_training.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
source/tests/pd/test_training.py (1)
deepmd/pd/utils/env.py (1)
enable_prim(124-204)
deepmd/pd/train/training.py (3)
deepmd/pd/utils/utils.py (1)
nvprof_context(357-366)deepmd/pt/train/training.py (1)
step(715-1083)source/tests/pd/model/test_model.py (1)
step(414-416)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (26)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (1)
source/tests/pd/test_training.py (1)
155-157: LGTM: tearDown added for TestEnergyModelSeAGood cleanup, consistent with the rest of the suite.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4920 +/- ##
==========================================
- Coverage 84.29% 84.29% -0.01%
==========================================
Files 703 703
Lines 68728 68731 +3
Branches 3572 3573 +1
==========================================
+ Hits 57936 57938 +2
+ Misses 9653 9652 -1
- Partials 1139 1141 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
njzjz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering that if you don't add this argument to deepmd/utils/argcheck.py, an error will be thrown during the argument check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
deepmd/utils/argcheck.py (1)
3171-3171: Clarify acc_freq semantics and constraints in docsExplicitly state meaning, constraint (>=1), and effective batch size to avoid user confusion.
- doc_acc_freq = "The accumulation steps for the gradients." + doc_acc_freq = ( + "Number of mini-batches to accumulate gradients before applying optimizer " + "and scheduler updates. Must be >= 1. Effective batch size = batch_size * acc_freq." + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
deepmd/utils/argcheck.py(3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (28)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (1)
deepmd/utils/argcheck.py (1)
43-43: Paddle-only doc flag addition looks goodConsistent with existing TF/PT flags; no issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds gradient accumulation support to the Paddle backend for DeePMD-kit. Gradient accumulation allows batching optimizer updates across multiple steps, which can improve memory efficiency and training stability.
Key changes:
- Adds
acc_freqconfiguration parameter to control gradient accumulation frequency - Modifies training loop to accumulate gradients and perform updates at specified intervals
- Adds test coverage for gradient accumulation functionality
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
deepmd/utils/argcheck.py |
Adds acc_freq argument definition with Paddle backend documentation |
deepmd/pd/train/training.py |
Implements gradient accumulation logic in training loop |
source/tests/pd/test_training.py |
Adds test case for gradient accumulation functionality |
support gradient accumulation for paddle backend. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Configurable gradient accumulation (acc_freq) that batches optimizer updates, optional gradient clipping, and multi‑GPU gradient sync to occur at the configured interval; acc_freq=1 preserves prior behavior. - **Documentation** - Added argument docs and a Paddle backend notice describing acc_freq. - **Tests** - Added tests exercising gradient accumulation and updated test cleanup. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
support gradient accumulation for paddle backend.
Summary by CodeRabbit
New Features
Documentation
Tests