# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import multiprocessing import socket import threading import time from typing import Optional from unittest.mock import patch import pytest from vllm.v1.utils import (APIServerProcessManager, wait_for_completion_or_failure) # Global variables to control worker behavior WORKER_RUNTIME_SECONDS = 0.5 # Mock implementation of run_api_server_worker def mock_run_api_server_worker(listen_address, sock, args, client_config=None): """Mock run_api_server_worker that runs for a specific time.""" print(f"Mock worker started with client_config: {client_config}") time.sleep(WORKER_RUNTIME_SECONDS) print("Mock worker completed successfully") @pytest.fixture def api_server_args(): """Fixture to provide arguments for APIServerProcessManager.""" sock = socket.socket() return { "target_server_fn": mock_run_api_server_worker, "listen_address": "localhost:8000", "sock": sock, "args": "test_args", # Simple string to avoid pickling issues "num_servers": 3, "input_addresses": [ "tcp://127.0.0.1:5001", "tcp://127.0.0.1:5002", "tcp://127.0.0.1:5003" ], "output_addresses": [ "tcp://127.0.0.1:6001", "tcp://127.0.0.1:6002", "tcp://127.0.0.1:6003" ], "stats_update_address": "tcp://127.0.0.1:7000", } @pytest.mark.parametrize("with_stats_update", [True, False]) def test_api_server_process_manager_init(api_server_args, with_stats_update): """Test initializing the APIServerProcessManager.""" # Set the worker runtime to ensure tests complete in reasonable time global WORKER_RUNTIME_SECONDS WORKER_RUNTIME_SECONDS = 0.5 # Copy the args to avoid mutating the args = api_server_args.copy() if not with_stats_update: args.pop("stats_update_address") manager = APIServerProcessManager(**args) try: # Verify the manager was initialized correctly assert len(manager.processes) == 3 # Verify all processes are running for proc in manager.processes: assert proc.is_alive() print("Waiting for processes to run...") time.sleep(WORKER_RUNTIME_SECONDS / 2) # They should still be alive at this point for proc in manager.processes: assert proc.is_alive() finally: # Always clean up the processes print("Cleaning up processes...") manager.close() # Give processes time to terminate time.sleep(0.2) # Verify all processes were terminated for proc in manager.processes: assert not proc.is_alive() @patch("vllm.entrypoints.cli.serve.run_api_server_worker", mock_run_api_server_worker) def test_wait_for_completion_or_failure(api_server_args): """Test that wait_for_completion_or_failure works with failures.""" global WORKER_RUNTIME_SECONDS WORKER_RUNTIME_SECONDS = 1.0 # Create the manager manager = APIServerProcessManager(**api_server_args) try: assert len(manager.processes) == 3 # Create a result capture for the thread result: dict[str, Optional[Exception]] = {"exception": None} def run_with_exception_capture(): try: wait_for_completion_or_failure(api_server_manager=manager) except Exception as e: result["exception"] = e # Start a thread to run wait_for_completion_or_failure wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True) wait_thread.start() # Let all processes run for a short time time.sleep(0.2) # All processes should still be running assert all(proc.is_alive() for proc in manager.processes) # Now simulate a process failure print("Simulating process failure...") manager.processes[0].terminate() # Wait for the wait_for_completion_or_failure # to detect and handle the failure # This should trigger it to terminate all other processes wait_thread.join(timeout=1.0) # The wait thread should have exited assert not wait_thread.is_alive() # Verify that an exception was raised with appropriate error message assert result["exception"] is not None assert "died with exit code" in str(result["exception"]) # All processes should now be terminated for i, proc in enumerate(manager.processes): assert not proc.is_alive(), f"Process {i} should not be alive" finally: manager.close() time.sleep(0.2) @pytest.mark.timeout(30) def test_normal_completion(api_server_args): """Test that wait_for_completion_or_failure works in normal completion.""" global WORKER_RUNTIME_SECONDS WORKER_RUNTIME_SECONDS = 0.1 # Create the manager manager = APIServerProcessManager(**api_server_args) try: # Give processes time to terminate # wait for processes to complete remaining_processes = manager.processes.copy() while remaining_processes: for proc in remaining_processes: if not proc.is_alive(): remaining_processes.remove(proc) time.sleep(0.1) # Verify all processes have terminated for i, proc in enumerate(manager.processes): assert not proc.is_alive( ), f"Process {i} still alive after terminate()" # Now call wait_for_completion_or_failure # since all processes have already # terminated, it should return immediately # with no error wait_for_completion_or_failure(api_server_manager=manager) finally: # Clean up just in case manager.close() time.sleep(0.2) @pytest.mark.timeout(30) def test_external_process_monitoring(api_server_args): """Test that wait_for_completion_or_failure handles additional processes.""" global WORKER_RUNTIME_SECONDS WORKER_RUNTIME_SECONDS = 100 # Create and start the external process # (simulates local_engine_manager or coordinator) spawn_context = multiprocessing.get_context("spawn") external_proc = spawn_context.Process(target=mock_run_api_server_worker, name="MockExternalProcess") external_proc.start() # Create the class to simulate a coordinator class MockCoordinator: def __init__(self, proc): self.proc = proc def close(self): if self.proc.is_alive(): self.proc.terminate() self.proc.join(timeout=0.5) # Create a mock coordinator with the external process mock_coordinator = MockCoordinator(external_proc) # Create the API server manager manager = APIServerProcessManager(**api_server_args) try: # Verify manager initialization assert len(manager.processes) == 3 # Create a result capture for the thread result: dict[str, Optional[Exception]] = {"exception": None} def run_with_exception_capture(): try: wait_for_completion_or_failure(api_server_manager=manager, coordinator=mock_coordinator) except Exception as e: result["exception"] = e # Start a thread to run wait_for_completion_or_failure wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True) wait_thread.start() # Terminate the external process to trigger a failure time.sleep(0.2) external_proc.terminate() # Wait for the thread to detect the failure wait_thread.join(timeout=1.0) # The wait thread should have completed assert not wait_thread.is_alive( ), "wait_for_completion_or_failure thread still running" # Verify that an exception was raised with appropriate error message assert result["exception"] is not None, "No exception was raised" error_message = str(result["exception"]) assert "died with exit code" in error_message, \ f"Unexpected error message: {error_message}" assert "MockExternalProcess" in error_message, \ f"Error doesn't mention external process: {error_message}" # Verify that all API server processes were terminated as a result for i, proc in enumerate(manager.processes): assert not proc.is_alive( ), f"API server process {i} was not terminated" finally: # Clean up manager.close() mock_coordinator.close() time.sleep(0.2)