Skip to content

Change keepdim argument to False in loss function#9

Merged
hyp1231 merged 1 commit into
RUCAIBox:masterfrom
yinpu:master
Feb 20, 2023
Merged

Change keepdim argument to False in loss function#9
hyp1231 merged 1 commit into
RUCAIBox:masterfrom
yinpu:master

Conversation

@yinpu

@yinpu yinpu commented Feb 20, 2023

Copy link
Copy Markdown
Contributor

In this commit, the keepdim=True argument in the loss function has been removed as it is unnecessary. In the original code, the shape of neg_logits is (batch_size,), and the shape of pos_logits is (batch_size, 1). The shape of the result of dividing the two is (batch_size, batch_size)

@hyp1231

hyp1231 commented Feb 20, 2023

Copy link
Copy Markdown
Member

Thanks for your attention and the great contribution! It's indeed a bug and seems influencing the pre-training procedure. The bug may bring additional noises to the pre-training objectives. After merging this PR, I suppose the pre-training cound be more efficient and converges faster. The results could be slightly different comparing to what we've reported.

I'll continue test the performance after fixing this bug and will update later. Thanks again for the PR!

@hyp1231 hyp1231 merged commit ed0bfbc into RUCAIBox:master Feb 20, 2023
@yinpu

yinpu commented Feb 20, 2023

Copy link
Copy Markdown
Contributor Author

Thank you for your response. However, this modification will not affect the results, as the final use of the mean function will lead to the same results as before, only reducing computation and memory consumption.

@hyp1231

hyp1231 commented Feb 20, 2023

Copy link
Copy Markdown
Member

Thank you for your response. However, this modification will not affect the results, as the final use of the mean function will lead to the same results as before, only reducing computation and memory consumption.

Thanks! I hope so but while testing I found that if we use [[a1], [a2], [a3]] / [b1, b2, b3], it seems that the result will be:

a1/b1   a1/b2   a1/b3
a2/b1   a2/b2   a2/b3
a3/b1   a3/b2   a3/b3

and only those elements of matrix diagonal are what we need. I think the other elements could be noises.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants