vllm/tests/test_triton_utils.py

94 lines
2.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
import types
from unittest import mock
from vllm.triton_utils.importing import (TritonLanguagePlaceholder,
TritonPlaceholder)
def test_triton_placeholder_is_module():
triton = TritonPlaceholder()
assert isinstance(triton, types.ModuleType)
assert triton.__name__ == "triton"
def test_triton_language_placeholder_is_module():
triton_language = TritonLanguagePlaceholder()
assert isinstance(triton_language, types.ModuleType)
assert triton_language.__name__ == "triton.language"
def test_triton_placeholder_decorators():
triton = TritonPlaceholder()
@triton.jit
def foo(x):
return x
@triton.autotune
def bar(x):
return x
@triton.heuristics
def baz(x):
return x
assert foo(1) == 1
assert bar(2) == 2
assert baz(3) == 3
def test_triton_placeholder_decorators_with_args():
triton = TritonPlaceholder()
@triton.jit(debug=True)
def foo(x):
return x
@triton.autotune(configs=[], key="x")
def bar(x):
return x
@triton.heuristics(
{"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64})
def baz(x):
return x
assert foo(1) == 1
assert bar(2) == 2
assert baz(3) == 3
def test_triton_placeholder_language():
lang = TritonLanguagePlaceholder()
assert isinstance(lang, types.ModuleType)
assert lang.__name__ == "triton.language"
assert lang.constexpr is None
assert lang.dtype is None
assert lang.int64 is None
def test_triton_placeholder_language_from_parent():
triton = TritonPlaceholder()
lang = triton.language
assert isinstance(lang, TritonLanguagePlaceholder)
def test_no_triton_fallback():
# clear existing triton modules
sys.modules.pop("triton", None)
sys.modules.pop("triton.language", None)
sys.modules.pop("vllm.triton_utils", None)
sys.modules.pop("vllm.triton_utils.importing", None)
# mock triton not being installed
with mock.patch.dict(sys.modules, {"triton": None}):
from vllm.triton_utils import HAS_TRITON, tl, triton
assert HAS_TRITON is False
assert triton.__class__.__name__ == "TritonPlaceholder"
assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder"
assert tl.__class__.__name__ == "TritonLanguagePlaceholder"