Skip to content

Conversation

@to3i
Copy link
Contributor

@to3i to3i commented Jul 14, 2014

This PR adds a max function option to the elementwise layer as suggested in #654 . Tests are included, however, all four gradient tests fail using the elementwise checker!

For the gradient computation I assume that only the gradients of maximum values in forward pass are propagated backwards, while all the remaining values are set to zero. I wonder if the gradient calculation is flawed or the gradient checker needs to be adapted to deal with the elementwise max case. Maybe someone familiar with the gradient checker could have a look.

Here is some output from the GPU test run:

[----------] 2 tests from EltwiseLayerTest/3, where TypeParam = caffe::DoubleGPU
[ RUN      ] EltwiseLayerTest/3.TestMax
[       OK ] EltwiseLayerTest/3.TestMax (1 ms)
[ RUN      ] EltwiseLayerTest/3.TestMaxGradient
./include/caffe/test/test_gradient_check_util.hpp:165: Failure
The difference between computed_gradient and estimated_gradient is 0.016865366604179588, which exceeds threshold_ * scale, where
computed_gradient evaluates to 0,
estimated_gradient evaluates to 0.016865366604179588, and
threshold_ * scale evaluates to 0.001.
debug: (top_id, top_data_id, blob_id, feat_id)=0,13,0,13
./include/caffe/test/test_gradient_check_util.hpp:165: Failure
The difference between computed_gradient and estimated_gradient is 0.0168653666041787, which exceeds threshold_ * scale, where
computed_gradient evaluates to 1,
estimated_gradient evaluates to 0.9831346333958213, and
threshold_ * scale evaluates to 0.001.
debug: (top_id, top_data_id, blob_id, feat_id)=0,13,1,13
./include/caffe/test/test_gradient_check_util.hpp:165: Failure
The difference between computed_gradient and estimated_gradient is 0.22881695721298412, which exceeds threshold_ * scale, where
computed_gradient evaluates to 1,
estimated_gradient evaluates to 0.77118304278701588, and
threshold_ * scale evaluates to 0.001.
debug: (top_id, top_data_id, blob_id, feat_id)=0,15,1,15
./include/caffe/test/test_gradient_check_util.hpp:165: Failure
The difference between computed_gradient and estimated_gradient is 0.22881695721298501, which exceeds threshold_ * scale, where
computed_gradient evaluates to 0,
estimated_gradient evaluates to 0.22881695721298501, and
threshold_ * scale evaluates to 0.001.
debug: (top_id, top_data_id, blob_id, feat_id)=0,15,2,15
./include/caffe/test/test_gradient_check_util.hpp:165: Failure
The difference between computed_gradient and estimated_gradient is 0.46016938239336058, which exceeds threshold_ * scale, where
computed_gradient evaluates to 0,
estimated_gradient evaluates to 0.46016938239336058, and
threshold_ * scale evaluates to 0.001.
debug: (top_id, top_data_id, blob_id, feat_id)=0,36,1,36
./include/caffe/test/test_gradient_check_util.hpp:165: Failure
The difference between computed_gradient and estimated_gradient is 0.46016938239335969, which exceeds threshold_ * scale, where
computed_gradient evaluates to 1,
estimated_gradient evaluates to 0.53983061760664031, and
threshold_ * scale evaluates to 0.001.
debug: (top_id, top_data_id, blob_id, feat_id)=0,36,2,36
./include/caffe/test/test_gradient_check_util.hpp:165: Failure
The difference between computed_gradient and estimated_gradient is 0.33487495919689581, which exceeds threshold_ * scale, where
computed_gradient evaluates to 0,
estimated_gradient evaluates to 0.33487495919689581, and
threshold_ * scale evaluates to 0.001.
debug: (top_id, top_data_id, blob_id, feat_id)=0,58,0,58
./include/caffe/test/test_gradient_check_util.hpp:165: Failure
The difference between computed_gradient and estimated_gradient is 0.33487495919689492, which exceeds threshold_ * scale, where
computed_gradient evaluates to 1,
estimated_gradient evaluates to 0.66512504080310508, and
threshold_ * scale evaluates to 0.001.
debug: (top_id, top_data_id, blob_id, feat_id)=0,58,1,58
./include/caffe/test/test_gradient_check_util.hpp:165: Failure
The difference between computed_gradient and estimated_gradient is 0.19332874752581164, which exceeds threshold_ * scale, where
computed_gradient evaluates to 0,
estimated_gradient evaluates to 0.19332874752581164, and
threshold_ * scale evaluates to 0.001.
debug: (top_id, top_data_id, blob_id, feat_id)=0,65,0,65
./include/caffe/test/test_gradient_check_util.hpp:165: Failure
The difference between computed_gradient and estimated_gradient is 0.19332874752581075, which exceeds threshold_ * scale, where
computed_gradient evaluates to 1,
estimated_gradient evaluates to 0.80667125247418925, and
threshold_ * scale evaluates to 0.001.
debug: (top_id, top_data_id, blob_id, feat_id)=0,65,2,65
[  FAILED  ] EltwiseLayerTest/3.TestMaxGradient, where TypeParam = caffe::DoubleGPU (48 ms)
[----------] 2 tests from EltwiseLayerTest/3 (49 ms total)

@longjon
Copy link
Contributor

longjon commented Jul 14, 2014

This is likely due to the nonsmoothness of the max function. The gradient checker has limited support for gradient discontinuities with the kink parameters, but we probably need to either (1) think about how to generalize it to handle cases like this, or (2) postpone thinking by writing some code that adjusts kink appropriately while testing this layer.

@shelhamer shelhamer force-pushed the dev branch 3 times, most recently from 4278286 to c01f07a Compare August 28, 2014 07:00
jeffdonahue added a commit that referenced this pull request Sep 10, 2014
@jeffdonahue
Copy link
Contributor

wrapped up in #1053, thanks @to3i!

This was referenced Sep 18, 2014
mitmul pushed a commit to mitmul/caffe that referenced this pull request Sep 30, 2014
RazvanRanca pushed a commit to RazvanRanca/caffe that referenced this pull request Nov 4, 2014
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants