Dynamic Request Batching#

Serve offers a request batching feature that can improve your service throughput without sacrificing latency. This is possible because ML models can utilize efficient vectorized computation to process a batch of request at a time. Batching is also necessary when your model is expensive to use and you want to maximize the utilization of hardware.

Machine Learning (ML) frameworks such as Tensorflow, PyTorch, and Scikit-Learn support evaluating multiple samples at the same time. Ray Serve allows you to take advantage of this feature via dynamic request batching. When a request arrives, Serve puts the request in a queue. This queue buffers the requests to form a batch. The deployment picks up the batch and evaluates it. After the evaluation, the resulting batch will be split up, and each response is returned individually.

Enable batching for your deployment#

You can enable batching by using the ray.serve.batch decorator. Let’s take a look at a simple example by modifying the Model class to accept a batch.

from ray import serve
from ray.serve.handle import DeploymentHandle


@serve.deployment
class Model:
    def __call__(self, single_sample: int) -> int:
        return single_sample * 2


handle: DeploymentHandle = serve.run(Model.bind())
assert handle.remote(1).result() == 2

The batching decorators expect you to make the following changes in your method signature:

  • The method is declared as an async method because the decorator batches in asyncio event loop.

  • The method accepts a list of its original input types as input. For example, arg1: int, arg2: str should be changed to arg1: List[int], arg2: List[str].

  • The method returns a list. The length of the return list and the input list must be of equal lengths for the decorator to split the output evenly and return a corresponding response back to its respective request.

from typing import List

import numpy as np

from ray import serve
from ray.serve.handle import DeploymentHandle


@serve.deployment
class Model:
    @serve.batch(max_batch_size=8, batch_wait_timeout_s=0.1)
    async def __call__(self, multiple_samples: List[int]) -> List[int]:
        # Use numpy's vectorized computation to efficiently process a batch.
        return np.array(multiple_samples) * 2


handle: DeploymentHandle = serve.run(Model.bind())
responses = [handle.remote(i) for i in range(8)]
assert list(r.result() for r in responses) == [i * 2 for i in range(8)]

You can supply two optional parameters to the decorators.

  • batch_wait_timeout_s controls how long Serve should wait for a batch once the first request arrives.

  • max_batch_size controls the size of the batch. Once the first request arrives, the batching decorator will wait for a full batch (up to max_batch_size) until batch_wait_timeout_s is reached. If the timeout is reached, the batch will be sent to the model regardless the batch size.

Tip

You can reconfigure your batch_wait_timeout_s and max_batch_size parameters using the set_batch_wait_timeout_s and set_max_batch_size methods:

from typing import Dict


@serve.deployment(
    # These values can be overridden in the Serve config.
    user_config={
        "max_batch_size": 10,
        "batch_wait_timeout_s": 0.5,
    }
)
class Model:
    @serve.batch(max_batch_size=8, batch_wait_timeout_s=0.1)
    async def __call__(self, multiple_samples: List[int]) -> List[int]:
        # Use numpy's vectorized computation to efficiently process a batch.
        return np.array(multiple_samples) * 2

    def reconfigure(self, user_config: Dict):
        self.__call__.set_max_batch_size(user_config["max_batch_size"])
        self.__call__.set_batch_wait_timeout_s(user_config["batch_wait_timeout_s"])


Use these methods in the constructor or the reconfigure method to control the @serve.batch parameters through your Serve configuration file.

Streaming batched requests#

Use an async generator to stream the outputs from your batched requests. Let’s convert the StreamingResponder class to accept a batch.

import asyncio
from typing import AsyncGenerator
from starlette.requests import Request
from starlette.responses import StreamingResponse

from ray import serve


@serve.deployment
class StreamingResponder:
    async def generate_numbers(self, max: str) -> AsyncGenerator[str, None]:
        for i in range(max):
            yield str(i)
            await asyncio.sleep(0.1)

    def __call__(self, request: Request) -> StreamingResponse:
        max = int(request.query_params.get("max", "25"))
        gen = self.generate_numbers(max)
        return StreamingResponse(gen, status_code=200, media_type="text/plain")


Decorate async generator functions with the ray.serve.batch decorator. Similar to non-streaming methods, the function takes in a List of inputs and in each iteration it yields an iterable of outputs with the same length as the input batch size.

import asyncio
from typing import List, AsyncGenerator, Union
from starlette.requests import Request
from starlette.responses import StreamingResponse

from ray import serve


@serve.deployment
class StreamingResponder:
    @serve.batch(max_batch_size=5, batch_wait_timeout_s=0.1)
    async def generate_numbers(
        self, max_list: List[str]
    ) -> AsyncGenerator[List[Union[int, StopIteration]], None]:
        for i in range(max(max_list)):
            next_numbers = []
            for requested_max in max_list:
                if requested_max > i:
                    next_numbers.append(str(i))
                else:
                    next_numbers.append(StopIteration)
            yield next_numbers
            await asyncio.sleep(0.1)

    async def __call__(self, request: Request) -> StreamingResponse:
        max = int(request.query_params.get("max", "25"))
        gen = self.generate_numbers(max)
        return StreamingResponse(gen, status_code=200, media_type="text/plain")


Calling the serve.batch-decorated function returns an async generator that can be awaited to receive results.

Some inputs within a batch may generate fewer outputs than others. When a particular input has nothing left to yield, pass a StopIteration object into the output iterable. This terminates the generator that was returned when the serve.batch function was called with that input. When streaming generators returned by serve.batch-decorated functions over HTTP, this allows the end client’s connection to terminate once its call is done, instead of waiting until the entire batch is done.

Tips for fine-tuning batching parameters#

max_batch_size ideally should be a power of 2 (2, 4, 8, 16, …) because CPUs and GPUs are both optimized for data of these shapes. Large batch sizes incur a high memory cost as well as latency penalty for the first few requests.

batch_wait_timeout_s should be set considering the end to end latency SLO (Service Level Objective). For example, if your latency target is 150ms, and the model takes 100ms to evaluate the batch, the batch_wait_timeout_s should be set to a value much lower than 150ms - 100ms = 50ms.

When using batching in a Serve Deployment Graph, the relationship between an upstream node and a downstream node might affect the performance as well. Consider a chain of two models where first model sets max_batch_size=8 and second model sets max_batch_size=6. In this scenario, when the first model finishes a full batch of 8, the second model will finish one batch of 6 and then to fill the next batch, which will initially only be partially filled with 8 - 6 = 2 requests, incurring latency costs. The batch size of downstream models should ideally be multiples or divisors of the upstream models to ensure the batches play well together.