Skip to content

Parakeet nemotron encoder#23568

Merged
mickqian merged 11 commits intosgl-project:mainfrom
yhyang201:parakeet-nemotron-encoder
Apr 25, 2026
Merged

Parakeet nemotron encoder#23568
mickqian merged 11 commits intosgl-project:mainfrom
yhyang201:parakeet-nemotron-encoder

Conversation

@yhyang201
Copy link
Copy Markdown
Collaborator

Motivation

Modifications

Accuracy Tests

Speed Tests and Profiling

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@yhyang201
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces audio processing and dynamic resolution support for the Nemotron-VL model, including the integration of the Parakeet audio encoder and a utility for extracting audio from video. Key enhancements include temporal video compression via tubelet grouping and ragged packing for variable-sized images. Feedback identifies a critical bug in the extract_feature method where a missing projection layer causes a dimension mismatch, and a style violation regarding an inline import that should be moved to the top of the file.

Comment on lines 206 to 225
def extract_feature(self, pixel_values):
# Process images in a micro-batch of at most 128 frames per call
# This is done on purpose to ensure peak GPU ram usage of huge batch
# (namely for really long videos with EVS ON) won't cause any problems
# as we don't support chunked prefill for video media
micro_batch_size = 128
n = pixel_values.shape[0]
patch_size = self.config.patch_size
h_patches = pixel_values.shape[-2] // patch_size
w_patches = pixel_values.shape[-1] // patch_size
vit_embeds_list = []
for i in range(0, n, micro_batch_size):
vit_embeds = self.vision_model(pixel_values[i : i + micro_batch_size])
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
chunk = pixel_values[i : i + micro_batch_size]
batch_size = chunk.shape[0]
vit_embeds = self.vision_model(chunk)
vit_embeds = vit_embeds.to(dtype=self.model_dtype)
vit_embeds = vit_embeds.reshape(batch_size, h_patches, w_patches, -1)
vit_embeds = self.pixel_shuffle(
vit_embeds, scale_factor=self.downsample_ratio
)
vit_embeds = vit_embeds.view(-1, self.rmsnorm_hidden_size)
vit_embeds = self.mlp1(vit_embeds)
vit_embeds = vit_embeds.view(n, -1, self.rmsnorm_hidden_size)
vit_embeds = vit_embeds.view(batch_size, -1, self.llm_hidden_size)
vit_embeds_list.append(vit_embeds)
vit_embeds = torch.cat(vit_embeds_list, dim=0)
return vit_embeds
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The extract_feature method is missing the application of the mlp1 projection layer. This will result in features with incorrect dimensions (rmsnorm_hidden_size instead of llm_hidden_size), which could lead to runtime errors or incorrect model behavior. Other feature extraction methods in this file, like extract_feature_dynamic and extract_video_feature_temporal, correctly apply this projection. This method should be updated to include the mlp1 projection to ensure feature dimensions are correct.

    def extract_feature(self, pixel_values):
        micro_batch_size = 128
        n = pixel_values.shape[0]
        patch_size = self.config.patch_size
        h_patches = pixel_values.shape[-2] // patch_size
        w_patches = pixel_values.shape[-1] // patch_size
        vit_embeds_list = []
        for i in range(0, n, micro_batch_size):
            chunk = pixel_values[i : i + micro_batch_size]
            batch_size = chunk.shape[0]
            vit_embeds = self.vision_model(chunk)
            vit_embeds = vit_embeds.to(dtype=self.model_dtype)
            vit_embeds = vit_embeds.reshape(batch_size, h_patches, w_patches, -1)
            vit_embeds = self.pixel_shuffle(
                vit_embeds, scale_factor=self.downsample_ratio
            )
            vit_embeds = vit_embeds.view(-1, self.rmsnorm_hidden_size)
            vit_embeds = self.mlp1(vit_embeds)
            vit_embeds = vit_embeds.view(batch_size, -1, self.llm_hidden_size)
            vit_embeds_list.append(vit_embeds)
        vit_embeds = torch.cat(vit_embeds_list, dim=0)
        return vit_embeds

return clip_sizes

def _subsampling_output_length(self, length: int) -> int:
import math
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to the PEP 8 style guide, imports should be at the top of the file. Placing import math inside a method reduces readability and is not standard practice. Please move this import to the top of the module.

@yhyang201
Copy link
Copy Markdown
Collaborator Author

The CI error is unrelated to this PR. Can we merge it? @mickqian

@mickqian mickqian merged commit 4a3fe2a into sgl-project:main Apr 25, 2026
1148 of 1258 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants