Skip to content

[data][llm] Add per-stage map kwargs for build_llm_processor preprocess/postprocess #57812

@nrghosh

Description

@nrghosh

Description

Goal: enable users to control resources/concurrency of preprocess/postprocess stages inside build_llm_processor (Ray Data LLM batch processor), so they can scale those stages independently of the main LLM stage without external workarounds. Allows things like provisioning fractional CPU to improve utilization / efficiency (whereas Ray data would default to CPU: 1).

Summary

  • Today build_llm_processor wires preprocess -> vLLM -> postprocess using fixed Dataset.map(...).
  • map defaults: ~1 CPU/task; concurrency ~= num blocks (bounded by available CPUs).
  • Users cannot tune pre/post stages (e.g., num_cpus) independently.
  • Feature Request: allow passing through map kwargs per stage.

Motivation

  • Common pipelines (e.g., image captioning) need lightweight CPU-heavy preprocess and cheap postprocess that should scale differently from the LLM/GPU stage.
  • Today the workaround is to set preprocess=None/postprocess=None and wrap the processor with external Ray Data ops.

Proposal

  • Extend API:

    build_llm_processor(
        config,
        preprocess: Optional[Callable] = None,
        postprocess: Optional[Callable] = None,
        preprocess_map_kwargs: Optional[Dict[str, Any]] = None,
        postprocess_map_kwargs: Optional[Dict[str, Any]] = None,
    )
  • In ProcessorBase (e.g., python/ray/llm/_internal/batch/processor/base.py), forward the given kwargs to:

    if self.preprocess:
        dataset = dataset.map(self.preprocess, **(self.preprocess_map_kwargs or {}))
    ...
    if self.postprocess:
        dataset = dataset.map(self.postprocess, **(self.postprocess_map_kwargs or {}))
  • Validate keys against supported Dataset.map kwargs; warn on unknown keys.

  • Defaults None → current behavior unchanged.

Use case

Example

proc = build_llm_processor(
    processor_config,
    preprocess=caption_preprocess,   # fn(row) -> row
    postprocess=caption_postprocess, # fn(row) -> row
    preprocess_map_kwargs={"num_cpus": 0.5},
    postprocess_map_kwargs={"num_cpus": 0.25},
)
result = proc(input_ds)

Acceptance Criteria

  • Users can set per-stage num_cpus via the new kwargs.
  • Unit tests confirm kwargs are honored (resource specs attached to map tasks; concurrency respected).
  • Docs updated (processor API, example snippet).
  • Backward compatibility preserved.

Metadata

Metadata

Assignees

Labels

dataRay Data-related issuesenhancementRequest for new feature and/or capabilityllmtriageNeeds triage (eg: priority, bug/not-bug, and owning component)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions