mirror of https://github.com/vllm-project/vllm.git
122 lines
5.1 KiB
Plaintext
122 lines
5.1 KiB
Plaintext
#include <ATen/cuda/CUDAContext.h>
|
|
#include <torch/all.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
#include <cmath>
|
|
#include "core/math.hpp"
|
|
#include "cuda_compat.h"
|
|
#include "dispatch_utils.h"
|
|
|
|
#include "quantization/fp8/common.cuh"
|
|
|
|
namespace vllm {
|
|
|
|
template <typename T>
|
|
__device__ __forceinline__ T silu_kernel(const T& x) {
|
|
// x * sigmoid(x)
|
|
return (T)(((float)x) / (1.0f + expf((float)-x)));
|
|
}
|
|
|
|
// Activation and gating kernel template.
|
|
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
|
typename fp8_type>
|
|
__global__ void act_and_mul_quant_kernel(
|
|
fp8_type* __restrict__ out, // [..., d]
|
|
const scalar_t* __restrict__ input, // [..., 2, d]
|
|
const float* scale, const int d) {
|
|
const int32_t blocks_per_token = gridDim.y;
|
|
|
|
const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t);
|
|
|
|
// We don't expect the hidden dimension to exceed 32 bits so int32 should
|
|
// be safe here.
|
|
const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token);
|
|
const int32_t elems_per_block =
|
|
round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load);
|
|
const int32_t block_start = blockIdx.y * elems_per_block;
|
|
int32_t block_end = block_start + elems_per_block;
|
|
block_end = block_end > d ? d : block_end;
|
|
|
|
// token_idx is 64 bit to prevent 32 bit overflow when the number of tokens
|
|
// is very large
|
|
const int64_t token_idx = blockIdx.x;
|
|
const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d;
|
|
const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d;
|
|
fp8_type* __restrict__ out_ptr = out + token_idx * d;
|
|
|
|
// 128-bit vectorized code
|
|
const int32_t vec_loop_end =
|
|
round_to_previous_multiple_of(elems_per_128bit_load, block_end);
|
|
const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load;
|
|
const int32_t vec_start_idx = block_start / elems_per_128bit_load;
|
|
|
|
const int4* __restrict__ x_128bit_ptr = reinterpret_cast<const int4*>(x_ptr);
|
|
const int4* __restrict__ y_128bit_ptr = reinterpret_cast<const int4*>(y_ptr);
|
|
int2* __restrict__ out_128bit_ptr = reinterpret_cast<int2*>(out_ptr);
|
|
|
|
float inverted_scale = 1 / *scale;
|
|
#pragma unroll
|
|
for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx;
|
|
vec_idx += blockDim.x) {
|
|
const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]);
|
|
const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]);
|
|
using scalar_128bit_vec_t = std::array<scalar_t, elems_per_128bit_load>;
|
|
using scalar_64bit_vec_t = std::array<fp8_type, elems_per_128bit_load>;
|
|
|
|
scalar_64bit_vec_t out_vec;
|
|
const auto x_vec = reinterpret_cast<scalar_128bit_vec_t const&>(x_128bit);
|
|
const auto y_vec = reinterpret_cast<scalar_128bit_vec_t const&>(y_128bit);
|
|
|
|
#pragma unroll
|
|
for (int i = 0; i < elems_per_128bit_load; i++) {
|
|
out_vec[i] = scaled_fp8_conversion<true, fp8_type>(
|
|
ACT_FN(x_vec[i]) * y_vec[i], inverted_scale);
|
|
}
|
|
|
|
out_128bit_ptr[vec_idx] = reinterpret_cast<const int2&>(out_vec);
|
|
}
|
|
|
|
// Scalar cleanup code
|
|
if (block_end > vec_loop_end) {
|
|
for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end;
|
|
idx += blockDim.x) {
|
|
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
|
|
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
|
|
out_ptr[idx] =
|
|
scaled_fp8_conversion<true, fp8_type>(ACT_FN(x) * y, inverted_scale);
|
|
}
|
|
}
|
|
}
|
|
} // namespace vllm
|
|
|
|
// Launch activation, gating, and quantize kernel.
|
|
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
|
int d = input.size(-1) / 2; \
|
|
int64_t num_tokens = input.numel() / input.size(-1); \
|
|
dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \
|
|
dim3 block(std::min(d, 512)); \
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
|
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
|
VLLM_DISPATCH_FP8_TYPES( \
|
|
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
|
|
vllm::act_and_mul_quant_kernel<scalar_t, KERNEL<scalar_t>, \
|
|
fp8_t> \
|
|
<<<grid, block, 0, stream>>>(out.data_ptr<fp8_t>(), \
|
|
input.data_ptr<scalar_t>(), \
|
|
scale.data_ptr<float>(), d); \
|
|
}); \
|
|
});
|
|
|
|
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
|
|
torch::Tensor& input, // [..., 2 * d]
|
|
torch::Tensor& scale) {
|
|
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn ||
|
|
out.dtype() == torch::kFloat8_e4m3fnuz);
|
|
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
|
|
input.dtype() == torch::kBFloat16);
|
|
TORCH_CHECK(input.size(-1) % 2 == 0);
|
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
|
}
|