mirror of https://github.com/vllm-project/vllm.git
94 lines
2.4 KiB
Python
94 lines
2.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import sys
|
|
import types
|
|
from unittest import mock
|
|
|
|
from vllm.triton_utils.importing import (TritonLanguagePlaceholder,
|
|
TritonPlaceholder)
|
|
|
|
|
|
def test_triton_placeholder_is_module():
|
|
triton = TritonPlaceholder()
|
|
assert isinstance(triton, types.ModuleType)
|
|
assert triton.__name__ == "triton"
|
|
|
|
|
|
def test_triton_language_placeholder_is_module():
|
|
triton_language = TritonLanguagePlaceholder()
|
|
assert isinstance(triton_language, types.ModuleType)
|
|
assert triton_language.__name__ == "triton.language"
|
|
|
|
|
|
def test_triton_placeholder_decorators():
|
|
triton = TritonPlaceholder()
|
|
|
|
@triton.jit
|
|
def foo(x):
|
|
return x
|
|
|
|
@triton.autotune
|
|
def bar(x):
|
|
return x
|
|
|
|
@triton.heuristics
|
|
def baz(x):
|
|
return x
|
|
|
|
assert foo(1) == 1
|
|
assert bar(2) == 2
|
|
assert baz(3) == 3
|
|
|
|
|
|
def test_triton_placeholder_decorators_with_args():
|
|
triton = TritonPlaceholder()
|
|
|
|
@triton.jit(debug=True)
|
|
def foo(x):
|
|
return x
|
|
|
|
@triton.autotune(configs=[], key="x")
|
|
def bar(x):
|
|
return x
|
|
|
|
@triton.heuristics(
|
|
{"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64})
|
|
def baz(x):
|
|
return x
|
|
|
|
assert foo(1) == 1
|
|
assert bar(2) == 2
|
|
assert baz(3) == 3
|
|
|
|
|
|
def test_triton_placeholder_language():
|
|
lang = TritonLanguagePlaceholder()
|
|
assert isinstance(lang, types.ModuleType)
|
|
assert lang.__name__ == "triton.language"
|
|
assert lang.constexpr is None
|
|
assert lang.dtype is None
|
|
assert lang.int64 is None
|
|
|
|
|
|
def test_triton_placeholder_language_from_parent():
|
|
triton = TritonPlaceholder()
|
|
lang = triton.language
|
|
assert isinstance(lang, TritonLanguagePlaceholder)
|
|
|
|
|
|
def test_no_triton_fallback():
|
|
# clear existing triton modules
|
|
sys.modules.pop("triton", None)
|
|
sys.modules.pop("triton.language", None)
|
|
sys.modules.pop("vllm.triton_utils", None)
|
|
sys.modules.pop("vllm.triton_utils.importing", None)
|
|
|
|
# mock triton not being installed
|
|
with mock.patch.dict(sys.modules, {"triton": None}):
|
|
from vllm.triton_utils import HAS_TRITON, tl, triton
|
|
assert HAS_TRITON is False
|
|
assert triton.__class__.__name__ == "TritonPlaceholder"
|
|
assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder"
|
|
assert tl.__class__.__name__ == "TritonLanguagePlaceholder"
|