/* Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16) The process of fast dequantization can be summarized as a combination of bitwise operations and floating-point computations: weight =>(bit_op / bitwise operations)=> f16_value =>(flop / floating-point computation)=> dequantized_weight Since the dequantized weights typically require subtracting the zero point and applying a scale factor, the floating-point computation step can be fused with the zero-point subtraction and scaling operations. The following are the parts that need to be modified for the fused operation of zero-point subtraction and scaling. ## INT4 => FP16/BF16 or INT8 => FP16 The floating-point computation is `__hsub2` If has zero points: flop(bit_op(weight)) - flop(bit_op(zp)) = sub(bit_op(weight), bias) - sub(bit_op(zp), bias) = bit_op(weight) - bit_op(zp) so we don't need additional modification. If has float zero points: flop(bit_op(weight)) - fzp = sub(bit_op(weight), bias) - fzp = bit_op(weight) - (fzp + bias) where the `fzp + bias` can be computed at weight loading. But this may have accuracy issue, so we should not use this in most cases. If has not zero points: scale(flop(bit_op(weight))) = scale(sub(bit_op(weight), bias)) = scale(bit_op(weight)) - scale(bias) = fma(bit_op(weight), scale_factor, scale(bias)) where the `scale(bias)` can be cached. But this may have accuracy issue, so we should not use this in most cases. ## INT8 => BF16 INT8 => BF16 is a special case, it use byte_perm instead of flop. We cannot fused byte_perm with scaling. ## FP4/FP8 => FP16/BF16 scale(flop(bit_op(weight))) = scale(mul(bit_op(weight), multiplier)) = mul(bit_op(weight), scale_factor * multiplier) where `scale_factor * multiplier` can be computed at weight loading. */ #include "marlin_dtypes.cuh" namespace MARLIN_NAMESPACE_NAME { #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. template __device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); return res; } // Constructs destination register by taking bytes from 2 sources (based on // mask) template __device__ inline uint32_t prmt(uint32_t a) { uint32_t res; asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask)); return res; } template __device__ inline void dequant(int q, scalar_t2* frag_b); // // Efficiently dequantize 4bit values packed in an int32 value into a full // B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, // with some small changes: // - FP16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 // - BF16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 // template <> __device__ inline void dequant(int q, half2* frag_b) { const int MASK = 0x000f000f; const int EX = 0x64006400; // Guarantee that the `(a & b) | c` operations are LOP3s. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); q >>= 4; int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); frag_b[0] = *reinterpret_cast(&lo); frag_b[1] = *reinterpret_cast(&hi); } template <> __device__ inline void dequant(int q, half2* frag_b) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; // Guarantee that the `(a & b) | c` operations are LOP3s. // clang-format off int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // clang-format on // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point // directly into `SUB` and `ADD`. const int SUB = 0x64086408; const int MUL = 0x2c002c00; const int ADD = 0xd480d480; frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); frag_b[1] = __hfma2(*reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); } template <> __device__ inline void dequant(int q, half2* frag_b) { dequant(q, frag_b); } template <> __device__ inline void dequant(int q, half2* frag_b) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; // Guarantee that the `(a & b) | c` operations are LOP3s. // clang-format off int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // clang-format on // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point // directly into `SUB` and `ADD`. const int SUB = 0x64006400; const int MUL = 0x2c002c00; const int ADD = 0xd400d400; frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); frag_b[1] = __hfma2(*reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); } template <> __device__ inline void dequant( int q, nv_bfloat162* frag_b) { static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t EX = 0x43004300; // Guarantee that the `(a & b) | c` operations are LOP3s. // clang-format off int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); q >>= 4; int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); // clang-format on frag_b[0] = *reinterpret_cast(&lo); frag_b[1] = *reinterpret_cast(&hi); } template <> __device__ inline void dequant( int q, nv_bfloat162* frag_b) { dequant(q, frag_b); static constexpr uint32_t SUB = 0x43084308; frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); } template <> __device__ inline void dequant( int q, nv_bfloat162* frag_b) { dequant(q, frag_b); } template <> __device__ inline void dequant( int q, nv_bfloat162* frag_b) { dequant(q, frag_b); static constexpr uint32_t SUB = 0x43004300; frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); } // // Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or // bf16 Reference: // - FP16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 // - BF16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 // template <> __device__ inline void dequant(int q, half2* frag_b) { static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t start_byte_for_fp16 = 0x64646464; uint32_t lo = prmt(q); uint32_t hi = prmt(q); frag_b[0] = *reinterpret_cast(&lo); frag_b[1] = *reinterpret_cast(&hi); } template <> __device__ inline void dequant( int q, half2* frag_b) { dequant(q, frag_b); static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); } template <> __device__ inline void dequant(int q, half2* frag_b) { dequant(q, frag_b); } template <> __device__ inline void dequant(int q, half2* frag_b) { dequant(q, frag_b); static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); } template <> __device__ inline void dequant( int q, nv_bfloat162* frag_b) { float fp32_intermediates[4]; uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); static constexpr uint32_t fp32_base = 0x4B000000; fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); fp32_intermediates[0] -= 8388736.f; fp32_intermediates[1] -= 8388736.f; fp32_intermediates[2] -= 8388736.f; fp32_intermediates[3] -= 8388736.f; uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); } template <> __device__ inline void dequant( int q, nv_bfloat162* frag_b) { float fp32_intermediates[4]; uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); static constexpr uint32_t fp32_base = 0x4B000000; fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); fp32_intermediates[0] -= 8388608.f; fp32_intermediates[1] -= 8388608.f; fp32_intermediates[2] -= 8388608.f; fp32_intermediates[3] -= 8388608.f; uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); } template <> __device__ inline void dequant( int q, half2* frag_b) { // Constants for FP8 (E4M3) and FP16 formats constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; constexpr int MASK = 0x7F007F00; // Extract and shift FP8 values to FP16 format int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); q <<= 8; int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); // Note: reverse indexing is intentional because weights are permuted frag_b[1] = *reinterpret_cast(&Out1); frag_b[0] = *reinterpret_cast(&Out2); } template <> __device__ inline void dequant( int q, half2* frag_b) { dequant(q, frag_b); // Constants for FP8 (E4M3) and FP16 formats constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; // Construct and apply exponent bias constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); // Convert to half2 and apply bias frag_b[1] = __hmul2(frag_b[1], bias_reg); frag_b[0] = __hmul2(frag_b[0], bias_reg); } template <> __device__ inline void dequant( int q, nv_bfloat162* frag_b) { // Constants for FP8 (E4M3) and BF16 formats constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; constexpr int MASK = 0x7F007F00; // Extract and shift FP8 values to BF16 format int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); q <<= 8; int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); // Note: reverse indexing is intentional because weights are permuted frag_b[1] = *reinterpret_cast(&Out1); frag_b[0] = *reinterpret_cast(&Out2); } template <> __device__ inline void dequant( int q, nv_bfloat162* frag_b) { dequant(q, frag_b); // Constants for FP8 (E4M3) and BF16 formats constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; // Construct and apply exponent bias constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent // position constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); // Convert to bfloat162 and apply bias frag_b[1] = __hmul2(frag_b[1], bias_reg); frag_b[0] = __hmul2(frag_b[0], bias_reg); } template <> __device__ inline void dequant(int q, half2* frag_b) { // Constants for FP4 (E2M1) and FP16 formats constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; constexpr int MASK = 0x70007000; // Extract and shift FP4 values to FP16 format int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); q <<= 4; int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); // Note: reverse indexing is intentional because weights are permuted frag_b[1] = *reinterpret_cast(&Out1); frag_b[0] = *reinterpret_cast(&Out2); } template <> __device__ inline void dequant( int q, half2* frag_b) { dequant(q, frag_b); // Constants for FP4 (E2M1) and FP16 formats constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; // Construct and apply exponent bias constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); // Convert to half2 and apply bias frag_b[1] = __hmul2(frag_b[1], bias_reg); frag_b[0] = __hmul2(frag_b[0], bias_reg); } template <> __device__ inline void dequant( int q, nv_bfloat162* frag_b) { // Constants for FP4 (E2M1) and FP16 formats constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; constexpr int MASK = 0x70007000; // Extract and shift FP4 values to FP16 format int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); q <<= 4; int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); // Note: reverse indexing is intentional because weights are permuted frag_b[1] = *reinterpret_cast(&Out1); frag_b[0] = *reinterpret_cast(&Out2); } template <> __device__ inline void dequant( int q, nv_bfloat162* frag_b) { dequant(q, frag_b); // Constants for FP4 (E2M1) and BF16 formats constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; // Construct and apply exponent bias constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent // position constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); // Convert to half2 and apply bias frag_b[1] = __hmul2(frag_b[1], bias_reg); frag_b[0] = __hmul2(frag_b[0], bias_reg); } template __device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); template <> __device__ inline void dequant_fp8_scales(int q, half2* frag_b) { int Out1 = (q & 0xFF00FF00) >> 1; ; q <<= 8; int Out2 = (q & 0xFF00FF00) >> 1; // Note: reverse indexing is intentional because weights are permuted frag_b[1] = *reinterpret_cast(&Out1); frag_b[0] = *reinterpret_cast(&Out2); }; template <> __device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; constexpr int MASK = 0x7F007F00; // Extract and shift FP8 values to BF16 format int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); q <<= 8; int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); // Note: reverse indexing is intentional because weights are permuted frag_b[1] = *reinterpret_cast(&Out1); frag_b[0] = *reinterpret_cast(&Out2); } #endif } // namespace MARLIN_NAMESPACE_NAME