Skip to content

Adds FunAudioChat multimodal audio model support (#2)#33058

Merged
DarkLight1337 merged 13 commits intovllm-project:mainfrom
nemoramo:main
Jan 28, 2026
Merged

Adds FunAudioChat multimodal audio model support (#2)#33058
DarkLight1337 merged 13 commits intovllm-project:mainfrom
nemoramo:main

Conversation

@nemoramo
Copy link
Copy Markdown
Contributor

@nemoramo nemoramo commented Jan 26, 2026

supports full-functional FunAudioChat ST2T only inference mode in VLLM

Purpose

Add vLLM support for FunAudioChatForConditionalGeneration (multimodal audio → text).

This PR:

  • Implements the FunAudioChat model in vllm/model_executor/models/funaudiochat.py (multimodal embeddings + audio towers + LM integration).
  • Registers the model/config in the vLLM model registry and Transformers config registry so it can be loaded via --model.
  • Adds/updates the multimodal audio processor path, supporting audio inputs as:
    • np.ndarray (waveform), and
    • (np.ndarray, sampling_rate) tuples to enable automatic resampling when needed.
  • Improves audio media handling (e.g., WAV byte fallback when optional deps are missing; validation to avoid decoding unsupported WAV formats).
  • Adds an offline inference example for quick functional verification.
  • Updates CI model-registry coverage by adding an example entry in tests/models/registry.py (placeholder if no public HF repo is available).

Test Plan

Unit / CI-style tests

Run:

  • python -m pytest -q tests/models/test_registry.py -k FunAudioChat
  • python -m pytest -q tests/multimodal
  • python -m pytest -q tests/model_executor -k funaudiochat (if applicable in this repo)

Lint / formatting

Run one of the following (depending on repo setup):

  • pre-commit run --all-files
  • or python -m ruff check . && python -m ruff format --check .

Notes / constraints:

  • Long-audio path may require FlashAttention-2 (flash_attn) for performance and/or memory (runtime error provides actionable instructions if missing).
  • Resampling is triggered only when input provides (audio, sr) and sr differs from the target SR; it is skipped when SR already matches.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Jan 26, 2026

Documentation preview: https://vllm--33058.org.readthedocs.build/en/33058/

@mergify mergify bot added documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) new-model Requests to new models labels Jan 26, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the FunAudioChat multimodal audio model, including the core model implementation, configuration, tests, and example scripts. The implementation appears to be a comprehensive port of the original model. The associated changes to improve audio handling robustness, such as adding a fallback for WAV file loading, are beneficial. I've identified one potential high-severity issue in the model implementation that could lead to a division-by-zero error under specific configuration, which I've detailed in a comment.

if channels % 2 != 0:
raise ValueError("SinusoidsPositionEmbedding needs even channels input")

log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line can cause a division-by-zero error if channels is 2, as the expression channels // 2 - 1 would evaluate to 0. While the default d_model (which is passed as channels) is large, a custom model configuration could set it to 2, leading to a crash. Please add a check at the beginning of the __init__ method to prevent this. For example:

if channels < 4:
    raise ValueError(f'SinusoidsPositionEmbedding needs channels to be >= 4 (got {channels})')

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Jan 26, 2026

Hi @nemoramo, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 5 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

if self.continuous_features_mode == "add":
hidden_states[feature_exist_mask] += continuous_audio_hidden_states
else:
hidden_states[feature_exist_mask] = continuous_audio_hidden_states
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shape mismatch when feature_exist_mask is partial

Medium Severity

When feature_exist_mask is not all True, hidden_states[feature_exist_mask] produces a tensor with fewer batch items than continuous_audio_hidden_states. The addition or assignment attempts to combine tensors of shape (num_valid, seq_len, hidden) with (batch_size, seq_len, hidden), causing a shape mismatch. The continuous_audio_hidden_states tensor also needs to be indexed with feature_exist_mask.

Fix in Cursor Fix in Web

]
)
self.ln_post = nn.LayerNorm(embed_dim)
self.avg_pooler = nn.AvgPool1d(2, stride=2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused avg_pooler module is dead code

Low Severity

self.avg_pooler is defined as nn.AvgPool1d(2, stride=2) but never used. The code instead uses nn.functional.avg_pool1d directly at line 389. Since nn.AvgPool1d has no learnable parameters, this module serves no purpose and can be removed.

Fix in Cursor Fix in Web

eos_token_id: int | None = None,
group_size: int = 5,
enable_audio_invert_tower: bool = True,
pad_token_id: int | None = None,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Config defaults cause TypeError when creating model

Medium Severity

The config declares codebook_size: int | None = None and pad_token_id: int | None = None with None defaults. However, FunAudioChatDiscreteEncoder.__init__ calls int(config.pad_token_id) and int(config.codebook_size) which raises TypeError when these are None. When FunAudioChatConfig is created without an explicit audio_config, a default config with these None values is used.

Additional Locations (1)

Fix in Cursor Fix in Web

embed_dim = int(config.d_model)
self.num_mel_bins = int(config.num_mel_bins)
self.max_source_positions = int(config.max_source_positions)
self.embed_scale = (embed_dim**0.5) if bool(config.scale_embedding) else 1.0
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Computed embed_scale variable is never used

Low Severity

self.embed_scale is computed based on config.scale_embedding but is never used in the forward method. If scale_embedding=True is configured, the expected scaling behavior (multiplying embeddings by sqrt(d_model)) will not occur, causing a mismatch with the intended HuggingFace model behavior.

Fix in Cursor Fix in Web


if attn_impl != "flash_attention_2":
# Upgrade the audio tower attention impl for this run.
self.config._attn_implementation = "flash_attention_2"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Config mutation affects shared state across requests

Medium Severity

When processing long audio (speech_maxlen >= 7500), the code permanently mutates self.config._attn_implementation to "flash_attention_2". This state change persists after the forward pass completes, affecting all subsequent inference requests even for short audio that wouldn't require FlashAttention-2.

Fix in Cursor Fix in Web

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e0fa71ebfc

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +800 to +804
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen3ForCausalLM"],
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Respect text_config model_type when selecting LM arch

The FunAudioChat config explicitly defaults text_config to Qwen2 for backward compatibility (vllm/transformers_utils/configs/funaudiochat.py:97-109), but the model initialization hard-codes architectures=["Qwen3ForCausalLM"]. That forces Qwen3 even when the checkpoint/config is Qwen2, so older FunAudioChat checkpoints (or configs without an embedded text_config) will fail to load due to mismatched architecture/weights. Consider deriving architectures from config.text_config.model_type or config.text_config.architectures instead of forcing Qwen3.

Useful? React with 👍 / 👎.

Comment on lines +127 to +130
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
Copy link
Copy Markdown
Member

@DarkLight1337 DarkLight1337 Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should use merge the QKV projections into QKVParallelLinear and use RowParallelLinear for out_proj. During weight loading you can use stacked_params_mapping to merge the weights, look at other models for examples.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should use merge the QKV projections into QKVParallelLinear and use RowParallelLinear for out_proj. During weight loading you can use stacked_params_mapping to merge the weights, look at other models for examples.

now added qkvparallelLinear and rowparallellinear

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Jan 26, 2026

Hi @nemoramo, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Comment on lines +141 to +143
with torch.no_grad():
if self.qkv_proj.bias is not None:
self.qkv_proj.bias.zero_()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parameter initialize seems to be training specific, so it can be removed

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parameter initialize seems to be training specific, so it can be removed

done

value_states = value_states.transpose(0, 1).unsqueeze(0)
attn_impl = getattr(self.config, "_attn_implementation", "eager") or "eager"

attention_interface = _eager_attention_forward
Copy link
Copy Markdown
Member

@DarkLight1337 DarkLight1337 Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use vLLM's attention modules instead (refer to other models)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, have you tried loading the model automatically using --model-impl transformers without needing this PR?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, have you tried loading the model automatically using --model-impl transformers without needing this PR?

I believe that funaudiochat is not implemented in huggingface transformers yet. This model is fully implemented in https://github.com/FunAudioLLM/Fun-Audio-Chat.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use vLLM's attention modules instead (refer to other models)

done using mmencoderattn

nemoramo and others added 4 commits January 26, 2026 18:34
supports full-functional FunAudioChat ST2T only inference mode in VLLM

Signed-off-by: ramos <49182011+nemoramo@users.noreply.github.com>
Signed-off-by: ramos <49182011+nemoramo@users.noreply.github.com>
Refactors several scripts to enhance readability by reformatting
long lines, improving variable names, and standardizing code style.
Makes logic easier to follow without changing functionality.

Signed-off-by: mayufeng <mayufeng@example.com>
Replaces custom audio attention with shared encoder attention backend,
enabling support for FlashAttention and aligning with vLLM's multimodal
infrastructure. Updates weight loading to handle fused QKV weights and
biases robustly. Ensures distributed group initialization for test scripts
and cleans up process groups on exit to prevent warnings.

Improves memory efficiency and correctness for long audio sequences,
raising explicit errors if FlashAttention is not available.

Relates to compatibility with vLLM multimodal models and enhanced
resource management for offline inference.

Signed-off-by: mayufeng <mayufeng@example.com>
@@ -0,0 +1,262 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move these examples into the consolidated example files (like audio_language.py)? We want to avoid having a separate example file for every model

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, sorry these are for my debugging and attention alignment

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move these examples into the consolidated example files (like audio_language.py)? We want to avoid having a separate example file for every model

removed other examples and consolidate the funaudiochat examples in audio_language.py

Signed-off-by: ramos <49182011+nemoramo@users.noreply.github.com>
Signed-off-by: ramos <49182011+nemoramo@users.noreply.github.com>
Comment thread vllm/_custom_ops.py Outdated
Comment thread vllm/transformers_utils/configs/funaudiochat.py
Eliminates an unnecessary conditional guard before registering a fake custom operation, streamlining the code and ensuring consistent registration behavior regardless of op availability.

Signed-off-by: ramos <49182011+nemoramo@users.noreply.github.com>
Introduces a note clarifying the need for a workaround due to missing
native support and auto_map in public checkpoints. Ensures compatibility
until official support or proper mapping is available in upstream
libraries.

Signed-off-by: ramos <49182011+nemoramo@users.noreply.github.com>
Comment thread tests/model_executor/test_funaudiochat_init.py Outdated
Comment thread tests/models/registry.py
trust_remote_code=True,
),
"FunAudioChatForConditionalGeneration": _HfExamplesInfo(
"funaudiochat", is_available_online=False
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an estimated min_transformers_version?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently not possible, i will send the PR to transformers, but uh... not sure the version will be

Can you add an estimated min_transformers_version?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok no problem

Comment thread vllm/multimodal/media/audio.py Outdated
Deletes tests targeting initialization logic and fallback behaviors
for the FunAudioChat model. Likely reflects deprecation, refactoring,
or migration of test coverage elsewhere.

Signed-off-by: ramos <49182011+nemoramo@users.noreply.github.com>
Simplifies audio loading logic by eliminating the custom WAV parsing
and requiring librosa for all audio decoding. Reduces code complexity
and ensures consistent behavior for all supported audio formats.

Signed-off-by: ramos <49182011+nemoramo@users.noreply.github.com>
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) January 28, 2026 02:56
@DarkLight1337
Copy link
Copy Markdown
Member

