[Quantization] Remove FP4 emulation; Fall-back to marlin for device < 100 (#19563)

This commit is contained in:
Dipika Sikka 2025-06-16 17:33:51 -04:00 committed by GitHub
parent 90f9c2eb5c
commit 6bc7b57315
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 79 additions and 60 deletions

View File

@ -667,7 +667,13 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, scheme)
if isinstance(qkv_proj.scheme, scheme) or isinstance(
qkv_proj.scheme, CompressedTensorsW4A16Fp4
) and not CompressedTensorsW4A4Fp4.cutlass_fp4_supported():
assert True
else:
raise AssertionError("FP4 Scheme Mismatch")
assert qkv_proj.scheme.group_size == 16
llm.apply_model(check_model)

View File

@ -374,7 +374,14 @@ class CompressedTensorsConfig(QuantizationConfig):
if is_activation_quantization_format(self.quant_format):
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4Fp4()
if CompressedTensorsW4A4Fp4.cutlass_fp4_supported():
return CompressedTensorsW4A4Fp4()
else:
logger.warning_once(
"Current platform does not support cutlass NVFP4."
" Running CompressedTensorsW4A16Fp4.")
return CompressedTensorsW4A16Fp4(
has_input_global_scale=True)
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(

View File

@ -18,7 +18,8 @@ __all__ = ["CompressedTensorsW4A16Fp4"]
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
def __init__(self):
def __init__(self, has_input_global_scale: bool = False):
self.has_input_global_scale = has_input_global_scale
self.group_size = 16
@classmethod
@ -64,6 +65,13 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale)
if self.has_input_global_scale:
input_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_global_scale", input_global_scale)
def process_weights_after_loading(self, layer) -> None:
# Process parameters for marlin repacking
@ -77,6 +85,10 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
requires_grad=False)
del layer.weight_global_scale
if self.has_input_global_scale:
layer.input_global_scale = torch.nn.Parameter(
layer.input_global_scale.data, requires_grad=False)
prepare_fp4_layer_for_marlin(layer)
def apply_weights(self,

View File

@ -9,8 +9,6 @@ from vllm._custom_ops import (cutlass_scaled_fp4_mm,
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
dequantize_to_dtype, ref_nvfp4_quant)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
@ -21,53 +19,23 @@ logger = init_logger(__name__)
__all__ = ["CompressedTensorsW4A4Fp4"]
def cutlass_fp4_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
def __init__(self):
self.group_size = 16
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
if not self.cutlass_nvfp4_supported:
logger.warning("Current platform does not support cutlass NVFP4."
" Running emulations.")
@classmethod
def get_min_capability(cls) -> int:
# dont restrict as emulations
return 80
return 100
def run_nvfp4_emulations(self, x: torch.Tensor, layer):
x_m, x_k = x.shape
output_dtype = x.dtype
# quantize input to (FP4 and interleaved block scale)
x_fp4, x_blockscale = ref_nvfp4_quant(x, layer.input_global_scale,
self.group_size)
# dequantize input
x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size)
x_blockscale = x_blockscale.unsqueeze(-1) / layer.input_global_scale
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
del x_fp4, x_blockscale
# dequantize weight
w_fp4 = layer.weight.data.view(torch.uint8)
w_blockscale = layer.weight_scale_swizzled.data
w_global_scale = layer.weight_global_scale
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
output_dtype, x.device, self.group_size)
# matmul
out = torch.matmul(x_dq, w_dq.t())
del w_dq, x_dq
return out
@classmethod
def cutlass_fp4_supported(cls) -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int( # noqa: E501
)
return cutlass_scaled_mm_supports_fp4(capability)
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
@ -152,27 +120,24 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
# required by cutlass kernel; need Parameter, not ModelWeightParameter
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
if self.cutlass_nvfp4_supported:
layer.alpha = Parameter(layer.input_global_scale *
layer.weight_global_scale,
requires_grad=False)
layer.alpha = Parameter(layer.input_global_scale *
layer.weight_global_scale,
requires_grad=False)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.cutlass_nvfp4_supported:
output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight.shape[0]]
output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
layer.weight_scale_swizzled,
1 / layer.alpha, output_dtype)
if bias is not None:
out = out + bias
return out.view(*output_shape)
return self.run_nvfp4_emulations(x, layer)
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
layer.weight_scale_swizzled,
1 / layer.alpha, output_dtype)
if bias is not None:
out = out + bias
return out.view(*output_shape)

View File

@ -102,3 +102,32 @@ def ref_nvfp4_quant(x, global_scale, block_size):
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
# both outputs are float32
return cast_to_fp4(clipped_x), scale.squeeze(-1)
def run_nvfp4_emulations(x: torch.Tensor, input_global_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale_swizzled: torch.Tensor,
weight_global_scale: torch.Tensor):
group_size = 16
x_m, x_k = x.shape
output_dtype = x.dtype
# quantize input to (FP4 and interleaved block scale)
x_fp4, x_blockscale = ref_nvfp4_quant(x, input_global_scale, group_size)
# dequantize input
x_fp4 = x_fp4.reshape(x_m, x_k // group_size, group_size)
x_blockscale = x_blockscale.unsqueeze(-1) / input_global_scale
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
del x_fp4, x_blockscale
# dequantize weight
w_fp4 = weight.data.view(torch.uint8)
w_dq = dequantize_to_dtype(w_fp4, weight_scale_swizzled.data,
weight_global_scale, output_dtype, x.device,
group_size)
# matmul
out = torch.matmul(x_dq, w_dq.t())
del w_dq, x_dq
return out