vLLM FP8 quantized support for SFT/GRPO#3414
Merged
Merged
Conversation
| del W_deq | ||
| return grad_X, None, None | ||
|
|
||
| @torch.compile |
Member
There was a problem hiding this comment.
Can you check if torch.compile(fullgraph = True, dynamic = True) works better.
Also try using:
from unsloth_zoo.temporary_patches.common import torch_compile_options, torch_compile
@torch_compile
def ...
Collaborator
Author
There was a problem hiding this comment.
I noticed no performance difference between the three when trying out Qwen3-8B between any of the 3
| if weight_fake_quantizer is not None: | ||
| W = weight_fake_quantizer(W) | ||
|
|
||
| W_quant = next((x for x in [getattr(W, "quant_state", None), getattr(base_layer, "weight_scale_inv", None), getattr(base_layer, "weight_scale", None)] if x is not None), None) |
Member
There was a problem hiding this comment.
Tbh best to make an if elif to make it faster
Collaborator
Author
There was a problem hiding this comment.
My only worry is someone mistakenly changing when I add if..else cuz if tensor would fail when tensor exists.
one needs to explicitly do if tensor is not None or something like that
I thought this is a safer way to let people continue this/avoid that
But can change if you feel its better that way
danielhanchen
requested changes
Oct 15, 2025
| if weight_fake_quantizer is not None: | ||
| W = weight_fake_quantizer(W) | ||
|
|
||
| W_quant = next((x for x in [getattr(W, "quant_state", None), getattr(base_layer, "weight_scale_inv", None), getattr(base_layer, "weight_scale", None)] if x is not None), None) |
abiswas-realadvice
pushed a commit
to abiswas-realadvice/unsloth
that referenced
this pull request
May 14, 2026
* Prefer loading model from pretrained instead of config * Fixup FP8 forward pass and inference * [WIP] Fix lora forwards * Infer block size from weight shapes * reconstruct weights from fp8 quants for lora matmul * Return weight transpose and fix dtype * Refactor FP8 operations * Fix naming :) * Saner compile * do not depend on transformers * [WIP] fix training * Update comment * fixup training * use dequant kernel from deepseek * Differentiate between fp8 and fbgemmfp8 * fixup differentiation b/w fp8 and fbgemm_fp8 * make inputs contiguous if required * Improve dequant * More robust handling * Fixup backward pass for fbgemm_fp8 * refactor and use bf16 for dequant * Use torch fp8 block matmul * Disable torch block matmul for now * safer import and cosmetics * more cosmectics * add torchao operations * Spaceeeeeee
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Depends on unslothai/unsloth-zoo#313