Skip to content

Implement nvfp4#546

Draft
zianglih wants to merge 24 commits intoradixark:mainfrom
zianglih:nvfp4
Draft

Implement nvfp4#546
zianglih wants to merge 24 commits intoradixark:mainfrom
zianglih:nvfp4

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Feb 2, 2026

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @zianglih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the model quantization capabilities by integrating NVIDIA's NVFP4 and MXFP8 formats. It provides the necessary processing logic, conversion tools for Hugging Face models, and updates to the training and inference pipelines to leverage these new, more efficient data types, particularly for MoE architectures and Blackwell hardware. The changes aim to improve performance and memory footprint for large language models.

Highlights

  • NVFP4 Quantization Implementation: Introduced support for NVIDIA's NVFP4 (E2M1) quantization scheme, specifically targeting Mixture-of-Experts (MoE) expert GEMMs. This includes new processor logic for handling NVFP4 quantization, global and block scaling, and packing FP4 values.
  • MXFP8 Quantization Implementation: Added support for MXFP8 quantization, utilizing sglang's mxfp8_group_quantize for efficient quantization of model parameters, particularly within decoder layers and MLP experts.
  • Hugging Face Model Conversion Scripts: New utility scripts (convert_hf_to_nvfp4.py and convert_hf_to_mxfp8.py) have been added to convert existing Hugging Face safetensors checkpoints into the respective NVFP4 and MXFP8 quantized formats, including updating model configuration files.
  • Integration into Quantization Pipeline: The central quantize_params dispatcher now recognizes and routes to the new NVFP4 and MXFP8 quantization methods based on the quant_method or quant_algo configuration.
  • Updated Training and Rollout Scripts: The run_qwen3_30b_a3b.py and run_qwen3_4b.py scripts have been updated to include options for enabling NVFP4 and MXFP8 during training and rollout, with checks for hardware compatibility (Blackwell GPUs) and mutual exclusivity with other FP8 modes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for NVFP4 and MXFP8 quantization, adding new quantization processors, conversion scripts, and updating run scripts. However, it introduces a critical command injection vulnerability in several scripts used for training and conversion. These scripts interpolate user-controlled input, specifically model_name, directly into shell command strings without proper sanitization, which could lead to arbitrary code execution. Beyond this security concern, there is significant code duplication of the NVFP4 quantization logic between the processor module and the conversion script, which should be refactored. Opportunities also exist to improve maintainability by refactoring repeated logic in the MXFP8 quantizer and the run scripts, and parts of the NVFP4 implementation could be optimized for better performance.

Comment on lines +122 to +229
for converted_name, param in converted_named_params:
base, role = _split_gated_pair_name(converted_name)
if base is None or role is None:
continue
if _should_quantize_param(converted_name, param, group_size):
gated_candidates.setdefault(base, {})[role] = param

for base, roles in gated_candidates.items():
if "gate" in roles and "up" in roles and _should_share_gated_pair_amax(args, base):
gate_amax = roles["gate"].abs().max().to(torch.float32)
up_amax = roles["up"].abs().max().to(torch.float32)
shared_global_amax[base] = torch.max(gate_amax, up_amax)

quantize_named_params = []
for converted_name, param in converted_named_params:
if not _should_quantize_param(converted_name, param, group_size):
quantize_named_params.append((converted_name, param))
continue
base, _role = _split_gated_pair_name(converted_name)
global_amax = shared_global_amax.get(base) if base else None
qweight, block_scale, weight_scale_2 = quantize_nvfp4(param, global_amax=global_amax, group_size=group_size)
quantize_named_params.append((converted_name, qweight))
quantize_named_params.append((converted_name.replace(".weight", ".weight_scale"), block_scale))
quantize_named_params.append((converted_name.replace(".weight", ".weight_scale_2"), weight_scale_2))
quantize_named_params.append(
(converted_name.replace(".weight", ".input_scale"), torch.ones_like(weight_scale_2, dtype=torch.float32))
)

return quantize_named_params


def _should_quantize_param(name, weight, group_size):
if not name.endswith(".weight"):
return False
if weight.dtype not in (torch.float16, torch.bfloat16, torch.float32):
return False
if weight.dim() < 2:
return False
if weight.shape[-1] % group_size != 0:
raise ValueError(f"Last dim {weight.shape[-1]} must be divisible by {group_size} for NVFP4 ({name}).")
return True


def _split_gated_pair_name(name: str):
for suffix, role in GATED_PAIR_SUFFIXES.items():
if name.endswith(suffix):
return name[: -len(suffix)], role
return None, None


def cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor:
"""Quantize a tensor to FP4 E2M1 and pack two values per byte."""
result = torch.zeros_like(x, dtype=torch.uint8)
result[(x >= 0.0) & (x <= 0.25)] = 0
result[(x > 0.25) & (x < 0.75)] = 1
result[(x >= 0.75) & (x <= 1.25)] = 2
result[(x > 1.25) & (x < 1.75)] = 3
result[(x >= 1.75) & (x <= 2.5)] = 4
result[(x > 2.5) & (x < 3.5)] = 5
result[(x >= 3.5) & (x <= 5.0)] = 6
result[x > 5.0] = 7

result[(x >= -0.25) & (x < -0.0)] = 8
result[(x < -0.25) & (x > -0.75)] = 9
result[(x <= -0.75) & (x >= -1.25)] = 10
result[(x < -1.25) & (x > -1.75)] = 11
result[(x <= -1.75) & (x >= -2.5)] = 12
result[(x < -2.5) & (x > -3.5)] = 13
result[(x <= -3.5) & (x >= -5.0)] = 14
result[x < -5.0] = 15

return result[:, ::2] + result[:, 1::2] * 16


def _quantize_nvfp4_1d(
weight: torch.Tensor,
global_amax: torch.Tensor | None = None,
group_size: int = NVFP4_GROUP_SIZE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
NVFP4 1D quantization (tile shape = 1x16), adapted from
TransformerEngine NVFP4QuantizerRef._quantize_blockwise_reference.

Returns:
qweight: uint8 packed fp4, shape (M, K // 2)
block_scale: float8_e4m3fn, shape (M, K // group_size)
global_scale: float32 scalar tensor
"""
weight = weight.contiguous()
m, n = weight.shape
if n % group_size != 0:
raise ValueError(f"NVFP4 requires K divisible by {group_size}, got {n}.")

weight_f = weight.to(torch.float32)
if global_amax is None:
global_amax = torch.max(torch.abs(weight_f))
else:
global_amax = global_amax.to(device=weight.device, dtype=torch.float32)
if global_amax.item() == 0.0:
qweight = torch.zeros((m, n // 2), dtype=torch.uint8, device=weight.device)
block_scale = torch.zeros(
(m, n // group_size),
dtype=torch.float8_e4m3fn,
device=weight.device,
)
global_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32)
return qweight, block_scale, global_scale

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The functions cast_to_fp4x2, _quantize_nvfp4_1d, and quantize_nvfp4 are identical to those in the new tools/convert_hf_to_nvfp4.py script. This significant code duplication will make future maintenance difficult and error-prone, as any change would need to be manually synchronized between both files.

Please extract this shared quantization logic into a common utility module (e.g., within miles/utils/) and import it in both this file and the conversion script.

Comment on lines +133 to +238
def cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor:
"""Quantize a tensor to FP4 E2M1 and pack two values per byte."""
result = torch.zeros_like(x, dtype=torch.uint8)
result[(x >= 0.0) & (x <= 0.25)] = 0
result[(x > 0.25) & (x < 0.75)] = 1
result[(x >= 0.75) & (x <= 1.25)] = 2
result[(x > 1.25) & (x < 1.75)] = 3
result[(x >= 1.75) & (x <= 2.5)] = 4
result[(x > 2.5) & (x < 3.5)] = 5
result[(x >= 3.5) & (x <= 5.0)] = 6
result[x > 5.0] = 7

result[(x >= -0.25) & (x < -0.0)] = 8
result[(x < -0.25) & (x > -0.75)] = 9
result[(x <= -0.75) & (x >= -1.25)] = 10
result[(x < -1.25) & (x > -1.75)] = 11
result[(x <= -1.75) & (x >= -2.5)] = 12
result[(x < -2.5) & (x > -3.5)] = 13
result[(x <= -3.5) & (x >= -5.0)] = 14
result[x < -5.0] = 15

return result[:, ::2] + result[:, 1::2] * 16


def _quantize_nvfp4_1d(
weight: torch.Tensor,
global_amax: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
NVFP4 1D quantization (tile shape = 1x16), adapted from
TransformerEngine NVFP4QuantizerRef._quantize_blockwise_reference.

Returns:
qweight: uint8 packed fp4, shape (M, K // 2)
block_scale: float8_e4m3fn, shape (M, K // 16)
global_scale: float32 scalar tensor
"""
weight = weight.contiguous()
m, n = weight.shape
if n % NVFP4_GROUP_SIZE != 0:
raise ValueError(f"NVFP4 requires K divisible by {NVFP4_GROUP_SIZE}, got {n}.")

weight_f = weight.to(torch.float32)
if global_amax is None:
global_amax = torch.max(torch.abs(weight_f))
else:
global_amax = global_amax.to(device=weight.device, dtype=torch.float32)
if global_amax.item() == 0.0:
qweight = torch.zeros((m, n // 2), dtype=torch.uint8, device=weight.device)
block_scale = torch.zeros(
(m, n // NVFP4_GROUP_SIZE),
dtype=torch.float8_e4m3fn,
device=weight.device,
)
global_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32)
return qweight, block_scale, global_scale

fp4_max = torch.tensor(FP4_E2M1_MAX, device=weight.device, dtype=torch.float32)
fp8_max = torch.tensor(FP8_E4M3_MAX, device=weight.device, dtype=torch.float32)

global_encode_scale = torch.div(fp8_max * fp4_max, global_amax)
global_encode_scale = torch.min(
global_encode_scale,
torch.tensor(torch.finfo(torch.float32).max, device=weight.device, dtype=torch.float32),
)
if global_encode_scale.item() == 0.0:
global_encode_scale = torch.tensor(1.0, device=weight.device, dtype=torch.float32)
global_decode_scale = torch.div(1.0, global_encode_scale)

weight_blocks = weight_f.view(m, n // NVFP4_GROUP_SIZE, NVFP4_GROUP_SIZE)
vec_max = torch.amax(torch.abs(weight_blocks), dim=-1, keepdim=True)
decode_scale = torch.div(vec_max, fp4_max) * global_encode_scale
decode_scale = torch.clamp(decode_scale, min=-fp8_max, max=fp8_max).to(torch.float8_e4m3fn)

encode_scale = torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale)
scaled = weight_blocks * encode_scale
clipped = torch.clamp(scaled, -fp4_max, fp4_max).reshape(m, n)

qweight = cast_to_fp4x2(clipped)
block_scale = decode_scale.squeeze(-1)
return qweight, block_scale, global_decode_scale


def quantize_nvfp4(
weight: torch.Tensor,
global_amax: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if weight.dim() == 2:
return _quantize_nvfp4_1d(weight, global_amax=global_amax)
if weight.dim() == 3:
if global_amax is not None:
raise ValueError("global_amax override is only supported for 2D weights.")
qweights = []
block_scales = []
global_scales = []
for idx in range(weight.shape[0]):
qweight, block_scale, global_scale = _quantize_nvfp4_1d(weight[idx])
qweights.append(qweight)
block_scales.append(block_scale)
global_scales.append(global_scale)
return (
torch.stack(qweights, dim=0),
torch.stack(block_scales, dim=0),
torch.stack(global_scales, dim=0),
)
raise ValueError(f"Unsupported weight rank {weight.dim()} for NVFP4 quantization.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The functions cast_to_fp4x2, _quantize_nvfp4_1d, and quantize_nvfp4 are identical to those in miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_nvfp4.py. This significant code duplication is problematic for maintenance.

Please extract this shared quantization logic into a common utility module to ensure consistency and simplify future updates.

Comment thread scripts/run_qwen3_4b.py
ckpt_args = (
f"--hf-checkpoint /root/models/{args.model_name}{'-FP8' if args.rollout_fp8 else ''} "
f"--load {load_save_path} "
f"--hf-checkpoint {hf_checkpoint}/ "
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The hf_checkpoint variable, derived from the unsanitized args.model_name, is used in the construction of ckpt_args, which is later executed as a shell command. This allows for arbitrary command execution if a malicious model name is provided.

Comment on lines +55 to +64
U.exec_command(
f"python tools/convert_hf_to_mxfp8.py --model-dir /root/models/{args.model_name} --save-dir {mxfp8_path}"
)

if use_nvfp4:
nvfp4_path = f"/root/models/{args.model_name}-NVFP4"
if not os.path.isdir(nvfp4_path):
U.exec_command(
f"python tools/convert_hf_to_nvfp4.py --model-dir /root/models/{args.model_name} --save-dir {nvfp4_path}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The script is vulnerable to command injection because it constructs shell commands using the unsanitized user-controlled input args.model_name. An attacker could provide a malicious model name containing shell metacharacters (e.g., ; rm -rf /) to execute arbitrary commands in the environment where this script is run. This input flows into U.exec_command on lines 56 and 63, and into the training command via hf_checkpoint on line 95.

Comment thread scripts/run_qwen3_4b.py
Comment on lines +60 to +62
U.exec_command(
f"python tools/convert_hf_to_mxfp8.py --model-dir /root/models/{args.model_name} --save-dir {mxfp8_path}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

Similar to the issue in scripts/run_qwen3_30b_a3b.py, this script is vulnerable to command injection. The user-controlled args.model_name is used to construct shell commands executed via U.exec_command on line 61 and interpolated into the training command via hf_checkpoint on line 89.

Comment on lines +37 to +56

# experts
expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)"
match = re.match(expert_pattern, rest)
if match:
rest, expert_idx = match.groups()
if rest in [
"linear_fc1",
"linear_fc2",
]:
return _quantize_moe_params(args, converted_named_params, group_size)

# shared expert
shared_expert_pattern = r"mlp.shared_experts\.(.+)"
match = re.match(shared_expert_pattern, rest)
if match:
rest = match.groups()[0]
if rest in [
"linear_fc1.weight",
"linear_fc2.weight",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to identify which parameters to quantize is spread across multiple conditional blocks based on regex matching. This makes the control flow complex and hard to follow.

To improve readability and maintainability, consider consolidating the decision logic. You could define the patterns for quantizable layers and then have a single block that performs the quantization if a match is found. This would also make it easier to add new quantizable layers in the future.

Comment on lines +84 to +93
use_blackwell_fp8 = args.hardware in ("GB200", "GB300") and (args.rollout_fp8 or args.train_fp8)
use_nvfp4 = args.rollout_nvfp4
if use_nvfp4:
hf_checkpoint = f"/root/models/{args.model_name}-NVFP4"
elif use_blackwell_fp8:
hf_checkpoint = f"/root/models/{args.model_name}-MXFP8"
elif args.rollout_fp8:
hf_checkpoint = f"/root/models/{args.model_name}-FP8"
else:
hf_checkpoint = f"/root/models/{args.model_name}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to determine the Hugging Face checkpoint path based on quantization flags is verbose. The helper variables use_blackwell_fp8 and use_nvfp4 are also recalculated in both the prepare and execute functions.

To improve clarity and reduce repetition, you could move the calculation of these flags into the __post_init__ method of ScriptArgs and add a helper method to the class to resolve the checkpoint path.

Comment thread scripts/run_qwen3_4b.py
Comment on lines +80 to +86
use_blackwell_fp8 = args.hardware in ("GB200", "GB300") and (args.rollout_fp8 or args.train_fp8)
if use_blackwell_fp8:
hf_checkpoint = f"/root/models/{args.model_name}-MXFP8"
elif args.rollout_fp8:
hf_checkpoint = f"/root/models/{args.model_name}-FP8"
else:
hf_checkpoint = f"/root/models/{args.model_name}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to determine the Hugging Face checkpoint path based on quantization flags is repeated from other scripts and could be simplified.

Consider adding a helper method to the ScriptArgs class to encapsulate this logic. This would improve code reuse and make the execute function cleaner.

Comment on lines +133 to +154
def cast_to_fp4x2(x: torch.Tensor) -> torch.Tensor:
"""Quantize a tensor to FP4 E2M1 and pack two values per byte."""
result = torch.zeros_like(x, dtype=torch.uint8)
result[(x >= 0.0) & (x <= 0.25)] = 0
result[(x > 0.25) & (x < 0.75)] = 1
result[(x >= 0.75) & (x <= 1.25)] = 2
result[(x > 1.25) & (x < 1.75)] = 3
result[(x >= 1.75) & (x <= 2.5)] = 4
result[(x > 2.5) & (x < 3.5)] = 5
result[(x >= 3.5) & (x <= 5.0)] = 6
result[x > 5.0] = 7

result[(x >= -0.25) & (x < -0.0)] = 8
result[(x < -0.25) & (x > -0.75)] = 9
result[(x <= -0.75) & (x >= -1.25)] = 10
result[(x < -1.25) & (x > -1.75)] = 11
result[(x <= -1.75) & (x >= -2.5)] = 12
result[(x < -2.5) & (x > -3.5)] = 13
result[(x <= -3.5) & (x >= -5.0)] = 14
result[x < -5.0] = 15

return result[:, ::2] + result[:, 1::2] * 16
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The cast_to_fp4x2 function is implemented with a series of boolean masks, which can be inefficient for large tensors. For better performance, consider using a more optimized PyTorch operation like torch.bucketize to quantize the tensor in a single pass.

fp4_max = torch.tensor(FP4_E2M1_MAX, device=weight.device, dtype=torch.float32)
fp8_max = torch.tensor(FP8_E4M3_MAX, device=weight.device, dtype=torch.float32)

global_encode_scale = torch.div(fp8_max * fp4_max, global_amax)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _quantize_nvfp4_1d function uses torch.div. It is more idiomatic and readable to use the / operator for division in PyTorch. torch.div is also being deprecated.

Suggested change
global_encode_scale = torch.div(fp8_max * fp4_max, global_amax)
global_encode_scale = fp8_max * fp4_max / global_amax

@ziang-and ziang-and requested a review from guapisolo as a code owner February 12, 2026 04:47
@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Feb 12, 2026

My current NVFP4 dev workflow:

All commands:


# SGLANG
cd /sgl-workspace/
rm -rf sglang
git clone -b  miles-0.5.8 https://github.com/zianglih/sglang.git
cd sglang
pip install --upgrade pip
pip install -e "python"
pip uninstall -y flashinfer-jit-cache flashinfer-cubin
pip3 install --upgrade flashinfer-python flashinfer-cubin --force-reinstall --no-deps

# TE
export MAX_JOBS=200
pip uninstall -y transformer_engine transformer_engine_cu12 transformer_engine_torch
pip install pybind11
cd /root
rm -rf TransformerEngine
git clone --recursive -b keep-bwd-v2.10 https://github.com/zianglih/TransformerEngine.git
cd TransformerEngine
export NVTE_FRAMEWORK=pytorch         # Optionally set framework
pip3 install --no-build-isolation .   # Build and install

# Miles
cd /root/
rm -rf miles
git clone -b nvfp4 https://github.com/zianglih/miles.git
cd /root/miles
pip install -e . --no-deps

# Megatron
pip install pybind11
pip uninstall -y megatron-core
cd /root
rm -rf Megatron-LM
git clone --recursive -b  miles-0.5.8-te-precision-config https://github.com/zianglih/Megatron-LM.git
cd Megatron-LM
pip install --no-deps --no-build-isolation .[mlm,dev]


# nvfp4 keep last8
python scripts/run_qwen3_30b_a3b.py --no-colocate --actor-num-gpus-per-node 4 --rollout-num-gpus 4 --no-enable-eval --hardware GB200 --num-gpus-per-node 8 --rollout-nvfp4 --nvfp4-keep-last-n 8 --extra-args "--use-rollout-routing-replay --use-miles-router --sglang-moe-runner-backend cutlass --sglang-enable-nan-detection  --sglang-kv-cache-dtype bf16 "
# nvfp4 first 2 last8
python scripts/run_qwen3_30b_a3b.py --no-colocate --actor-num-gpus-per-node 4 --rollout-num-gpus 4 --no-enable-eval --hardware GB200 --num-gpus-per-node 8 --rollout-nvfp4 --nvfp4-keep-first-n 2 --nvfp4-keep-last-n 8 --extra-args "--use-rollout-routing-replay --use-miles-router --sglang-moe-runner-backend cutlass --sglang-enable-nan-detection  --sglang-kv-cache-dtype bf16 "
# nvfp4 QAT
python scripts/run_qwen3_30b_a3b.py --no-colocate --actor-num-gpus-per-node 4 --rollout-num-gpus 4 --no-enable-eval --hardware GB200 --num-gpus-per-node 8 --rollout-nvfp4 --train-nvfp4 --extra-args "--use-rollout-routing-replay --use-miles-router --sglang-moe-runner-backend cutlass --sglang-enable-nan-detection  --sglang-kv-cache-dtype bf16 "

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant