# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import glob import itertools import os import subprocess import jinja2 FILE_HEAD = """ // auto generated by generate.py // clang-format off #include "kernel.h" #include "marlin_template.h" namespace MARLIN_NAMESPACE_NAME { """.strip() TEMPLATE = ("template __global__ void Marlin<" "{{scalar_t}}, " "{{w_type_id}}, " "{{threads}}, " "{{thread_m_blocks}}, " "{{thread_n_blocks}}, " "{{thread_k_blocks}}, " "{{'true' if m_block_size_8 else 'false'}}, " "{{stages}}, " "{{group_blocks}}, " "{{'true' if is_zp_float else 'false'}}>" "( MARLIN_KERNEL_PARAMS );") # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. SCALAR_TYPES = [ "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", "vllm::kFE2M1f" ] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] # group_blocks: # = 0 : act order case # = -1 : channelwise quantization # > 0 : group_size=16*group_blocks GROUP_BLOCKS = [0, 1, -1, 2, 4, 8] DTYPES = ["fp16", "bf16"] def remove_old_kernels(): for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): subprocess.call(["rm", "-f", filename]) def generate_new_kernels(): for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): all_template_str_list = [] for group_blocks, m_blocks, thread_configs in itertools.product( GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): # act order case only support gptq-int4 and gptq-int8 if group_blocks == 0 and scalar_type not in [ "vllm::kU4B8", "vllm::kU8B128" ]: continue if thread_configs[2] == 256: # for small batch (m_blocks == 1), we only need (128, 128, 256) # for large batch (m_blocks > 1), we only need (64, 256, 256) if m_blocks <= 1 and thread_configs[0] != 128: continue if m_blocks > 1 and thread_configs[0] != 64: continue # we only support channelwise quantization and group_size == 128 # for fp8 if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue # nvfp4 only supports group_size == 16 if scalar_type == "vllm::kFE2M1f" and group_blocks != 1: continue # other quantization methods don't support group_size = 16 if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: continue k_blocks = thread_configs[0] // 16 n_blocks = thread_configs[1] // 16 threads = thread_configs[2] c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" is_zp_float_list = [False] if dtype == "fp16" and scalar_type == "vllm::kU4" and \ group_blocks == 4: # HQQ (is_zp_float = true) only supports # 4bit quantization and fp16 is_zp_float_list.append(True) for is_zp_float in is_zp_float_list: template_str = jinja2.Template(TEMPLATE).render( scalar_t=c_dtype, w_type_id=scalar_type + ".id()", threads=threads, thread_m_blocks=max(m_blocks, 1), thread_n_blocks=n_blocks, thread_k_blocks=k_blocks, m_block_size_8=m_blocks == 0.5, stages="pipe_stages", group_blocks=group_blocks, is_zp_float=is_zp_float, ) all_template_str_list.append(template_str) file_content = FILE_HEAD + "\n\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: f.write(file_content) if __name__ == "__main__": remove_old_kernels() generate_new_kernels()