Introduce Knowledge Distillation Base#432
Merged
Merged
Conversation
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com> Set default `chunk_size` to `1024` Signed-off-by: Austin Liu <austin362667@gmail.com> Rebase Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
4 tasks
Collaborator
Sorry for the misleading question and late response. Passing What I questioned in the last comment was about the |
Contributor
Author
|
Thanks @Tcc0403 and @hongpeng-guo . I think we can merge this PR first? To unblock other distill loss impl. I have some follow-ups to iterate on in my mind:
@ByronHsu WDYT? |
austin362667
commented
Dec 9, 2024
3 tasks
Tcc0403
pushed a commit
that referenced
this pull request
Jan 30, 2025
## Summary > [!CAUTION] > This PR depends on #417. Do not merge until #417 (later #432) is merged. This is a pure torch compiled, chunked fused linear JSD Loss, aiming for knowledge distillation. #### Jensen-Shannon Divergence Loss This PR implements Jensen-Shannon Divergence (JSD) loss as the soft learning objective in a distillation setting (teacher & student). This component can be replaced with other losses (e.g., KL divergence) as `distillation_loss_fn`. JSD is defined as the average of the KL divergences between each distribution and the mean distribution: ```math \text{JSD}(P || Q) = \frac{1}{2} \text{KL}(P || M) + \frac{1}{2} \text{KL}(Q || M), \quad \text{where } M = \frac{1}{2}(P + Q) ``` Here, `P`and `Q` are the two probability distributions, and `M` is their average. ## Testing Done Below figures are benchmark results with different `chunk_size`, which also significantly affects performance. #### Hint: User can tune their `chunk_size` as suggested by the liger [paper](https://arxiv.org/pdf/2306.13649) for the moment: ```math 2^{\lceil \log_2 \lceil \frac{BT}{V/H} \rceil \rceil} ``` #### Memory 1. `chunk_size` = 1  2. `chunk_size` = 1024  #### Speed (Elapsed Time) 1. `chunk_size` = 1  2. `chunk_size` = 1024  - Hardware Type: NVIDIA H100 80GB HBM3 (SXM5) - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu <austin362667@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Recreate #417 from the main repo.
Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR is the first split of #408, focusing solely on introducing the Knowledge Distillation base class. As a result, this PR does not include any tests at the moment.
Code Changes
Refactor
betainto two weights:weight_hard_lossandweight_soft_loss, as coefficients betweenhard_lossandsoft_loss. @Tcc0403 also pointed out that we could usetorch.lerpif applicable.Pass
teacher_logitsandstudent_logitsdirectly to the divergence loss function. This avoids redundant computations of converting logits to log probabilities and then reverting them to raw logits. However note that we are not reusing thestudent_log_probsvalue calculated duringce_lossin distillation base.get_batch_logpsintest/utils.py.Modify
chunkingdimensions fromBtoB * T. Thanks to @hongpeng-guo's great advice.Normalize the
distillation_lossusing(full_target != ignore_index).sum().TODO
Although a slightly slowdown is reasonable, we need to investigate why this PR's implementation is significantly slower compared to the naive approach. Thanks to @Tcc0403 's clarification.
The issue arises because we are not properly configuring the
chunk_sizefor theB * Tdimension, which is extremely large (a few thousand). The previous default of 1 results in an excessive number of chunks.In contrast, this problem does not occur with the preference loss, as chunking is performed on the
Bdimension. This produces fewer than 10 chunks, which is efficient and works as expected.In conclusion, I set
chunk_sizeto1024works pretty well in new benchmark results as shown in Add JSD Loss for Distillation #425Introduce Knowledge Distillation Base #417 (comment)
Knowledge Distillation
Knowledge Distillation (KD; Hinton et al. 2015, Gou et al. 2020) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student.
In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let
z_tandz_srepresent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperatureT. When ground truth labelsyare available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth.The combined loss function is defined as:
Here, we directly pass in
logitsrather thanlogpbs. @Tcc0403Shared
DistillationBaseTo support various distillation learning objectives, this PR aims to add a
LigerFusedLinearDistillationBasewhich is basically same as propose by @hongpeng-guo within this discussion #371 (comment). Thank you @hongpeng-guo for thinking through this.Testing Done
I'll post JSD tests and benchmarks results in next PR: #425
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence