Introduce Distillation with a Chunked, Fused Linear JS-divergence Loss#408
Introduce Distillation with a Chunked, Fused Linear JS-divergence Loss#408austin362667 wants to merge 17 commits into
Conversation
Signed-off-by: Austin Liu <austin362667@gmail.com> Add Testing Naive Distillation Base Signed-off-by: Austin Liu <austin362667@gmail.com> Add Chunked JSD Tests and Benchmarks Signed-off-by: Austin Liu <austin362667@gmail.com> Fix call Signed-off-by: Austin Liu <austin362667@gmail.com> Fix Test Usage Signed-off-by: Austin Liu <austin362667@gmail.com> Remove beta Signed-off-by: Austin Liu <austin362667@gmail.com> Fix test params Signed-off-by: Austin Liu <austin362667@gmail.com> Fix call Signed-off-by: Austin Liu <austin362667@gmail.com> Fix ignore_index Signed-off-by: Austin Liu <austin362667@gmail.com> Fix weights dimension Signed-off-by: Austin Liu <austin362667@gmail.com> Fix assign dimension Signed-off-by: Austin Liu <austin362667@gmail.com> Fix assign dimension Signed-off-by: Austin Liu <austin362667@gmail.com> Fix teacher bias Signed-off-by: Austin Liu <austin362667@gmail.com> Reshape input Signed-off-by: Austin Liu <austin362667@gmail.com> Fix mean Signed-off-by: Austin Liu <austin362667@gmail.com> Remove alpha Signed-off-by: Austin Liu <austin362667@gmail.com> Fix t Signed-off-by: Austin Liu <austin362667@gmail.com> Fix t Signed-off-by: Austin Liu <austin362667@gmail.com> Fix t scaling Signed-off-by: Austin Liu <austin362667@gmail.com> Remove teacher tests Signed-off-by: Austin Liu <austin362667@gmail.com> Fix t scaling Signed-off-by: Austin Liu <austin362667@gmail.com> Fix beta Signed-off-by: Austin Liu <austin362667@gmail.com> Fix beta Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com> WIP Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com> Clean up Signed-off-by: Austin Liu <austin362667@gmail.com> Clean up Signed-off-by: Austin Liu <austin362667@gmail.com> Clean up Signed-off-by: Austin Liu <austin362667@gmail.com> Clean up Signed-off-by: Austin Liu <austin362667@gmail.com> Clean up Signed-off-by: Austin Liu <austin362667@gmail.com> Clean up Signed-off-by: Austin Liu <austin362667@gmail.com> Clean up Signed-off-by: Austin Liu <austin362667@gmail.com> Clean up Signed-off-by: Austin Liu <austin362667@gmail.com> Clean up Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
Signed-off-by: Austin Liu <austin362667@gmail.com>
| if valid_mask.any(): | ||
| student_average_log_prob[valid_mask] = ( | ||
| student_per_token_logps * loss_mask | ||
| ).sum(-1)[valid_mask] / loss_mask_sum[valid_mask] |
There was a problem hiding this comment.
I'm not quite understand what loss_mask_sum and valid_mask do. Is it just a way to avoid ZeroDivisionError?
| loss_fn=None, | ||
| chunk_size=1, | ||
| ignore_index=-100, | ||
| beta=0.5, |
There was a problem hiding this comment.
Perhaps we need another variable name for the weight between soft and hard loss, since some loss functions have 'beta' parameter, such as generalized jsd we've implemented in #278.
Since lambda is a reserved keyword, maybe weight_hard_loss and weight_soft_loss?
If sum of both weights is 1, you can just pick one of them and also consider torch.lerp() for combining 2 losses
| labels.view(-1), | ||
| ) | ||
|
|
||
| student_logps = self.get_batch_logps( |
There was a problem hiding this comment.
Do we need to calculate the probability per token for knowledge distillation? I might be wrong but don't we just pass teacher_logits and student_logits directly to divergence loss function, such as kldiv (normally with reduction="batchmean") or jsd?
There was a problem hiding this comment.
@Tcc0403 Thank you for the review!!
Actually, you're right—I’m aware of that. I was just trying to align the interface with the preference-based design and reuse the value of student_log_probs calculated during ce_loss in DistillBase. However, if it's not necessary to maintain the same interface, I prefer your suggestion.
As shown in the distillation calculation function in this PR, it essentially undoes the operations. This redundant computation could be avoided by directly passing the raw logits to the divergence function, instead of first converting them to log probabilities and then reversing them back to the original values.
| label_chunk = torch.where(loss_mask, target_chunk, 0) | ||
|
|
||
| student_average_log_prob = torch.zeros_like(loss_mask, dtype=torch.float) | ||
| student_per_token_logps = student_log_probs_chunk.gather( |
If you're referring to current LigerFusedLinearJSD, there're some benchmark in #300. When comparing forward pass only, fljsd kernel is supposed to be slower since it does gradient calculations in forward pass as well, and it isn't purely written in triton so it might also suffer from kernel launching overhead. But it's true that it doesn't perform well in low BT scenario. |
| student_logps (torch.Tensor): Avg log probabilities of student inputs. Shape: (batch_size, hidden_size,). | ||
| teacher_logps (torch.Tensor): Avg log probabilities of teacher inputs. Shape: (batch_size, hidden_size,). |
There was a problem hiding this comment.
I think for the general distillation loss, the student and teacher logps should be per-token instead of being averaged in the sequence length dimension. I.e., both tensors should be of shape (bathc_size, sequence_size, vocab_size) or (flattended_batch_sequence_size, vocab_size).
| distillation_loss = distillation_loss_fn( | ||
| student_logps, teacher_logps, temperature | ||
| ) | ||
| distillation_loss = distillation_loss / (full_target.shape[0]) |
There was a problem hiding this comment.
After we made the distillation loss per_token, we may normalize the distillation_loss with full_target != ignore_index).sum similar to the ce_loss.
hongpeng-guo
left a comment
There was a problem hiding this comment.
Thanks a lot for drafting the distillation base class, left some comments on the fused_linear_distillation.py, mainly discussing the loss should be computed per token or averaged from the sequence level first.
Anther major question that I am having is on the chunking dimensions. Current implementation of this PR is just chunking from the batch_size dimension, which is similar to the implementation of fused_linear_preference.py. However, I think it would be better if we can chunk from the flattened dim[0] of (B*T, vocab_size), which is also the way of chunking described in the paper for CE_loss.
For preference_base class, I think the chunking only happens on the batch_size dimension because the sequence dimension is reduced when calculating the average logps (link). . But for distillation, we may prefer to follow the patten of CE loss to chunk on the joint dimension of B*T, so that this kernel can work for very long sequence/ context scenario. Happy to help refine this base class @austin362667
cc @shivam15s what do you think on this?
|
@austin362667 nit: A side note is to split this PR into two stacked PRs: first for the distillation base class and second for the JSDloss based from it. We can prioritize to polish and merge the first PR so that other distillation losses can be based on it and it's non-blocking 😄 |
|
@hongpeng-guo Thanks for review~
That makes perfect sense to me; I'll proceed with this approach.
Absolutely! I'll split this into two separate PRs. |
|
Thanks all nice comments! @Tcc0403 and @hongpeng-guo |
## Summary Made #417 from the main repo. Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR is the s first split from #408, focusing solely on introducing the Knowledge Distillation base class. As a result, this PR does not include any tests at the moment. #### Code Changes 1. Refactor `beta` into two weights: `weight_hard_loss` and `weight_soft_loss`, as coefficients between `hard_loss` and `soft_loss`. @Tcc0403 also pointed out that we could use `torch.lerp` if applicable. 2. Pass `teacher_logits` and `student_logits` directly to the divergence loss function. This avoids redundant computations of converting logits to log probabilities and then reverting them to raw logits. However note that we are not reusing the `student_log_probs` value calculated during `ce_loss` in distillation base. 1. Remove the unnecessary `get_batch_logps` in `test/utils.py`. 3. Modify `chunking` dimensions from `B` to `B * T`. Thanks to @hongpeng-guo's great advice. 1. Fix the loss calculation to use per-token values instead of averaging across the sequence length dimension. 4. Normalize the `distillation_loss` using `(full_target != ignore_index).sum()`. #### TODO 1. [X] Although a slightly slowdown is reasonable, we need to investigate why this PR's implementation is **significantly slower** compared to the naive approach. Thanks to @Tcc0403 's clarification. The issue arises because we are not properly configuring the `chunk_size` for the `B * T` dimension, which is extremely large (a few thousand). The previous default of 1 results in an excessive number of chunks. In contrast, this problem does not occur with the preference loss, as chunking is performed on the `B` dimension. This produces fewer than 10 chunks, which is efficient and works as expected. In conclusion, I set `chunk_size` to `1024` works pretty well in new benchmark results as shown in #425 2. [ ] #417 (comment) #### Knowledge Distillation Knowledge Distillation (KD; [Hinton et al. 2015](https://arxiv.org/abs/1503.02531), [Gou et al. 2020](https://arxiv.org/abs/2006.05525)) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student. In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let `z_t` and `z_s` represent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperature `T`. When ground truth labels `y` are available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth. The combined loss function is defined as: ```math \mathcal{L}_{\text{knowledge distillation}} = \mathcal{w}_{\text{soft}} \cdot \mathcal{L}_{\text{distill}}(\mathbf{z_t}, \mathbf{z_s}, T) + \mathcal{w}_{\text{hard}} \cdot \mathcal{L}_{\text{cross entropy}}(\mathbf{y}, \mathbf{z_s}), ``` Here, we directly pass in `logits` rather than `logpbs`. @Tcc0403 #### Shared `DistillationBase` To support various distillation learning objectives, this PR aims to add a `LigerFusedLinearDistillationBase` which is basically same as propose by @hongpeng-guo within this discussion #371 (comment). Thank you @hongpeng-guo for thinking through this. ## Testing Done I'll post JSD tests and benchmarks results in next PR: #425 - Hardware Type: L40S - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu <austin362667@gmail.com> Co-authored-by: shivam15s <shivam15800@gmail.com>
Summary
Knowledge Distillation
Knowledge Distillation (KD; Hinton et al. 2015, Gou et al. 2020) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student.
In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let
z_tandz_srepresent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperatureT. When ground truth labelsyare available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth.The combined loss function is defined as:
Here,
lambdais a hyperparameter that balances the distillation loss and the supervised objective.Shared
DistillationBaseTo support various distillation learning objectives, this PR aims to add a
LigerFusedLinearDistillationBasewhich is basically same as propose by @hongpeng-guo within this discussion #371 (comment). Thank you @hongpeng-guo for thinking through this.Jensen-Shannon Divergence Loss
In addition to adding the base class, this PR implements Jensen-Shannon Divergence (JSD) loss as the soft learning objective in the distillation setting. This component can be replaced with other losses (e.g., KL divergence) as
distillation_loss_fn.JSD is defined as the average of the KL divergences between each distribution and the mean distribution:
Here,
PandQare the two probability distributions, andMis their average.TODO
Testing Done
Yes.
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence