Skip to content

[Serve] Downstream deployments over-provision when receiving DeploymentResponse arguments from slow upstream #60624

@abrarsheikh

Description

@abrarsheikh

Description

Summary

When a downstream deployment receives a DeploymentResponse as an argument (i.e., the result of an upstream deployment call), it increments num_queued_requests before resolving the argument. This causes the autoscaler to see inflated queue metrics and provision more replicas than necessary.

Root Cause

In router.py, the _route_and_send_request_once method:

with self._metrics_manager.wrap_queued_request(is_retry, num_curr_replicas):
    # num_queued_requests is ALREADY incremented here
    if not pr.resolved:
        await self._resolve_request_arguments(pr)  # This can take a LONG time

The wrap_queued_request context manager increments num_queued_requests immediately. Then _resolve_request_arguments awaits any DeploymentResponse arguments, which blocks until the upstream deployment responds. During this time, the downstream deployment's queue metrics are artificially inflated.

Trade-off

This is a trade-off between two behaviors:

Behavior Pros Cons
Current (increment before resolve) Faster cold-start scaling - downstream can begin scaling while upstream is processing Over-provisions in steady state with chained deployments
Alternative (increment after resolve) Accurate queue metrics, no over-provisioning Slower scaling for chained deployments - must wait for upstream before scaling downstream

Reproduction

@serve.deployment
class SlowUpstream:
    async def __call__(self):
        await asyncio.sleep(5)  # Simulate slow processing
        return "result"

@serve.deployment(autoscaling_config={"target_ongoing_requests": 1, "min_replicas": 1, "max_replicas": 10})
class FastDownstream:
    async def __call__(self, data: str):
        return f"processed: {data}"  # Instant processing

@serve.deployment
class Router:
    def __init__(self, up: DeploymentHandle, down: DeploymentHandle):
        self._up, self._down = up, down

    async def __call__(self):
        # Pass upstream response directly to downstream
        return await self._down.remote(self._up.remote())

# Send 5 requests -> FastDownstream scales to 5 replicas even though
# it processes requests instantly. It's just waiting for arguments.

Observed Behavior

  • FastDownstream scales from 1 to 5 replicas while SlowUpstream is processing
  • Controller metrics show 5 queued requests for FastDownstream even though none have arrived yet

Expected Behavior (if over-provisioning is undesirable)

  • FastDownstream should not scale up until it actually receives requests to process
  • Queue metrics should reflect actual pending work, not argument resolution time

Questions for Discussion

  1. Is the current cold-start optimization worth the over-provisioning cost in steady state?
  2. Should this be configurable per-deployment?
  3. Should metrics distinguish between "waiting for arguments" vs "actually queued"?
Test that show this in action
import pytest

import ray
from ray import serve
from ray._common.test_utils import SignalActor, wait_for_condition
from ray.serve._private.common import DeploymentID, ReplicaState
from ray.serve._private.constants import (
    RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE,
    SERVE_DEFAULT_APP_NAME,
    SERVE_NAMESPACE,
)
from ray.serve._private.test_utils import check_num_replicas_eq, check_num_replicas_gte
from ray.serve.handle import DeploymentHandle


def get_controller():
    from ray.serve._private.constants import SERVE_CONTROLLER_NAME

    return ray.get_actor(SERVE_CONTROLLER_NAME, namespace=SERVE_NAMESPACE)


def get_running_replica_count(name: str) -> int:
    controller = get_controller()
    dep_id = DeploymentID(name=name, app_name=SERVE_DEFAULT_APP_NAME)
    replicas = ray.get(controller._dump_replica_states_for_testing.remote(dep_id))
    return len(replicas.get([ReplicaState.RUNNING]))


def get_total_requests(name: str) -> int:
    controller = get_controller()
    dep_id = DeploymentID(name=name, app_name=SERVE_DEFAULT_APP_NAME)
    return ray.get(
        controller._get_total_num_requests_for_deployment_for_testing.remote(dep_id)
    )


@pytest.fixture
def serve_instance():
    ray.init(namespace=SERVE_NAMESPACE)
    signal = SignalActor.remote()
    yield signal
    serve.shutdown()
    ray.shutdown()


@pytest.mark.skipif(
    not RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE,
    reason="Needs metric collection at handle.",
)
class TestOverprovisioningFromArgumentResolution:

    def test_downstream_overscales_while_waiting_for_upstream(self, serve_instance):
        """Downstream scales up while blocked on upstream argument resolution."""
        signal = serve_instance

        @serve.deployment(max_ongoing_requests=100)
        class Upstream:
            async def __call__(self):
                await signal.wait.remote()
                return "result"

        @serve.deployment(
            max_ongoing_requests=5,
            autoscaling_config={
                "target_ongoing_requests": 1,
                "metrics_interval_s": 0.1,
                "min_replicas": 1,
                "max_replicas": 10,
                "upscale_delay_s": 0.2,
                "downscale_delay_s": 0.5,
                "look_back_period_s": 0.5,
            },
        )
        class Downstream:
            async def __call__(self, data: str):
                return f"processed: {data}"

        @serve.deployment(max_ongoing_requests=100)
        class Router:
            def __init__(self, up: DeploymentHandle, down: DeploymentHandle):
                self._up, self._down = up, down

            async def __call__(self):
                return await self._down.remote(self._up.remote())

        handle = serve.run(Router.bind(Upstream.bind(), Downstream.bind()))
        wait_for_condition(check_num_replicas_eq, name="Downstream", target=1)
        wait_for_condition(check_num_replicas_eq, name="Upstream", target=1)

        responses = [handle.remote() for _ in range(5)]

        # Bug: Downstream scales up while just waiting for upstream arguments
        wait_for_condition(
            check_num_replicas_eq, name="Downstream", target=5, timeout=5
        )
        replicas = get_running_replica_count("Downstream")
        print(f"Downstream over-provisioned to {replicas} replicas")

        ray.get(signal.send.remote())
        for r in responses:
            assert r.result() == "processed: result"

    def test_controller_sees_inflated_request_count(self, serve_instance):
        """Controller metrics show requests for downstream before upstream responds."""
        signal = serve_instance

        @serve.deployment(max_ongoing_requests=100)
        class Upstream:
            async def __call__(self):
                await signal.wait.remote()
                return "data"

        @serve.deployment(
            max_ongoing_requests=100,
            autoscaling_config={
                "target_ongoing_requests": 5,
                "metrics_interval_s": 0.1,
                "min_replicas": 1,
                "max_replicas": 10,
                "upscale_delay_s": 1,
                "downscale_delay_s": 1,
                "look_back_period_s": 1,
            },
        )
        class Downstream:
            async def __call__(self, data: str):
                return f"got: {data}"

        @serve.deployment(max_ongoing_requests=100)
        class Router:
            def __init__(self, up: DeploymentHandle, down: DeploymentHandle):
                self._up, self._down = up, down

            async def __call__(self):
                return await self._down.remote(self._up.remote())

        handle = serve.run(Router.bind(Upstream.bind(), Downstream.bind()))
        wait_for_condition(check_num_replicas_eq, name="Downstream", target=1)
        wait_for_condition(check_num_replicas_eq, name="Upstream", target=1)

        responses = [handle.remote() for _ in range(10)]

        # Bug: Controller sees 10 requests for Downstream while they're blocked on Upstream
        def check_inflated():
            return get_total_requests("Downstream") == 10

        wait_for_condition(check_inflated, timeout=10)
        print(f"Downstream shows {get_total_requests('Downstream')} requests "
              f"while blocked on Upstream")

        ray.get(signal.send.remote())
        for r in responses:
            r.result()


if __name__ == "__main__":
    import sys
    sys.exit(pytest.main(["-v", "-s", __file__]))

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions