Skip to content

add reduce=True argument to MultiLabelMarginLoss#4924

Merged
soumith merged 4 commits intopytorch:masterfrom
li-roy:multilabelmarginloss_reduce
Feb 5, 2018
Merged

add reduce=True argument to MultiLabelMarginLoss#4924
soumith merged 4 commits intopytorch:masterfrom
li-roy:multilabelmarginloss_reduce

Conversation

@li-roy
Copy link
Copy Markdown
Contributor

@li-roy li-roy commented Jan 30, 2018

As per #264. When reduce is False, MultiLabelMarginLoss outputs a loss per sample in minibatch. When reduce is True (default), the current behavior is kept.

Test Plan
test/run_test.sh
Added unit test. For the reduce=False case. Added unit tests for 1d tensors. Added a reference function.

@pytorchbot
Copy link
Copy Markdown
Collaborator

@li-roy, thanks for your PR! We identified @zdevito to be a potential reviewer.

@li-roy
Copy link
Copy Markdown
Contributor Author

li-roy commented Jan 30, 2018

@pytorchbot retest this please

@li-roy li-roy closed this Jan 30, 2018
@li-roy li-roy reopened this Jan 30, 2018
@zou3519
Copy link
Copy Markdown
Contributor

zou3519 commented Jan 30, 2018

@pytorchbot retest this please

Copy link
Copy Markdown
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't reviewed the C code or kernels yet, will get to that soon. I left a few comments on the python side to start out


template <typename Dtype, typename Acctype>
__global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gradInput,
Dtype *gradOutput,

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/nn/modules/loss.py Outdated
size_average. Default: True

Shape:
- Input: :math:`(N)` or :math:`(N, *)` where `*` means, any number of additional

This comment was marked as off-topic.

Comment thread torch/nn/modules/loss.py Outdated
- Input: :math:`(N)` or :math:`(N, *)` where `*` means, any number of additional
dimensions
- Target: :math:`(N)` or :math:`(N, *)`, same shape as the input
- Output: scalar. If `reduce` is False, then `(N)` or `(N, *)`, same shape as

This comment was marked as off-topic.

Comment thread torch/nn/modules/loss.py Outdated
@@ -565,10 +565,32 @@ class MultiLabelMarginLoss(_Loss):
The criterion only considers the first non-negative `y[j]` targets.

This comment was marked as off-topic.

Comment thread test/test_nn.py


def multilabelmarginloss_no_reduce_test():
t = Variable(torch.rand(5, 10).mul(10).floor().long())

This comment was marked as off-topic.

Comment thread test/common_nn.py
check_no_size_average=True,
),
dict(
module_name='MultiLabelMarginLoss',

This comment was marked as off-topic.

Comment thread test/common_nn.py Outdated
if input.dim() == 1:
n = 1
dim = input.size()[0]
output = torch.Tensor(n).zero_()

This comment was marked as off-topic.

Comment thread test/common_nn.py Outdated
return output


def _multilabelmarginloss_reference(input, target, is_target):

This comment was marked as off-topic.

Comment thread test/common_nn.py Outdated

if input.dim() == 1:
n = 1
dim = input.size()[0]

This comment was marked as off-topic.

Comment thread test/common_nn.py Outdated
def _multilabelmarginloss_reference(input, target, is_target):
sum = 0
for i in range(0, target.size()[0]):
target_index = target[i]

This comment was marked as off-topic.

Copy link
Copy Markdown
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The THNN code looks good for the most part! Just a few nits here and there. Haven't looked at the cuda kernel yet.

isTarget_data += dim;
gradInput_data += dim;
}
gradInput_data -= nframe*dim;

This comment was marked as off-topic.

}

sum /= dim;
THTensor_(set1d)(output, t, sum);

This comment was marked as off-topic.

else
{
THNN_CHECK_DIM_SIZE(gradOutput, 1, 0, nframe);
gradOutput = THTensor_(newContiguous)(gradOutput);

This comment was marked as off-topic.

THTensor_(free)(input);
THIndexTensor_(free)(target);
THTensor_(free)(isTarget);
THTensor_(free)(gradOutput);

This comment was marked as off-topic.

Comment thread torch/nn/modules/loss.py Outdated
reduce (bool, optional): By default, the losses are averaged or summed over
observations for each minibatch depending on size_average. When reduce
is False, returns a loss per batch element instead and ignores
size_average. Default: True

This comment was marked as off-topic.

@@ -133,6 +140,11 @@ __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gra
}
__syncthreads();

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/nn/modules/loss.py Outdated
reduce (bool, optional): By default, the losses are averaged or summed over
observations for each minibatch depending on size_average. When reduce
is False, returns a loss per batch element instead and ignores
size_average. Default: True

This comment was marked as off-topic.

Comment thread torch/nn/modules/loss.py Outdated
size_average is set to ``False``, the losses are instead summed for
each minibatch. Default: ``True``
reduce (bool, optional): By default, the losses are averaged or summed over
observations for each minibatch depending on size_average. When reduce

This comment was marked as off-topic.

THIndexTensor_(free)(target);
THTensor_(free)(isTarget);
THTensor_(free)(gradOutput);
THTensor_(free)(gradInput);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Copy Markdown
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All comments addressed, LGTM unless anyone has something else to add. Thanks @li-roy!

@ssnl
Copy link
Copy Markdown
Collaborator

ssnl commented Feb 2, 2018

Thanks @li-roy !

@zou3519
Copy link
Copy Markdown
Contributor

zou3519 commented Feb 2, 2018

I'm not sure why the CI isn't running, but... let's give this one final test

@zou3519
Copy link
Copy Markdown
Contributor

zou3519 commented Feb 2, 2018

@pytorchbot retest this please

@soumith soumith merged commit 28f056f into pytorch:master Feb 5, 2018
@soumith soumith added 0.3.1 and removed 0.3.1 labels Feb 5, 2018
@li-roy li-roy deleted the multilabelmarginloss_reduce branch February 22, 2018 00:17
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
* add reduce=True argument to MultiLabelMarginLoss

* Fix lint

* Addressed comments

* Remove unneeded syncthreads calls
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants