mirror of https://github.com/vllm-project/vllm.git
243 lines
6.2 KiB
Python
243 lines
6.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
|
MINIMUM_BITBLAS_VERSION,
|
|
)
|
|
|
|
try:
|
|
import bitblas
|
|
|
|
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
|
raise ImportError(
|
|
"bitblas version is wrong. Please "
|
|
f"install bitblas>={MINIMUM_BITBLAS_VERSION}"
|
|
)
|
|
except ImportError as e:
|
|
bitblas_import_exception = e
|
|
raise ValueError(
|
|
"Trying to use the bitblas backend, but could not import"
|
|
f"with the following error: {bitblas_import_exception}. "
|
|
"Please install bitblas through the following command: "
|
|
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
|
) from bitblas_import_exception
|
|
|
|
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
|
|
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
parser = FlexibleArgumentParser(
|
|
description="Benchmark BitBLAS int4 on a specific target."
|
|
)
|
|
|
|
# Add arguments to the parser
|
|
parser.add_argument(
|
|
"--target",
|
|
type=str,
|
|
default=auto_detect_nvidia_target(),
|
|
help="Specify the target device for benchmarking.",
|
|
)
|
|
parser.add_argument(
|
|
"--group_size", type=int, default=None, help="Group size for grouped quantization."
|
|
)
|
|
parser.add_argument(
|
|
"--A_dtype",
|
|
type=str,
|
|
default="float16",
|
|
choices=["float16", "float32", "float64", "int32", "int8"],
|
|
help="Data type of activation A.",
|
|
)
|
|
parser.add_argument(
|
|
"--W_dtype",
|
|
type=str,
|
|
default="int4",
|
|
choices=[
|
|
"float16",
|
|
"float32",
|
|
"float64",
|
|
"int32",
|
|
"int8",
|
|
"int4",
|
|
"int2",
|
|
"int1",
|
|
"nf4",
|
|
"fp4_e2m1",
|
|
],
|
|
help="Data type of weight W.",
|
|
)
|
|
parser.add_argument(
|
|
"--accum_dtype",
|
|
type=str,
|
|
default="float16",
|
|
choices=["float16", "int32"],
|
|
help="Data type for accumulation.",
|
|
)
|
|
parser.add_argument(
|
|
"--out_dtype",
|
|
type=str,
|
|
default="float16",
|
|
choices=["float16", "float32", "int32", "int8"],
|
|
help="Data type for output.",
|
|
)
|
|
parser.add_argument(
|
|
"--layout",
|
|
type=str,
|
|
default="nt",
|
|
choices=["nt", "nn"],
|
|
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
|
|
)
|
|
parser.add_argument(
|
|
"--with_bias", action="store_true", help="Include bias in the benchmark."
|
|
)
|
|
parser.add_argument(
|
|
"--with_scaling",
|
|
action="store_true",
|
|
help="Include scaling factor in the quantization.",
|
|
)
|
|
parser.add_argument(
|
|
"--with_zeros", action="store_true", help="Include zeros in the quantization."
|
|
)
|
|
parser.add_argument(
|
|
"--zeros_mode",
|
|
type=str,
|
|
default=None,
|
|
choices=["original", "rescale", "quantized"],
|
|
help="Specify the mode for calculating zeros.",
|
|
)
|
|
|
|
# Parse the arguments
|
|
args = parser.parse_args()
|
|
|
|
# Assign arguments to variables
|
|
target = args.target
|
|
A_dtype = args.A_dtype
|
|
W_dtype = args.W_dtype
|
|
accum_dtype = args.accum_dtype
|
|
out_dtype = args.out_dtype
|
|
layout = args.layout
|
|
with_bias = args.with_bias
|
|
group_size = args.group_size
|
|
with_scaling = args.with_scaling
|
|
with_zeros = args.with_zeros
|
|
zeros_mode = args.zeros_mode
|
|
|
|
# Define a list of shared arguments that repeat in every config
|
|
shared_args = [
|
|
A_dtype,
|
|
W_dtype,
|
|
out_dtype,
|
|
accum_dtype,
|
|
layout,
|
|
with_bias,
|
|
group_size,
|
|
with_scaling,
|
|
with_zeros,
|
|
zeros_mode,
|
|
]
|
|
|
|
# Define just the (M, K, N) shapes in a more compact list
|
|
shapes = [
|
|
# square test
|
|
(1, 16384, 16384),
|
|
# BLOOM-176B
|
|
(1, 43008, 14336),
|
|
(1, 14336, 14336),
|
|
(1, 57344, 14336),
|
|
(1, 14336, 57344),
|
|
# OPT-65B
|
|
(1, 9216, 9216),
|
|
(1, 36864, 9216),
|
|
(1, 9216, 36864),
|
|
(1, 22016, 8192),
|
|
# LLAMA-70B/65B
|
|
(1, 8192, 22016),
|
|
(1, 8192, 8192),
|
|
(1, 28672, 8192),
|
|
(1, 8192, 28672),
|
|
# square test
|
|
(16384, 16384, 16384),
|
|
# BLOOM-176B
|
|
(8192, 43008, 14336),
|
|
(8192, 14336, 14336),
|
|
(8192, 57344, 14336),
|
|
(8192, 14336, 57344),
|
|
# OPT-65B
|
|
(8192, 9216, 9216),
|
|
(8192, 36864, 9216),
|
|
(8192, 9216, 36864),
|
|
(8192, 22016, 8192),
|
|
# LLAMA-70B/65B
|
|
(8192, 8192, 22016),
|
|
(8192, 8192, 8192),
|
|
(8192, 28672, 8192),
|
|
(8192, 8192, 28672),
|
|
]
|
|
|
|
# Build test shapes with all the shared arguments
|
|
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) for shape in shapes]
|
|
|
|
benchmark_sets = []
|
|
benchmark_sets.extend(test_shapes)
|
|
|
|
benchmark_results = {}
|
|
for config_class, operator, input_args in benchmark_sets:
|
|
config = config_class(*input_args)
|
|
matmul = operator(config, target=target, enable_tuning=True)
|
|
kernel_latency = matmul.profile_latency()
|
|
|
|
print("Time cost is: {:.3f} ms".format(kernel_latency))
|
|
|
|
profile_config = {
|
|
f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": {
|
|
"BitBLAS_top20_latency": kernel_latency,
|
|
}
|
|
}
|
|
|
|
benchmark_results.update(profile_config)
|
|
|
|
# Define headers for the table
|
|
headers = [
|
|
"PrimFunc",
|
|
"Input Arguments",
|
|
"BitBLAS Top20 Latency",
|
|
]
|
|
|
|
# Calculate column widths for pretty printing
|
|
col_widths = [0, 0, 0]
|
|
for config_key, values in benchmark_results.items():
|
|
args_split = config_key.split("-")
|
|
func_name = args_split[0]
|
|
input_args_str = "-".join(args_split[1:])
|
|
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
|
|
col_widths[1] = max(col_widths[1], len(input_args_str) + 2, len(headers[1]) + 2)
|
|
col_widths[2] = max(
|
|
col_widths[2],
|
|
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
|
|
len(headers[2]) + 2,
|
|
)
|
|
# break only if you want to measure widths from a single example;
|
|
# otherwise, let it loop over all items.
|
|
|
|
# Print header
|
|
for i, header in enumerate(headers):
|
|
headers[i] = header.ljust(col_widths[i])
|
|
print("".join(headers))
|
|
print("-" * sum(col_widths))
|
|
|
|
# Print rows
|
|
for config_key, values in benchmark_results.items():
|
|
args_split = config_key.split("-")
|
|
func_name = args_split[0]
|
|
input_args_str = "-".join(args_split[1:])
|
|
row = [
|
|
func_name,
|
|
input_args_str,
|
|
f"{values['BitBLAS_top20_latency']:.3f} ms",
|
|
]
|
|
row_str = "".join(
|
|
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)]
|
|
)
|
|
print(row_str)
|