Skip to content

[video processors] decode only sampled videos -> less RAM and faster processing#39600

Merged
zucchini-nlp merged 48 commits intohuggingface:mainfrom
zucchini-nlp:video-decoding
Aug 26, 2025
Merged

[video processors] decode only sampled videos -> less RAM and faster processing#39600
zucchini-nlp merged 48 commits intohuggingface:mainfrom
zucchini-nlp:video-decoding

Conversation

@zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Jul 23, 2025

What does this PR do?

This PR moves the video decoding code entirely into video processors, so that we can load only necessary video frames into memory. To be consistent with video processors, I also updated image processors to accept str in inputs and optionally load images.

The docs for video processors are also updated explaining how frames are sampled and what users need to do to turn it on/off. Note that we'll be using by default torchcodec and fallback to torchvision, and we won't support any arbitrary video decoders within video processor class. Otherwise we'd need to introduce more kwargs and handle differences between decoders, which bloats up the code even more

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zucchini-nlp zucchini-nlp requested a review from qubvel August 4, 2025 13:12
Copy link
Contributor

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, it should be a great improvement!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @yonigozlan for changes in this file

@zucchini-nlp
Copy link
Member Author

run-slow: aria, aya_vision, blip, bridgetower, chameleon, clip, colpali, deepseek_vl, deepseek_vl_hybrid, emu3, eomt, flava, gemma3, gemma3n, glm4v

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/aria', 'models/aya_vision', 'models/blip', 'models/bridgetower', 'models/chameleon', 'models/clip', 'models/colpali', 'models/deepseek_vl', 'models/deepseek_vl_hybrid', 'models/emu3', 'models/eomt', 'models/flava', 'models/gemma3', 'models/gemma3n', 'models/glm4v']
quantizations: [] ...

@zucchini-nlp
Copy link
Member Author

run-slow: qwen2_vl, qwen2_5_vl, qwen2_5_omni, smolvlm, llava_onevision, llava_next_video, perception_lm

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/llava_next_video', 'models/llava_onevision', 'models/perception_lm', 'models/qwen2_5_omni', 'models/qwen2_5_vl', 'models/qwen2_vl', 'models/smolvlm']
quantizations: [] ...

@zucchini-nlp
Copy link
Member Author

zucchini-nlp commented Aug 19, 2025

On no, new torch release doesn't work well with Bytes objects 😓 (fails only in CI, still figuring out why)

@zucchini-nlp
Copy link
Member Author

run-slow: qwen2_vl, qwen2_5_vl, qwen2_5_omni, smolvlm, llava_onevision, llava_next_video, perception_lm

1 similar comment
@zucchini-nlp
Copy link
Member Author

run-slow: qwen2_vl, qwen2_5_vl, qwen2_5_omni, smolvlm, llava_onevision, llava_next_video, perception_lm

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/llava_next_video', 'models/llava_onevision', 'models/perception_lm', 'models/qwen2_5_omni', 'models/qwen2_5_vl', 'models/qwen2_vl', 'models/smolvlm']
quantizations: [] ...

@zucchini-nlp
Copy link
Member Author

The CI is impossible to pass 🙃

@zucchini-nlp
Copy link
Member Author

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Aug 26, 2025

Style bot fixed some files and pushed the changes.

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: aria, aya_vision, blip, bridgetower, chameleon, clip, colpali, deepseek_vl, deepseek_vl_hybrid, emu3, eomt, flava, gemma3, gemma3n, glm4v

@zucchini-nlp zucchini-nlp merged commit f690a2a into huggingface:main Aug 26, 2025
24 checks passed
@shaform
Copy link

shaform commented Jan 22, 2026

@zucchini-nlp I believe it was previously possible for users to sample frames themselves and then pass the batched video frames directly to the video processor. This usage can be found in existing code, for example in the following notebook: https://huggingface.co/facebook/vjepa2-vitl-fpc16-256-ssv2/blob/main/notebook_finetuning.ipynb
.

However, this no longer seems to be supported due to changes introduced in this PR. In particular, the notebook above now produces errors when run with the current versions.

Was this behavior change intentional, or is it an unintended regression?

@zucchini-nlp
Copy link
Member Author

@shaform hey, it is still possible to sample frames and pass them as a list of 3D frames or a 4D array. You just need to pass do_sample=False in the video processor call to disable sampling. Such as:

pixel_values = video_processor(my_sampled_video, do_sample=False).pixel_values_video

@shaform
Copy link

shaform commented Jan 22, 2026

@shaform hey, it is still possible to sample frames and pass them as a list of 3D frames or a 4D array. You just need to pass do_sample=False in the video processor call to disable sampling. Such as:

pixel_values = video_processor(my_sampled_video, do_sample=False).pixel_values_video

@zucchini-nlp Thank you for your quick response. I tested a bit, and found that the reason the notebook fails now is because the processor will return a tensor with a incorrect shape if a batched input is given. Here is a minimal example to reproduce:

# input: B x T x C x H x W
processor(th.zeros(4, 5, 3, 100, 100), return_tensors="pt", do_sample_frames=False).pixel_values_videos.shape
# output: 1 x 1 x B x T x C x H x W
Out[17]: torch.Size([1, 1, 4, 5, 3, 256, 256])

It seems a simple workaround is to convert the input into a list, i.e.,

processor([t for t in th.zeros(4, 5, 3, 100, 100)], return_tensors="pt", do_sample_frames=False).pixel_values_videos.shape

Could this be considered as a bug?

@zucchini-nlp
Copy link
Member Author

@shaform yeah, should not be happening. Would you mind opening an issue with a minimal reproducer (model_id and how you call it) so I don't forget about it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants