mirror of https://github.com/vllm-project/vllm.git
88 lines
2.6 KiB
Python
88 lines
2.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import subprocess
|
|
import sys
|
|
|
|
import regex as re
|
|
|
|
FORBIDDEN_IMPORT_RE = re.compile(r"^(from|import)\s+triton(\s|\.|$)")
|
|
|
|
# the way allowed to import triton
|
|
ALLOWED_LINES = {
|
|
"from vllm.triton_utils import triton",
|
|
"from vllm.triton_utils import tl",
|
|
"from vllm.triton_utils import tl, triton",
|
|
}
|
|
|
|
ALLOWED_FILES = {"vllm/triton_utils/importing.py"}
|
|
|
|
|
|
def is_allowed_file(current_file: str) -> bool:
|
|
return current_file in ALLOWED_FILES
|
|
|
|
|
|
def is_forbidden_import(line: str) -> bool:
|
|
stripped = line.strip()
|
|
return bool(
|
|
FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES
|
|
|
|
|
|
def parse_diff(diff: str) -> list[str]:
|
|
violations = []
|
|
current_file = None
|
|
current_lineno = None
|
|
skip_allowed_file = False
|
|
|
|
for line in diff.splitlines():
|
|
if line.startswith("+++ b/"):
|
|
current_file = line[6:]
|
|
skip_allowed_file = is_allowed_file(current_file)
|
|
elif skip_allowed_file:
|
|
continue
|
|
elif line.startswith("@@"):
|
|
match = re.search(r"\+(\d+)", line)
|
|
if match:
|
|
current_lineno = int(
|
|
match.group(1)) - 1 # next "+ line" is here
|
|
elif line.startswith("+") and not line.startswith("++"):
|
|
current_lineno += 1
|
|
code_line = line[1:]
|
|
if is_forbidden_import(code_line):
|
|
violations.append(
|
|
f"{current_file}:{current_lineno}: {code_line.strip()}")
|
|
return violations
|
|
|
|
|
|
def get_diff(diff_type: str) -> str:
|
|
if diff_type == "staged":
|
|
return subprocess.check_output(
|
|
["git", "diff", "--cached", "--unified=0"], text=True)
|
|
elif diff_type == "unstaged":
|
|
return subprocess.check_output(["git", "diff", "--unified=0"],
|
|
text=True)
|
|
else:
|
|
raise ValueError(f"Unknown diff_type: {diff_type}")
|
|
|
|
|
|
def main():
|
|
all_violations = []
|
|
for diff_type in ["staged", "unstaged"]:
|
|
try:
|
|
diff_output = get_diff(diff_type)
|
|
violations = parse_diff(diff_output)
|
|
all_violations.extend(violations)
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr)
|
|
|
|
if all_violations:
|
|
print("❌ Forbidden direct `import triton` detected."
|
|
" ➤ Use `from vllm.triton_utils import triton` instead.\n")
|
|
for v in all_violations:
|
|
print(f"❌ {v}")
|
|
return 1
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|