Skip to content

ROCm: enable trillion-parameter MoE models with INT4-FP8 single node#4152

Merged
zhyncs merged 2 commits intosgl-project:mainfrom
HaiShaw:int4-fp8
Mar 6, 2025
Merged

ROCm: enable trillion-parameter MoE models with INT4-FP8 single node#4152
zhyncs merged 2 commits intosgl-project:mainfrom
HaiShaw:int4-fp8

Conversation

@HaiShaw
Copy link
Copy Markdown
Collaborator

@HaiShaw HaiShaw commented Mar 6, 2025

INT4 MoE weights, FP8 compute

credits: @shengnxu, @coderfeli, @carlushuang, @kkHuang-amd , @leishaoSC, @valarLip, @HaiShaw

Motivation

Enable models with more than 1.2 trillion parameters on single node of 8xMI300/MI308.
Speedup decoding performance from INT4 weight, lowered memory bandwidth.
Use the latest FP8 Tensor Core for computation (available to MI300, MI308).

Model used can be accessed at https://huggingface.co/amd/grok-1-W4A8KV8 (please apply access to https://huggingface.co/amd). you can also contact us in SGLang slack for temporary token.

grok-1-W4A8KV8/config.json:

{
  "_name_or_path": "/group/amdneuralopt/huggingface/pretrained_models/grok-1-sglang-tp1",
  "architectures": [
    "Grok1ModelForCausalLM"
  ],
  "attn_output_multiplier": 0.08838834764831845,
  "auto_map": {
    "AutoConfig": "configuration_grok1.Grok1Config",
    "AutoModel": "modeling_grok1.Grok1Model",
    "AutoModelForCausalLM": "modeling_grok1.Grok1ModelForCausalLM"
  },
  "bos_token_id": 1,
  "embedding_multiplier_scale": 78.38367176906169,
  "eos_token_id": 2,
  "hidden_size": 6144,
  "intermediate_size": 32768,
  "max_attn_value": 30.0,
  "max_position_embeddings": 8192,
  "model_type": "grok-1",
  "num_attention_heads": 48,
  "num_local_experts": 8,
  "num_experts_per_tok": 2,
  "num_hidden_layers": 64,
  "num_key_value_heads": 8,
  "output_multiplier_scale": 0.5773502691896257,
  "output_router_logits": false,
  "pad_token_id": 0,
  "quantization_config": {
    "activation_scheme": "static",
    "export": {
      "kv_cache_group": [
        "*k_proj",
        "*v_proj"
      ],
      "min_kv_scale": 1.0,
      "pack_method": "reorder",
      "weight_format": "real_quantized",
      "weight_merge_groups": null
    },
    "ignored_layers": [
      "model.layers.0.block_sparse_moe.gate",
      ... ... ... ...
      "model.layers.63.block_sparse_moe.gate",
      "lm_head"
    ],
    "kv_cache_scheme": "static",
    "quant_method": "fp8",
    "int4_experts": {
      "bits": 4,
      "sym": true,
      "group": "column"
    }
  },
  "rms_norm_eps": 1e-05,
  "router_aux_loss_coef": 0.001,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.47.1",
  "use_cache": true,
  "vocab_size": 131072
}

Modifications

with less than 1% margin on gsm8k scores

  • Grok-1 FP8 performance (one measured)
/sgl-workspace/sglang# python -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 512 --model /data/lmzheng-grok-1/ --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --quantization fp8
Benchmark ...
Prefill. latency: 1.70331 s, throughput:  19237.80 token/s
Decode.  latency: 0.01748 s, throughput:   1830.72 token/s
Decode.  latency: 0.01791 s, throughput:   1786.83 token/s
Decode.  latency: 0.01777 s, throughput:   1800.57 token/s
Decode.  latency: 0.01796 s, throughput:   1781.26 token/s
Decode.  latency: 0.01792 s, throughput:   1785.74 token/s
Decode.  median latency: 0.02416 s, median throughput:   1324.33 token/s
Total. latency: 13.594 s, throughput:   3615.73 token/s
  • Grok-1 INT4-FP8 quantized model performance (one measured)
# CK_MOE=1 USE_INT4_WEIGHT=1 python -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 512 --model /data/grok-1-W4A8KV8 --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --quantization fp8 --trust-remote-code
Benchmark ...
Prefill. latency: 2.21035 s, throughput:  14824.80 token/s
Decode.  latency: 0.02072 s, throughput:   1544.74 token/s
Decode.  latency: 0.02016 s, throughput:   1587.36 token/s
Decode.  latency: 0.02007 s, throughput:   1594.26 token/s
Decode.  latency: 0.02013 s, throughput:   1589.62 token/s
Decode.  latency: 0.02016 s, throughput:   1587.66 token/s
Decode.  median latency: 0.02068 s, median throughput:   1547.29 token/s
Total. latency: 12.734 s, throughput:   3859.76 token/s

INT4-FP8 model architecture

image

Conclusion:

  • INT4-FP8 enabled serving much bigger model on one server.
  • INT4-FP8 model yields better median decode throughput and latency, serves the purpose.

Checklist

@HaiShaw HaiShaw changed the title ROCm/AITER: enable trillion-parameter MoE models with INT4-FP8 (INT4 … ROCm/AITER: enable trillion-parameter MoE models with INT4-FP8 Mar 6, 2025
@HaiShaw HaiShaw changed the title ROCm/AITER: enable trillion-parameter MoE models with INT4-FP8 ROCm: enable trillion-parameter MoE models with INT4-FP8 single node Mar 6, 2025
Comment thread python/sglang/srt/layers/quantization/fp8.py
@zhyncs zhyncs merged commit 13bc39c into sgl-project:main Mar 6, 2025
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

USE_INT4_WEIGHT -> SGLANG_ROCM_USE_INT4_WEIGHTS

layer.w2_input_scale = None

def process_weights_after_loading(self, layer: Module) -> None:
if get_bool_env_var("USE_INT4_WEIGHT"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this part out into a separate function.

@HaiShaw
Copy link
Copy Markdown
Collaborator Author

HaiShaw commented Mar 6, 2025

@merrymercy let me handle your request soon.

@merrymercy merrymercy mentioned this pull request Mar 13, 2025
67 tasks
@Alcanderian
Copy link
Copy Markdown
Collaborator

Using INT4-FP8 in pure fp8 module is so ambiguous and we should refactor them in a stand-alone w4a8 module!!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants