vllm/tests/compile/test_fusion_attn.py

132 lines
5.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pytest
import torch._dynamo
from tests.compile.backend import TestBackend
from tests.models.utils import check_outputs_equal
from vllm import LLM, SamplingParams
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.platforms import current_platform
# globals needed for string-import custom Dynamo backend field
backend: Optional[TestBackend] = None
backend_unfused: Optional[TestBackend] = None
@pytest.mark.parametrize(
"model, quant_key",
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
@pytest.mark.parametrize(
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False])
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test CUDA and ROCm")
def test_attention_fusion(example_prompts, monkeypatch, model: str,
quant_key: QuantKey, use_triton_fa: bool):
# Clean Dynamo cache to avoid reusing other test cases
# (for some reason the reset at the end is not enough)
torch._dynamo.reset()
# Use global backends
global backend, backend_unfused
use_v1 = False # can be made a param once V1 support added
monkeypatch.setenv("VLLM_USE_V1", str(int(use_v1)))
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))
# Prompt 4 seems too open-ended, differs between fused and unfused
# (both outputs look reasonable though)
prompts = example_prompts[:4] + example_prompts[5:]
compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend_unfused",
)
vllm_config = VllmConfig(compilation_config=compile_config)
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
llm = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
max_model_len=2048)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=10,
top_p=0.95)
unfused_output = llm.generate(prompts, sampling_params)
backend_unfused = None # Reset backend to make sure llm gets released
del llm
compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend",
)
vllm_config = VllmConfig(compilation_config=compile_config)
# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
max_model_len=2048)
# check support
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key.dtype,
quant_key.static,
quant_key.group_shape)
for key, layer in compile_config.static_forward_context.items()
]
print(f"{attn_fusion_supported=}")
if any(attn_fusion_supported):
# Check quant ops
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
# attention ops present in both, just output_scale param changes
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
assert len(attn_nodes_pre) == len(attn_nodes_post)
for i in range(len(attn_nodes_pre)):
assert attn_nodes_pre[i].kwargs["output_scale"] is None
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
assert fused == attn_fusion_supported[i], \
f"Node {i} {'' if fused else 'not '} expected " \
f"to have fused output quant"
# check outputs
fused_output = llm2.generate(prompts, sampling_params)
# transform outputs to format expected by check_outputs_equal
sample_outs = lambda s: (list(s.token_ids), s.text)
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]
check_outputs_equal(
outputs_0_lst=outs_lst(unfused_output),
outputs_1_lst=outs_lst(fused_output),
name_0="unfused",
name_1="fused",
)
# Clean Dynamo cache to avoid polluting other case(s)
torch._dynamo.reset()
# Reset backend to make sure llm2 gets released
backend = None