Introduce Knowledge Distillation Base#417
Conversation
5257d26 to
3a9f125
Compare
| hard_loss, | ||
| ) = forward_output | ||
|
|
||
| soft_loss = self.distillation_loss(student_logits, teacher_logits) |
There was a problem hiding this comment.
the method use logprobs : def distillation_loss(self, student_logps, teacher_logps): but you use logits here.
I'd actually like to see both a logit and logprob implementation since it's easy to get logprobs offline from vllm and that is a faster way to generate the dataset.
There was a problem hiding this comment.
the method use logprobs : def distillation_loss(self, student_logps, teacher_logps): but you use logits here.
@winglian Nice catch! Thank you so much.
I'd actually like to see both a logit and logprob implementation since it's easy to get logprobs offline from vllm and that is a faster way to generate the dataset.
Sure, I think it's doable. And, I'm not quite sure I fully understand the need for logprobs implementation. Mind elaborate more on the vLLM use case?
There was a problem hiding this comment.
So rather than having to have the teacher model loaded during training, depending on the workload type, it can be faster and more compute efficient to pre-compute the logins/logprobs offline beforehand. However, vllm and sglang only provide the logprobs, and that's not easily back-calculated to logits.
There was a problem hiding this comment.
I see. That makes a lot sense to me. Thank you!
There was a problem hiding this comment.
@winglian curious if vllm/sglang support temperature scaled logprobs. This would be needed to enable https://github.com/huggingface/trl/blob/9c5388b69e0842f76edc46a2ff9d0b51e1db4337/trl/trainer/gkd_trainer.py#L174
There was a problem hiding this comment.
I believe we can address this ask in a subsequent PR
@ByronHsu what do you think?
The issue arises because we are not properly configuring the In contrast, this problem does not occur with the preference loss, as chunking is performed on the In conclusion, I set |
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com> Set default `chunk_size` to `1024` Signed-off-by: Austin Liu <austin362667@gmail.com> Rebase Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
4ada908 to
e381569
Compare
hongpeng-guo
left a comment
There was a problem hiding this comment.
@austin362667 FWIW, to run the Modal GPU CIs, this PR needs to be made from the main repo, i.e., linkedin/Liger-Kernel, instead of the forked repo.
A similar example is: I closed #399 and moved to #400 to enable the CI pipeline.
shivam15s
left a comment
There was a problem hiding this comment.
can you create another PR in linkedin? Some tests fail for me locally so I'd like to confirm before merging
|
@shivam15s Certainly, right here #432 Thanks a lot |
|
Move discussion to #432 |
## Summary Made #417 from the main repo. Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR is the s first split from #408, focusing solely on introducing the Knowledge Distillation base class. As a result, this PR does not include any tests at the moment. #### Code Changes 1. Refactor `beta` into two weights: `weight_hard_loss` and `weight_soft_loss`, as coefficients between `hard_loss` and `soft_loss`. @Tcc0403 also pointed out that we could use `torch.lerp` if applicable. 2. Pass `teacher_logits` and `student_logits` directly to the divergence loss function. This avoids redundant computations of converting logits to log probabilities and then reverting them to raw logits. However note that we are not reusing the `student_log_probs` value calculated during `ce_loss` in distillation base. 1. Remove the unnecessary `get_batch_logps` in `test/utils.py`. 3. Modify `chunking` dimensions from `B` to `B * T`. Thanks to @hongpeng-guo's great advice. 1. Fix the loss calculation to use per-token values instead of averaging across the sequence length dimension. 4. Normalize the `distillation_loss` using `(full_target != ignore_index).sum()`. #### TODO 1. [X] Although a slightly slowdown is reasonable, we need to investigate why this PR's implementation is **significantly slower** compared to the naive approach. Thanks to @Tcc0403 's clarification. The issue arises because we are not properly configuring the `chunk_size` for the `B * T` dimension, which is extremely large (a few thousand). The previous default of 1 results in an excessive number of chunks. In contrast, this problem does not occur with the preference loss, as chunking is performed on the `B` dimension. This produces fewer than 10 chunks, which is efficient and works as expected. In conclusion, I set `chunk_size` to `1024` works pretty well in new benchmark results as shown in #425 2. [ ] #417 (comment) #### Knowledge Distillation Knowledge Distillation (KD; [Hinton et al. 2015](https://arxiv.org/abs/1503.02531), [Gou et al. 2020](https://arxiv.org/abs/2006.05525)) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student. In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let `z_t` and `z_s` represent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperature `T`. When ground truth labels `y` are available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth. The combined loss function is defined as: ```math \mathcal{L}_{\text{knowledge distillation}} = \mathcal{w}_{\text{soft}} \cdot \mathcal{L}_{\text{distill}}(\mathbf{z_t}, \mathbf{z_s}, T) + \mathcal{w}_{\text{hard}} \cdot \mathcal{L}_{\text{cross entropy}}(\mathbf{y}, \mathbf{z_s}), ``` Here, we directly pass in `logits` rather than `logpbs`. @Tcc0403 #### Shared `DistillationBase` To support various distillation learning objectives, this PR aims to add a `LigerFusedLinearDistillationBase` which is basically same as propose by @hongpeng-guo within this discussion #371 (comment). Thank you @hongpeng-guo for thinking through this. ## Testing Done I'll post JSD tests and benchmarks results in next PR: #425 - Hardware Type: L40S - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu <austin362667@gmail.com> Co-authored-by: shivam15s <shivam15800@gmail.com>
## Summary > [!CAUTION] > This PR depends on #417. Do not merge until #417 (later #432) is merged. This is a pure torch compiled, chunked fused linear JSD Loss, aiming for knowledge distillation. #### Jensen-Shannon Divergence Loss This PR implements Jensen-Shannon Divergence (JSD) loss as the soft learning objective in a distillation setting (teacher & student). This component can be replaced with other losses (e.g., KL divergence) as `distillation_loss_fn`. JSD is defined as the average of the KL divergences between each distribution and the mean distribution: ```math \text{JSD}(P || Q) = \frac{1}{2} \text{KL}(P || M) + \frac{1}{2} \text{KL}(Q || M), \quad \text{where } M = \frac{1}{2}(P + Q) ``` Here, `P`and `Q` are the two probability distributions, and `M` is their average. ## Testing Done Below figures are benchmark results with different `chunk_size`, which also significantly affects performance. #### Hint: User can tune their `chunk_size` as suggested by the liger [paper](https://arxiv.org/pdf/2306.13649) for the moment: ```math 2^{\lceil \log_2 \lceil \frac{BT}{V/H} \rceil \rceil} ``` #### Memory 1. `chunk_size` = 1  2. `chunk_size` = 1024  #### Speed (Elapsed Time) 1. `chunk_size` = 1  2. `chunk_size` = 1024  - Hardware Type: NVIDIA H100 80GB HBM3 (SXM5) - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu <austin362667@gmail.com>
Summary
Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR is the first split from #408, focusing solely on introducing the Knowledge Distillation base class. As a result, this PR does not include any tests at the moment.
Code Changes
Refactor
betainto two weights:weight_hard_lossandweight_soft_loss, as coefficients betweenhard_lossandsoft_loss. @Tcc0403 also pointed out that we could usetorch.lerpif applicable.Pass
teacher_logitsandstudent_logitsdirectly to the divergence loss function. This avoids redundant computations of converting logits to log probabilities and then reverting them to raw logits. However note that we are not reusing thestudent_log_probsvalue calculated duringce_lossin distillation base.get_batch_logpsintest/utils.py.Modify
chunkingdimensions fromBtoB * T. Thanks to @hongpeng-guo's great advice.Normalize the
distillation_lossusing(full_target != ignore_index).sum().TODO
Although a slightly slowdown is reasonable, we need to investigate why this PR's implementation is significantly slower compared to the naive approach. Thanks to @Tcc0403 's clarification.
The issue arises because we are not properly configuring the
chunk_sizefor theB * Tdimension, which is extremely large (a few thousand). The previous default of 1 results in an excessive number of chunks.In contrast, this problem does not occur with the preference loss, as chunking is performed on the
Bdimension. This produces fewer than 10 chunks, which is efficient and works as expected.In conclusion, I set
chunk_sizeto1024works pretty well in new benchmark results as shown in Add JSD Loss for Distillation #425Knowledge Distillation
Knowledge Distillation (KD; Hinton et al. 2015, Gou et al. 2020) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student.
In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let
z_tandz_srepresent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperatureT. When ground truth labelsyare available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth.The combined loss function is defined as:
Here, we directly pass in
logitsrather thanlogpbs. @Tcc0403Shared
DistillationBaseTo support various distillation learning objectives, this PR aims to add a
LigerFusedLinearDistillationBasewhich is basically same as propose by @hongpeng-guo within this discussion #371 (comment). Thank you @hongpeng-guo for thinking through this.Testing Done
I'll post JSD tests and benchmarks results in next PR: #425
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence