#ifndef MARLIN_NAMESPACE_NAME #define MARLIN_NAMESPACE_NAME marlin #endif #include "marlin.cuh" #include "marlin_dtypes.cuh" #include "core/scalar_type.hpp" #define MARLIN_KERNEL_PARAMS \ const int4 *__restrict__ A, const int4 *__restrict__ B, \ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ const int4 *__restrict__ scales_ptr, \ const uint16_t *__restrict__ scale2_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem namespace MARLIN_NAMESPACE_NAME { template shared // fetch pipeline const int group_blocks, // number of consecutive 16x16 blocks // with a separate quantization scale const bool is_zp_float // is zero point of float16 type? > __global__ void Marlin(MARLIN_KERNEL_PARAMS); }