#include #include #include #include #include "core/math.hpp" #include "cuda_compat.h" #include "dispatch_utils.h" #include "quantization/fp8/common.cuh" namespace vllm { template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) return (T)(((float)x) / (1.0f + expf((float)-x))); } // Activation and gating kernel template. template __global__ void act_and_mul_quant_kernel( fp8_type* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] const float* scale, const int d) { const int32_t blocks_per_token = gridDim.y; const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t); // We don't expect the hidden dimension to exceed 32 bits so int32 should // be safe here. const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token); const int32_t elems_per_block = round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load); const int32_t block_start = blockIdx.y * elems_per_block; int32_t block_end = block_start + elems_per_block; block_end = block_end > d ? d : block_end; // token_idx is 64 bit to prevent 32 bit overflow when the number of tokens // is very large const int64_t token_idx = blockIdx.x; const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d; const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d; fp8_type* __restrict__ out_ptr = out + token_idx * d; // 128-bit vectorized code const int32_t vec_loop_end = round_to_previous_multiple_of(elems_per_128bit_load, block_end); const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load; const int32_t vec_start_idx = block_start / elems_per_128bit_load; const int4* __restrict__ x_128bit_ptr = reinterpret_cast(x_ptr); const int4* __restrict__ y_128bit_ptr = reinterpret_cast(y_ptr); int2* __restrict__ out_128bit_ptr = reinterpret_cast(out_ptr); float inverted_scale = 1 / *scale; #pragma unroll for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx; vec_idx += blockDim.x) { const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]); const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]); using scalar_128bit_vec_t = std::array; using scalar_64bit_vec_t = std::array; scalar_64bit_vec_t out_vec; const auto x_vec = reinterpret_cast(x_128bit); const auto y_vec = reinterpret_cast(y_128bit); #pragma unroll for (int i = 0; i < elems_per_128bit_load; i++) { out_vec[i] = scaled_fp8_conversion( ACT_FN(x_vec[i]) * y_vec[i], inverted_scale); } out_128bit_ptr[vec_idx] = reinterpret_cast(out_vec); } // Scalar cleanup code if (block_end > vec_loop_end) { for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&x_ptr[idx]); const scalar_t y = VLLM_LDG(&y_ptr[idx]); out_ptr[idx] = scaled_fp8_conversion(ACT_FN(x) * y, inverted_scale); } } } } // namespace vllm // Launch activation, gating, and quantize kernel. #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ int d = input.size(-1) / 2; \ int64_t num_tokens = input.numel() / input.size(-1); \ dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \ dim3 block(std::min(d, 512)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ VLLM_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), "act_and_mul_kernel", [&] { \ VLLM_DISPATCH_FP8_TYPES( \ out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \ vllm::act_and_mul_quant_kernel, \ fp8_t> \ <<>>(out.data_ptr(), \ input.data_ptr(), \ scale.data_ptr(), d); \ }); \ }); void silu_and_mul_quant(torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., 2 * d] torch::Tensor& scale) { TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn || out.dtype() == torch::kFloat8_e4m3fnuz); TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16); TORCH_CHECK(input.size(-1) % 2 == 0); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); }