Thanks for your patience!

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 28, 2026
@nemoramo
Copy link
Copy Markdown
Contributor Author

[2026-01-28T03:25:35Z] =========================== short test summary info ============================ -- [2026-01-28T03:25:35Z] ERROR tokenizers_/test_detokenize.py::test_decode_streaming[False-True-True-mistralai/Pixtral-12B-2409-False-\u1015\u102f\u1036\u1015\u103c\u1004\u103a\u101c\u1031\u1038\u1015\u103c\u1031\u102c\u1015\u103c\u1015\u102b\u103a] - FileNotFoundError: Could not connect to the Hugging Face Hub and no local files were found for the repo ID mistralai/Pixtral-12B-2409 and revision main. Please check your internet connection and try again.

does this issue matter? seems a connection problem @DarkLight1337

@DarkLight1337
Copy link
Copy Markdown
Member

Retrying

@DarkLight1337 DarkLight1337 merged commit 36d450e into vllm-project:main Jan 28, 2026
56 checks passed
apd10 pushed a commit to apd10/vllm that referenced this pull request Jan 31, 2026
…lm-project#33058)

Signed-off-by: ramos <49182011+nemoramo@users.noreply.github.com>
Signed-off-by: mayufeng <mayufeng@example.com>
Co-authored-by: mayufeng <mayufeng@example.com>
PiratePai pushed a commit to PiratePai/epd_shm that referenced this pull request Feb 3, 2026
…lm-project#33058)

Signed-off-by: ramos <49182011+nemoramo@users.noreply.github.com>
Signed-off-by: mayufeng <mayufeng@example.com>
Co-authored-by: mayufeng <mayufeng@example.com>
Signed-off-by: Pai <416932041@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) new-model Requests to new models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants