// copied and adapted from // https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu template static __global__ void moe_vec_q(const void* __restrict__ vx, const void* __restrict__ vy, scalar_t* __restrict__ dst, const int* topk_ids, const int topk, const int ncols, const int nrows, const int token_stride) { const auto row = blockIdx.x * blockDim.y + threadIdx.y; const auto token = blockIdx.z / topk; const auto expert = (topk_ids)[blockIdx.z]; if (row >= nrows) { return; } const int blocks_per_row = ncols / qk; const int blocks_per_warp = vdr * WARP_SIZE / qi; // partial sum for each thread float tmp = 0.0f; const block_q_t* x = ((const block_q_t*)vx) + expert * nrows * blocks_per_row; const block_q8_1* y = (const block_q8_1*)(((const int*)vy) + token * token_stride); for (auto i = threadIdx.x / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) { const int ibx = row * blocks_per_row + i; // x block index const int iby = i * (qk / QK8_1); // y block index that aligns with ibx const int iqs = vdr * (threadIdx.x % (qi / vdr)); // x block quant index when casting the quants to int tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs); } // sum up partial sums and write back result #pragma unroll for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += VLLM_SHFL_XOR_SYNC(tmp, mask); } if (threadIdx.x == 0) { dst[blockIdx.z * nrows + row] = tmp; } } template static void moe_vec_q4_0_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q<<>>( vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_q4_1_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q<<>>( vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_q5_0_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q<<>>( vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_q5_1_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q<<>>( vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_q8_0_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q<<>>( vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_q2_K_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q<<>>( vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_q3_K_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q<<>>( vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_q4_K_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q<<>>( vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_q5_K_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q<<>>( vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_q6_K_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q<<>>( vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_iq2_xxs_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q <<>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_iq2_xs_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q <<>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_iq2_s_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q <<>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_iq3_xxs_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q <<>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_iq1_s_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q <<>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_iq1_m_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q <<>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_iq4_nl_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q<<>>( vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_iq4_xs_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q <<>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); } template static void moe_vec_iq3_s_q8_1_cuda(const void* vx, const void* vy, scalar_t* dst, const int* topk_ids, const int top_k, const int tokens, const int ncols, const int nrows, const int token_stride, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, tokens * top_k); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); moe_vec_q <<>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); }