# SPDX-License-Identifier: Apache-2.0 import gc import tempfile import numpy as np import pytest import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr from vllm.config import set_current_vllm_config from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.model_loader.tpu import TPUModelLoader def _setup_environment(model): engine_args = EngineArgs(model=model, ) vllm_config = engine_args.create_engine_config() with set_current_vllm_config(vllm_config): temp_file = tempfile.mkstemp()[1] init_distributed_environment( 1, 0, local_rank=0, distributed_init_method=f"file://{temp_file}", backend="gloo") # Under single worker mode, full model is init first and then # partitioned using GSPMD. ensure_model_parallel_initialized(1, 1) return vllm_config MESH = None def _get_spmd_mesh(): global MESH if MESH is None: xr.use_spmd() num_devices = xr.global_runtime_device_count() mesh_shape = (num_devices, 1) device_ids = np.array(range(num_devices)) MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y')) return MESH @pytest.mark.parametrize( "model", [ "Qwen/Qwen2-1.5B-Instruct", # Skip large models due to CI runner disk space limitations # "meta-llama/Llama-3.1-8B-Instruct", # "meta-llama/Llama-3.1-70B-Instruct", ]) def test_tpu_model_loader(model): # Skip the 70B test if there are less than 8 chips # TODO: Query using torch xla API, the query API is not working # with SPMD now. However, This test is running under SPMD mode. if '70B' in model and xr.global_runtime_device_count() < 8: pytest.skip( "Skipping 70B model if the TPU VM has less than 8 chips to \ avoid OOM.") vllm_config = _setup_environment(model) loader = TPUModelLoader(load_config=vllm_config.load_config) mesh = _get_spmd_mesh() model = loader.load_model(vllm_config, vllm_config.model_config, mesh) del model gc.collect()