Skip to content

Add FusedLinearJSD#300

Merged
qingquansong merged 17 commits into
linkedin:mainfrom
Tcc0403:fl-jsd
Oct 11, 2024
Merged

Add FusedLinearJSD#300
qingquansong merged 17 commits into
linkedin:mainfrom
Tcc0403:fl-jsd

Conversation

@Tcc0403

@Tcc0403 Tcc0403 commented Oct 8, 2024

Copy link
Copy Markdown
Collaborator

Summary

similar to the fuse linear CE.

It handles the forward and backward pass of the final linear layer via JSD by avoiding the materialization of the large logits tensor. Since JSD is the last layer, we can compute the gradient at the forward pass.

Testing Done

Hidden size: 4096, Vocab size: 128256
fused_linear_jsd_memory
fused_linear_jsd_speed

  • Hardware Type: NVIDIA H100 80GB HBM3 (SXM5)
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@Tcc0403 Tcc0403 marked this pull request as ready for review October 9, 2024 01:51
out=grad_weight,
)

loss = torch.sum(loss_1d) / BT

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed that torch.sum might overflow when BT size is large. It can be fixed by either

  1. do division first before sum, i.e. torch.sum(loss_1d / BT)
  2. put division inside JSD kernel
    I prefre the second solution more, but I'm not sure if it is ok to modify another kernel in this PR.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can the existing division in JSD be able to handle this? 🤔 I saw you change the n_row for each chunk to be BT now.

@Tcc0403 Tcc0403 Oct 9, 2024

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently n_row parameter in JSD is only for gradients calculation.
it can be modified to something like cross_entropy does, which can calculate loss wrt expected reduction method

loss = loss / n_non_ignore

I saw you change the n_row for each chunk to be BT now.

simply passing BT can perform correct calculations without further alpha tweaking like flce does

alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0

@Tcc0403 Tcc0403 Oct 9, 2024

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can the existing division in JSD be able to handle this?

it has potential overflow issue too, thats one of the reason why i think moving division into kernel is better too

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. I'm good with moving it inside the JSD kernel. For ce/flce, let's keep them for now. Thanks!

qingquansong
qingquansong previously approved these changes Oct 9, 2024

@qingquansong qingquansong left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! We can add the ignore index later similar to here so it can be easily used for the SFT context. Great work!

out=grad_weight,
)

loss = torch.sum(loss_1d) / BT

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can the existing division in JSD be able to handle this? 🤔 I saw you change the n_row for each chunk to be BT now.



@triton.jit
def element_mul_kernel(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can delete the one in the original ce and share this

@Tcc0403 Tcc0403 Oct 9, 2024

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I can add the new ce file

@qingquansong qingquansong merged commit ff6650b into linkedin:main Oct 11, 2024
@ByronHsu ByronHsu mentioned this pull request Oct 31, 2024
@Tcc0403 Tcc0403 deleted the fl-jsd branch December 1, 2024 03:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants