#pragma once #include "vectorization.cuh" namespace vllm { template struct DefaultVecOp { ScaOp scalar_op; __device__ __forceinline__ void operator()( vec_n_t& dst, const vec_n_t& src) const { #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { scalar_op(dst.val[i], src.val[i]); } } }; template __device__ inline void vectorize_with_alignment( const InT* in, OutT* out, int len, int tid, int stride, VecOp&& vec_op, // vec_n_t -> vec_n_t ScaOp&& scalar_op) { // InT -> OutT static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0, "VEC_SIZE must be a positive power-of-two"); constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B uintptr_t addr = reinterpret_cast(in); int misalignment_offset = addr & (WIDTH - 1); // addr % 64 int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64) int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64 prefix_elems /= sizeof(InT); prefix_elems = min(prefix_elems, len); // 0 ≤ prefix < 16 // 1. prefill the when it is unsafe to vectorize for (int i = tid; i < prefix_elems; i += stride) { scalar_op(out[i], in[i]); } in += prefix_elems; out += prefix_elems; len -= prefix_elems; int num_vec = len / VEC_SIZE; using vin_t = vec_n_t; using vout_t = vec_n_t; auto* v_in = reinterpret_cast(in); auto* v_out = reinterpret_cast(out); // 2. vectorize the main part for (int i = tid; i < num_vec; i += stride) { vout_t tmp; vec_op(tmp, v_in[i]); v_out[i] = tmp; } // 3. handle the tail int tail_start = num_vec * VEC_SIZE; for (int i = tid + tail_start; i < len; i += stride) { scalar_op(out[i], in[i]); } } template __device__ __forceinline__ void vectorize_with_alignment(const InT* in, OutT* out, int len, int tid, int stride, ScaOp&& scalar_op) { using Vec = DefaultVecOp>; vectorize_with_alignment(in, out, len, tid, stride, Vec{scalar_op}, std::forward(scalar_op)); } } // namespace vllm