vllm/tests/v1/test_async_llm_dp.py

160 lines
5.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Optional
import pytest
from vllm import SamplingParams
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient
from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
DP_SIZE = int(os.getenv("DP_SIZE", 2))
engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b",
enforce_eager=True,
disable_log_requests=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
)
if not current_platform.supports_v1(engine_args.create_model_config()):
pytest.skip(reason="Requires V1-supporting platform.",
allow_module_level=True)
async def generate(
engine: AsyncLLM,
request_id: str,
prompt: PromptType,
output_kind: RequestOutputKind,
max_tokens: int,
prompt_logprobs: Optional[int] = None,
data_parallel_rank: Optional[int] = None) -> tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2)
count = 0
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True,
output_kind=output_kind,
temperature=0,
prompt_logprobs=prompt_logprobs)
async for out in engine.generate(request_id=request_id,
prompt=prompt,
sampling_params=sampling_params,
data_parallel_rank=data_parallel_rank):
num_tokens = len(out.outputs[0].token_ids)
if output_kind == RequestOutputKind.DELTA:
count += num_tokens
else:
count = num_tokens
await asyncio.sleep(0.)
return count, request_id
@pytest.mark.parametrize(
"output_kind",
[
RequestOutputKind.DELTA,
RequestOutputKind.FINAL_ONLY,
],
)
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
@pytest.mark.asyncio
async def test_load(output_kind: RequestOutputKind,
data_parallel_backend: str):
stats_loggers = {}
@dataclass
class SimpleStatsLogger(StatLoggerBase):
init_count: int = 0
finished_req_count: int = 0
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
stats_loggers[engine_index] = self
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
if iteration_stats:
self.finished_req_count += len(
iteration_stats.finished_requests)
def log_engine_initialized(self):
self.init_count += 1
with ExitStack() as after:
prompt = "This is a test of data parallel"
engine_args.data_parallel_backend = data_parallel_backend
engine = AsyncLLM.from_engine_args(engine_args,
stat_loggers=[SimpleStatsLogger])
after.callback(engine.shutdown)
NUM_REQUESTS = 100
NUM_EXPECTED_TOKENS = 10
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
# Create concurrent requests.
tasks = []
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine, request_id, prompt, output_kind,
NUM_EXPECTED_TOKENS)))
# Short sleep to ensure that requests are distributed.
await asyncio.sleep(0.01)
# Confirm that we got all the EXPECTED tokens from the requests.
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
for task in pending:
task.cancel()
for task in done:
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
f"{request_id} generated {num_generated_tokens} but "
f"expected {NUM_EXPECTED_TOKENS}")
assert not engine.output_processor.has_unfinished_requests()
# testing internals here which may break
core_client: DPAsyncMPClient = engine.engine_core
# the engines only synchronize stopping every N steps so
# allow a small amount of time here.
for _ in range(10):
if not core_client.engines_running:
break
await asyncio.sleep(0.5)
assert not core_client.engines_running
assert not core_client.reqs_in_flight
# Check that requests were distributed between the engines
print(f"Stats loggers after test: {stats_loggers}")
assert len(stats_loggers) == DP_SIZE
assert stats_loggers[0].init_count == 1
for sl in stats_loggers.values():
slogger: SimpleStatsLogger = sl
assert slogger.finished_req_count > NUM_REQUESTS // (
DP_SIZE + 1), f"requests are imbalanced: {stats_loggers}"