@@ -33,6 +33,7 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
3333 << param.DebugString ();
3434 param_ = param;
3535 CHECK_GE (param_.average_loss (), 1 ) << " average_loss should be non-negative." ;
36+ CHECK_GE (param_.accum_grad (), 1 ) << " accum_grad should be non-negative." ;
3637 if (param_.random_seed () >= 0 ) {
3738 Caffe::set_random_seed (param_.random_seed ());
3839 }
@@ -164,6 +165,7 @@ void Solver<Dtype>::Step(int iters) {
164165 const int start_iter = iter_;
165166 const int stop_iter = iter_ + iters;
166167 int average_loss = this ->param_ .average_loss ();
168+ const int accum_grad = this ->param_ .accum_grad ();
167169 vector<Dtype> losses;
168170 Dtype smoothed_loss = 0 ;
169171
@@ -175,7 +177,17 @@ void Solver<Dtype>::Step(int iters) {
175177
176178 const bool display = param_.display () && iter_ % param_.display () == 0 ;
177179 net_->set_debug_info (display && param_.debug_info ());
178- Dtype loss = net_->ForwardBackward (bottom_vec);
180+ Dtype loss = 0 ;
181+ if (accum_grad > 1 ) {
182+ ResetAccumulateGradients ();
183+ for (int i = 0 ; i < accum_grad; ++i) {
184+ loss += net_->ForwardBackward (bottom_vec);
185+ AccumulateGradients ();
186+ }
187+ loss /= accum_grad;
188+ } else {
189+ loss = net_->ForwardBackward (bottom_vec);
190+ }
179191 if (losses.size () < average_loss) {
180192 losses.push_back (loss);
181193 int size = losses.size ();
@@ -430,6 +442,9 @@ void SGDSolver<Dtype>::PreSolve() {
430442 temp_.push_back (shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
431443 net_param->num (), net_param->channels (), net_param->height (),
432444 net_param->width ())));
445+ accum_.push_back (shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
446+ net_param->num (), net_param->channels (), net_param->height (),
447+ net_param->width ())));
433448 }
434449}
435450
@@ -458,12 +473,54 @@ void SGDSolver<Dtype>::ClipGradients() {
458473 }
459474}
460475
476+ template <typename Dtype>
477+ void SGDSolver<Dtype>::AccumulateGradients() {
478+ const vector<shared_ptr<Blob<Dtype> > >& net_params = this ->net_ ->params ();
479+ const int accum_grad = this ->param_ .accum_grad ();
480+ if (Caffe::mode () == Caffe::GPU) {
481+ #ifndef CPU_ONLY
482+ for (int param_id = 0 ; param_id < net_params.size (); ++param_id) {
483+ caffe_gpu_axpy (net_params[param_id]->count (), Dtype (1 . / accum_grad),
484+ net_params[param_id]->gpu_diff (),
485+ accum_[param_id]->mutable_gpu_data ());
486+ }
487+ #else
488+ NO_GPU;
489+ #endif
490+ } else {
491+ for (int param_id = 0 ; param_id < net_params.size (); ++param_id) {
492+ caffe_axpy (net_params[param_id]->count (), Dtype (1 . / accum_grad),
493+ net_params[param_id]->cpu_diff (),
494+ accum_[param_id]->mutable_cpu_data ());
495+ }
496+ }
497+ }
498+ template <typename Dtype>
499+ void SGDSolver<Dtype>::ResetAccumulateGradients() {
500+ if (Caffe::mode () == Caffe::GPU) {
501+ #ifndef CPU_ONLY
502+ for (int param_id = 0 ; param_id < accum_.size (); ++param_id) {
503+ caffe_gpu_set (accum_[param_id]->count (), Dtype (0 ),
504+ accum_[param_id]->mutable_gpu_data ());
505+ }
506+ #else
507+ NO_GPU;
508+ #endif
509+ } else {
510+ for (int param_id = 0 ; param_id < accum_.size (); ++param_id) {
511+ caffe_set (accum_[param_id]->count (), Dtype (0 ),
512+ accum_[param_id]->mutable_cpu_data ());
513+ }
514+ }
515+ }
516+
461517template <typename Dtype>
462518void SGDSolver<Dtype>::ComputeUpdateValue() {
463519 const vector<shared_ptr<Blob<Dtype> > >& net_params = this ->net_ ->params ();
464520 const vector<float >& net_params_lr = this ->net_ ->params_lr ();
465521 const vector<float >& net_params_weight_decay =
466522 this ->net_ ->params_weight_decay ();
523+ const int accum_grad = this ->param_ .accum_grad ();
467524 // get the learning rate
468525 Dtype rate = GetLearningRate ();
469526 if (this ->param_ .display () && this ->iter_ % this ->param_ .display () == 0 ) {
@@ -477,6 +534,10 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
477534 case Caffe::CPU:
478535 for (int param_id = 0 ; param_id < net_params.size (); ++param_id) {
479536 // Compute the value to history, and then copy them to the blob's diff.
537+ if (accum_grad > 1 ) {
538+ caffe_copy (accum_[param_id]->count (), accum_[param_id]->cpu_data (),
539+ net_params[param_id]->mutable_cpu_diff ());
540+ }
480541 Dtype local_rate = rate * net_params_lr[param_id];
481542 Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
482543
@@ -513,6 +574,10 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
513574#ifndef CPU_ONLY
514575 for (int param_id = 0 ; param_id < net_params.size (); ++param_id) {
515576 // Compute the value to history, and then copy them to the blob's diff.
577+ if (accum_grad > 1 ) {
578+ caffe_copy (accum_[param_id]->count (), accum_[param_id]->gpu_data (),
579+ net_params[param_id]->mutable_gpu_diff ());
580+ }
516581 Dtype local_rate = rate * net_params_lr[param_id];
517582 Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
518583
@@ -696,6 +761,7 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue() {
696761 const vector<float >& net_params_lr = this ->net_ ->params_lr ();
697762 const vector<float >& net_params_weight_decay =
698763 this ->net_ ->params_weight_decay ();
764+ const int accum_grad = this ->param_ .accum_grad ();
699765 // get the learning rate
700766 Dtype rate = this ->GetLearningRate ();
701767 Dtype delta = this ->param_ .delta ();
@@ -708,6 +774,11 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue() {
708774 switch (Caffe::mode ()) {
709775 case Caffe::CPU:
710776 for (int param_id = 0 ; param_id < net_params.size (); ++param_id) {
777+ if (accum_grad > 1 ) {
778+ caffe_copy (this ->accum_ [param_id]->count (),
779+ this ->accum_ [param_id]->cpu_data (),
780+ net_params[param_id]->mutable_cpu_diff ());
781+ }
711782 Dtype local_rate = rate * net_params_lr[param_id];
712783 Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
713784
@@ -764,6 +835,11 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue() {
764835 case Caffe::GPU:
765836#ifndef CPU_ONLY
766837 for (int param_id = 0 ; param_id < net_params.size (); ++param_id) {
838+ if (accum_grad > 1 ) {
839+ caffe_copy (this ->accum_ [param_id]->count (),
840+ this ->accum_ [param_id]->gpu_data (),
841+ net_params[param_id]->mutable_gpu_diff ());
842+ }
767843 Dtype local_rate = rate * net_params_lr[param_id];
768844 Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
769845
0 commit comments