Skip to content

Enhance SGDP optimizer with caution parameter#2675

Merged
rwightman merged 5 commits intohuggingface:mainfrom
Yuan-Jinghui:main
Mar 5, 2026
Merged

Enhance SGDP optimizer with caution parameter#2675
rwightman merged 5 commits intohuggingface:mainfrom
Yuan-Jinghui:main

Conversation

@Yuan-Jinghui
Copy link
Copy Markdown
Contributor

Hi Ross, long time no see!

I have completed the implementation to add the cautious mask to the SGDP optimizer. The specific updates in this PR are as follows:

  • Added csgdp: I implemented the cautious mask for sgdp following the exact same logic used in cadamp.
  • Updated Citation: For the Spherical Cautious Optimizers, I updated the link to point to the permanent OpenReview acceptance URL.
  • Factory Registration: I registered csgdp in the optimizer factory.
  • Unit Tests: I added test functions for both csgdp and csgdw.

I have run the tests locally. Specifically for the Rosenbrock function, the output is as follows:

================================================= test session starts ================================================= 
platform win32 -- Python 3.10.18, pytest-9.0.2, pluggy-1.6.0 -- D:\anaconda\envs\manifold\python.exe
cachedir: .pytest_cache
rootdir: C:\Users\lenovo\Desktop\pytorch-image-models-main
configfile: pyproject.toml
plugins: anyio-4.9.0, dash-3.3.0
collected 130 items / 128 deselected / 2 selected

test_optim.py::test_optim_factory[csgdp] PASSED                                                                  [ 50%]
test_optim.py::test_csgdp[csgdp] PASSED                                                                          [100%]

========================================== 2 passed, 128 deselected in 7.17s ==========================================

The tests passed successfully. To verify that the training workflow completes without issues, I also ran the following scripts:

./distributed_train.sh 1 /home/cifar10     --dataset torch/cifar10     --model resnet18     --input-size 3 224 224     --num-classes 10     --batch-size 1024     --epochs 70     --opt csgdp     --lr 0.1     --weight-decay 5e-4     --amp     --experiment quick_csgdp_test

and

./distributed_train.sh 1 /home/cifar10     --dataset torch/cifar10     --model resnet18     --input-size 3 224 224     --num-classes 10     --batch-size 1024     --epochs 70     --opt csgdw     --lr 0.1     --weight-decay 5e-4     --amp     --experiment quick_csgdw_test

Here are the respective results:

CSGDP:

Current checkpoints:
 ('./output/train/quick_csgdp_test/checkpoint-68.pth.tar', 87.52999996337891)
 ('./output/train/quick_csgdp_test/checkpoint-69.pth.tar', 87.4599998046875)
 ('./output/train/quick_csgdp_test/checkpoint-66.pth.tar', 87.44999958496093)
 ('./output/train/quick_csgdp_test/checkpoint-67.pth.tar', 87.3799996459961)
 ('./output/train/quick_csgdp_test/checkpoint-65.pth.tar', 87.24999958496093)
 ('./output/train/quick_csgdp_test/checkpoint-64.pth.tar', 87.11999986572266)
 ('./output/train/quick_csgdp_test/checkpoint-63.pth.tar', 87.0999998046875)
 ('./output/train/quick_csgdp_test/checkpoint-62.pth.tar', 86.93000004882812)
 ('./output/train/quick_csgdp_test/checkpoint-61.pth.tar', 86.90000008544922)
 ('./output/train/quick_csgdp_test/checkpoint-60.pth.tar', 86.70999958496094)

Best metric: 87.52999996337891 (epoch 68)

CSGDW:

Current checkpoints:
 ('./output/train/quick_csgdw_test/checkpoint-63.pth.tar', 86.68999970703125)
 ('./output/train/quick_csgdw_test/checkpoint-66.pth.tar', 86.68000008544922)
 ('./output/train/quick_csgdw_test/checkpoint-60.pth.tar', 86.60999986572266)
 ('./output/train/quick_csgdw_test/checkpoint-69.pth.tar', 86.57999970703125)
 ('./output/train/quick_csgdw_test/checkpoint-67.pth.tar', 86.54999986572265)
 ('./output/train/quick_csgdw_test/checkpoint-68.pth.tar', 86.50999976806641)
 ('./output/train/quick_csgdw_test/checkpoint-64.pth.tar', 86.48999986572265)
 ('./output/train/quick_csgdw_test/checkpoint-58.pth.tar', 86.48000008544922)
 ('./output/train/quick_csgdw_test/checkpoint-65.pth.tar', 86.4399996459961)
 ('./output/train/quick_csgdw_test/checkpoint-61.pth.tar', 86.40999986572265)

Best metric: 86.68999970703125 (epoch 63)

These preliminary results demonstrate that ‘csgdp‘ can run successfully and maintain stable convergence.

Looking forward to your review when you have some free time!

Added 'caution' parameter to SGDP optimizer for enhanced functionality.
Fix reference link for Spherical Cautious Optimizers
Clone the buffer before using it for the update.
@rwightman rwightman merged commit 85bb330 into huggingface:main Mar 5, 2026
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants