mirror of https://github.com/vllm-project/vllm.git
132 lines
5.2 KiB
Python
132 lines
5.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
from typing import Optional
|
|
|
|
import pytest
|
|
import torch._dynamo
|
|
|
|
from tests.compile.backend import TestBackend
|
|
from tests.models.utils import check_outputs_equal
|
|
from vllm import LLM, SamplingParams
|
|
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
|
|
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
|
from vllm.compilation.fx_utils import find_op_nodes
|
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
|
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
|
from vllm.platforms import current_platform
|
|
|
|
# globals needed for string-import custom Dynamo backend field
|
|
backend: Optional[TestBackend] = None
|
|
backend_unfused: Optional[TestBackend] = None
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model, quant_key",
|
|
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
|
|
@pytest.mark.parametrize(
|
|
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False])
|
|
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
|
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
|
reason="Only test CUDA and ROCm")
|
|
def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
|
quant_key: QuantKey, use_triton_fa: bool):
|
|
# Clean Dynamo cache to avoid reusing other test cases
|
|
# (for some reason the reset at the end is not enough)
|
|
torch._dynamo.reset()
|
|
|
|
# Use global backends
|
|
global backend, backend_unfused
|
|
|
|
use_v1 = False # can be made a param once V1 support added
|
|
monkeypatch.setenv("VLLM_USE_V1", str(int(use_v1)))
|
|
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))
|
|
|
|
# Prompt 4 seems too open-ended, differs between fused and unfused
|
|
# (both outputs look reasonable though)
|
|
prompts = example_prompts[:4] + example_prompts[5:]
|
|
|
|
compile_config = CompilationConfig(
|
|
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
|
|
# DYNAMO_ONCE does not properly propagate shapes.
|
|
level=CompilationLevel.DYNAMO_AS_IS,
|
|
backend="tests.compile.test_fusion_attn.backend_unfused",
|
|
)
|
|
vllm_config = VllmConfig(compilation_config=compile_config)
|
|
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
|
|
|
|
llm = LLM(model,
|
|
enforce_eager=True,
|
|
compilation_config=compile_config,
|
|
gpu_memory_utilization=0.9,
|
|
max_model_len=2048)
|
|
|
|
sampling_params = SamplingParams(temperature=0.0,
|
|
max_tokens=10,
|
|
top_p=0.95)
|
|
|
|
unfused_output = llm.generate(prompts, sampling_params)
|
|
backend_unfused = None # Reset backend to make sure llm gets released
|
|
del llm
|
|
|
|
compile_config = CompilationConfig(
|
|
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
|
|
# DYNAMO_ONCE does not properly propagate shapes.
|
|
level=CompilationLevel.DYNAMO_AS_IS,
|
|
backend="tests.compile.test_fusion_attn.backend",
|
|
)
|
|
vllm_config = VllmConfig(compilation_config=compile_config)
|
|
|
|
# AttnFusionPass needs attention layers to be registered in config upon init
|
|
# so we initialize it during compilation.
|
|
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
|
|
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
|
|
llm2 = LLM(model,
|
|
enforce_eager=True,
|
|
compilation_config=compile_config,
|
|
gpu_memory_utilization=0.9,
|
|
max_model_len=2048)
|
|
|
|
# check support
|
|
attn_fusion_supported = [
|
|
layer.impl.fused_output_quant_supported(quant_key.dtype,
|
|
quant_key.static,
|
|
quant_key.group_shape)
|
|
for key, layer in compile_config.static_forward_context.items()
|
|
]
|
|
|
|
print(f"{attn_fusion_supported=}")
|
|
if any(attn_fusion_supported):
|
|
# Check quant ops
|
|
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
|
|
|
|
# attention ops present in both, just output_scale param changes
|
|
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
|
|
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
|
|
assert len(attn_nodes_pre) == len(attn_nodes_post)
|
|
|
|
for i in range(len(attn_nodes_pre)):
|
|
assert attn_nodes_pre[i].kwargs["output_scale"] is None
|
|
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
|
|
assert fused == attn_fusion_supported[i], \
|
|
f"Node {i} {'' if fused else 'not '} expected " \
|
|
f"to have fused output quant"
|
|
|
|
# check outputs
|
|
fused_output = llm2.generate(prompts, sampling_params)
|
|
|
|
# transform outputs to format expected by check_outputs_equal
|
|
sample_outs = lambda s: (list(s.token_ids), s.text)
|
|
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]
|
|
|
|
check_outputs_equal(
|
|
outputs_0_lst=outs_lst(unfused_output),
|
|
outputs_1_lst=outs_lst(fused_output),
|
|
name_0="unfused",
|
|
name_1="fused",
|
|
)
|
|
|
|
# Clean Dynamo cache to avoid polluting other case(s)
|
|
torch._dynamo.reset()
|
|
|
|
# Reset backend to make sure llm2 gets released
|
|
backend = None
|