diff --git a/CMakeLists.txt b/CMakeLists.txt index 8966a663d3..b1adeac586 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -648,6 +648,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if CUDA endif endif() +if (VLLM_GPU_LANG STREQUAL "HIP") + # Add QuickReduce kernels + list(APPEND VLLM_EXT_SRC + "csrc/custom_quickreduce.cu" + ) +# if ROCM endif +endif() + message(STATUS "Enabling C extension.") define_gpu_extension_target( _C diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu new file mode 100644 index 0000000000..33d0d4a722 --- /dev/null +++ b/csrc/custom_quickreduce.cu @@ -0,0 +1,114 @@ +#include +#include +#include +#include + +#ifdef USE_ROCM + + #include "quickreduce/quick_reduce.h" + +quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size, + std::optional qr_max_size) { + if (world_size > 8) + throw std::invalid_argument("world size > 8 is not supported"); + if (world_size == 6) + throw std::invalid_argument("world size == 6 is not supported"); + if (world_size % 2 != 0) + throw std::invalid_argument("Odd num gpus is not supported for now"); + if (rank < 0 || rank >= world_size) + throw std::invalid_argument("invalid rank passed in"); + quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms(); + fptr->init(world_size, rank, qr_max_size); + return (quickreduce::fptr_t)fptr; +} + +void qr_destroy(quickreduce::fptr_t _fa) { + if (_fa) { + auto fa = reinterpret_cast(_fa); + fa->destroy(); + delete fa; + } +} + +torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + hipIpcMemHandle_t handle = fa->get_handle(); + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto data_handle = + torch::empty({static_cast(sizeof(hipIpcMemHandle_t))}, options); + std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t)); + return data_handle; +} + +void qr_open_handles(quickreduce::fptr_t _fa, + const std::vector& handles) { + auto fa = reinterpret_cast(_fa); + std::vector ipc_handles; + ipc_handles.reserve(handles.size()); + for (auto& handle : handles) { + // Ensure the tensor is on the same device as the current device. + hipIpcMemHandle_t ipc_handle; + std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t)); + ipc_handles.push_back(ipc_handle); + } + fa->open_ipc_handles(ipc_handles); +} + +void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp, + torch::Tensor& out, int64_t quant_level, bool cast_bf2half) { + auto fa = reinterpret_cast(_fa); + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA(); + + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize); + if (out.scalar_type() == at::ScalarType::Half) { + fa->allreduce(reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), quant_level, stream); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + if (cast_bf2half) { + fa->allreduce(reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), quant_level, stream); + } else { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), quant_level, stream); + } + } else { + throw std::runtime_error( + "quick allreduce only supports float16 and bfloat16"); + } +} + +int64_t qr_max_size() { + // The default is 2GB (2,147,483,648 bytes) + return static_cast(std::numeric_limits::max()) + 1; +} + + #define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \ + template struct quickreduce::AllReduceTwoshot, \ + cast_bf2half>; \ + template struct quickreduce::AllReduceTwoshot, \ + cast_bf2half>; \ + template struct quickreduce::AllReduceTwoshot, cast_bf2half>; + +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true) + +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false) + +#endif // USE_ROCM \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index f02f5083ac..52c264d64c 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -360,3 +360,14 @@ std::tuple allocate_shared_buffer_and_handle( int64_t size); int64_t open_mem_handle(torch::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); + +#ifdef USE_ROCM +fptr_t init_custom_qr(int64_t rank, int64_t world_size, + std::optional qr_max_size = std::nullopt); +void qr_destroy(fptr_t _fa); +torch::Tensor qr_get_handle(fptr_t _fa); +void qr_open_handles(fptr_t _fa, const std::vector& handles); +void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + int64_t quant_level, bool cast_bf2half = false); +int64_t qr_max_size(); +#endif \ No newline at end of file diff --git a/csrc/quickreduce/base.h b/csrc/quickreduce/base.h new file mode 100644 index 0000000000..a2170e4832 --- /dev/null +++ b/csrc/quickreduce/base.h @@ -0,0 +1,338 @@ +#pragma once + +#include +#include +#include +#include + +#define __quickreduce_device_inline__ __device__ __forceinline__ +#define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4) +#define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4) + +namespace quickreduce { + +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; + +using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; +using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + +// Setup acquire-release semantics for vector memory reads (mubuf instruction) +// as per architecture. +#if defined(__gfx942__) +// CDNA3: Scope bits sc0, sc1 + #define MUBUF_ACQUIRE 16 + #define MUBUF_RELEASE 16 +#elif (defined(__gfx908__) || defined(__gfx90a__)) +// CDNA1 and CDNA2 - glc bit + #define MUBUF_ACQUIRE 1 + #define MUBUF_RELEASE 0 +#endif + +static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t + +// Number of atoms (4xf16x2_t) processed by a single thread +static constexpr int kAtoms = 8; + +// We use a workgroup of 256 threads +static constexpr int kBlockSize = 256; +static constexpr int kAtomStride = kBlockSize; + +// Size and atom stride of source/destination data that the block will +// process. +// Workgroup scope = Tile = (256 threads x 8 atoms x 16B) +static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t); + +// Max number of blocks. 304 CUs on MI300 +static constexpr int kMaxNumBlocks = 304 * 4; + +// Standard CDNA wavefront size. +static constexpr int kWavefront = 64; + +// 256 thread, 4 wavefronts. +static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1}; + +// Number of threads in a group for quantization +// It corresponds to 32 F16 elements in quantization block +static constexpr int kThreadGroupSize = 8; + +// Methods +__quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x, + unsigned long y) { + return ((x + y - 1) / y); +} + +union BufferResource { + __quickreduce_device_inline__ constexpr BufferResource() + : config(0x00020000U) {} + + __quickreduce_device_inline__ constexpr BufferResource(void* buffer_address, + uint32_t buffer_size) + : address(buffer_address), range(buffer_size), config(0x00020000U) {} + + int32x4_t descriptor; + struct { + void* address; // 8B, out of which first 48b is address, and 16b is stride + // (unused) + uint32_t range; // Byte range for the buffer resource + uint32_t config; // Constant, DFMT=32b + }; +}; + +__quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4( + int32x4_t srsrc, int32_t voffset, int32_t soffset, + int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); + +__quickreduce_device_inline__ static void buffer_store_dwordx4( + int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, + int32_t aux) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); + +__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) { +#if defined(__gfx942__) + if (value) { + asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::); + } else { + asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::); + } +#endif +} +union bf162_int_union { + int i; + nv_bfloat162 bf2; +}; + +template +__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, + int32x4_t* B); + +template <> +__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, + int32x4_t* B) { + int32x4_t& tR_fragment = A[0]; + int32x4_t& tA_fragment = B[0]; + + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[0]) + : "v"(tR_fragment[0]), "v"(tA_fragment[0])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[1]) + : "v"(tR_fragment[1]), "v"(tA_fragment[1])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[2]) + : "v"(tR_fragment[2]), "v"(tA_fragment[2])); + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(tR_fragment[3]) + : "v"(tR_fragment[3]), "v"(tA_fragment[3])); +} + +template <> +__quickreduce_device_inline__ void packed_assign_add( + int32x4_t* A, int32x4_t* B) { + nv_bfloat162* tA = reinterpret_cast(A); + nv_bfloat162* tB = reinterpret_cast(B); +#pragma unroll + for (int i = 0; i < 4; i++) { + tA[i] = __hadd2(tA[i], tB[i]); + } +} + +template +__quickreduce_device_inline__ int packed_max(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_max(int a, int b) { + int result; + asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_max(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmax2(A.bf2, B.bf2); + return R.i; +} + +template +__quickreduce_device_inline__ int packed_min(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_min(int a, int b) { + int result; + asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_min(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hmin2(A.bf2, B.bf2); + return R.i; +} + +template +__quickreduce_device_inline__ int packed_abs_max(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_abs_max(int a, int b) { + half2 wmaxh2 = __builtin_bit_cast(half2, a); + half2 wminh2 = __builtin_bit_cast(half2, b); + half2 wblockmaxh2; + + wblockmaxh2.x = + __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x; + wblockmaxh2.y = + __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y; + return __builtin_bit_cast(int, wblockmaxh2); +} + +template <> +__quickreduce_device_inline__ int packed_abs_max(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x; + R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y; + return R.i; +} + +template +__quickreduce_device_inline__ int packed_add(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_add(int a, int b) { + int result; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_add(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hadd2(A.bf2, B.bf2); + return R.i; +} + +template <> +__quickreduce_device_inline__ int packed_add(int a, int b) { + int result; + asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template +__quickreduce_device_inline__ int packed_sub(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_sub(int a, int b) { + int result; + + // MI300 lacks packed fp16 sub instruction. So we do -1 * min + max + asm volatile("v_pk_fma_f16 %0, %1, %2 %3" + : "=v"(result) + : "v"(kNegOne), "v"(b), "v"(a)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_sub(int a, int b) { + bf162_int_union A, B, R; + A.i = a; + B.i = b; + R.bf2 = __hsub2(A.bf2, B.bf2); + return R.i; +} + +template +__quickreduce_device_inline__ int packed_mul(int a, int b); + +template <> +__quickreduce_device_inline__ int packed_mul(int a, int b) { + int result; + asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); + return result; +} + +template <> +__quickreduce_device_inline__ int packed_mul(int a, int b) { + nv_bfloat162* tA = reinterpret_cast(&a); + nv_bfloat162* tB = reinterpret_cast(&b); + nv_bfloat162 tR = __hmul2(*tA, *tB); + return *(reinterpret_cast(&tR)); +} + +template +__quickreduce_device_inline__ int packed_rcp(int a); + +template <> +__quickreduce_device_inline__ int packed_rcp(int a) { + return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a))); +} + +template <> +__quickreduce_device_inline__ int packed_rcp(int a) { + bf162_int_union A, R; + A.i = a; + R.bf2 = h2rcp(A.bf2); + return R.i; +} + +// changes dtype +__quickreduce_device_inline__ float T2float_cast(half a) { + return __half2float(a); +} + +__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) { + return __bfloat162float(a); +} + +template +__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) { + const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; + + int wmax, wmin, wblockmax; + int a, b; + a = packed_max(atom[0], atom[1]); + b = packed_max(atom[2], atom[3]); + + wmax = packed_max(a, b); + + a = packed_min(atom[0], atom[1]); + b = packed_min(atom[2], atom[3]); + + wmin = packed_min(a, b); + + // Reduce the max among a group of threads + // Note: This is basically 2 blocks of values setup as the + // upper/lower halves of the f16x2_t + for (int i = 1; i < kThreadGroupSize; i <<= 1) { + int x = __shfl_down(wmax, i); + wmax = packed_max(wmax, x); + + int y = __shfl_down(wmin, i); + wmin = packed_min(wmin, y); + } + wblockmax = packed_abs_max(wmax, wmin); + // Share with the cohort + wblockmax = __shfl(wblockmax, group_leader); + return wblockmax; +} + +__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr, + uint32_t flag) { + __atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE); +} + +__quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr, + uint32_t flag) { + while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) { + } +} + +} // namespace quickreduce \ No newline at end of file diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h new file mode 100644 index 0000000000..4fe4c44be7 --- /dev/null +++ b/csrc/quickreduce/quick_reduce.h @@ -0,0 +1,196 @@ +#pragma once + +#include +#include +#include "quick_reduce_impl.cuh" + +#define HIP_CHECK(err) \ + do { \ + hipError_t err_ = (err); \ + if (err_ != hipSuccess) { \ + std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, \ + hipGetErrorString(err_)); \ + throw std::runtime_error("HIP error"); \ + } \ + } while (0) + +namespace quickreduce { +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +template +__global__ __quickreduce_launch_bounds_two_shot__ static void +allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks, + int rank, uint8_t** dbuffer_list, + uint32_t data_offset, uint32_t flag_color) { + int block = blockIdx.x; + int grid = gridDim.x; + + while (block < num_blocks) { + AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, + flag_color); + block += grid; + flag_color++; + } +} + +#define TWOSHOT_DISPATCH(__codec) \ + if (world_size == 2) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype_twoshot), \ + dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ + num_blocks, rank, dbuffer_list, data_offset, \ + flag_color); \ + } else if (world_size == 4) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype_twoshot), \ + dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ + num_blocks, rank, dbuffer_list, data_offset, \ + flag_color); \ + } else if (world_size == 8) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL((allreduce_prototype_twoshot), \ + dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ + num_blocks, rank, dbuffer_list, data_offset, \ + flag_color); \ + } + +enum QuickReduceQuantLevel { + F16 = 0, + INT8 = 1, + INT6 = 2, + INT4 = 3, +}; + +struct DeviceComms { + // Max problem size is 2GB (in bytes) or half of uint32_t max value. + int64_t kMaxProblemSize = + static_cast(std::numeric_limits::max()) + 1; + + // Max TP-8 + static int constexpr kMaxWorldSize = 8; + + bool initialized = false; + uint32_t flag_color = 1; + int world_size; + int rank; + + uint8_t* dbuffer; + uint8_t** dbuffer_list; + hipIpcMemHandle_t buffer_ipc_handle; + std::vector all_buffer_ipc_handles; + std::vector buffer_list; + uint32_t data_offset; + + DeviceComms() : initialized(false), world_size(1), rank(0) {} + ~DeviceComms() { destroy(); } + + void init(int world_size, int rank, + std::optional max_problem_size = std::nullopt) { + destroy(); + this->world_size = world_size; + this->rank = rank; + if (max_problem_size.has_value() && max_problem_size.value() > 0) { + this->kMaxProblemSize = max_problem_size.value(); + } + // Allocate buffer size for worst case: F16 2-stage buffer. + uint32_t flags_buffer_size = + 2 * world_size * kMaxNumBlocks * sizeof(uint32_t); + static int64_t data_buffer_size = 2 * this->kMaxProblemSize; + int64_t total_buffer_size = flags_buffer_size + data_buffer_size; + data_offset = flags_buffer_size; + HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, + hipDeviceMallocUncached)); + + // Clear the flags buffer. + HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size)); + + // Device-side list of IPC buffers. + buffer_list.resize(world_size); + HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*))); + + // Create IPC handles for rank's communication buffer. + all_buffer_ipc_handles.resize(world_size); + HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer)); + + initialized = true; + } + int get_world_size() { return world_size; } + int get_rank() { return rank; } + bool status() { return initialized; } + hipIpcMemHandle_t const get_handle() { return buffer_ipc_handle; } + + void destroy() { + if (initialized) { + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i])); + } + } + + HIP_CHECK(hipFree(dbuffer)); + HIP_CHECK(hipFree(dbuffer_list)); + + initialized = false; + } + } + + void open_ipc_handles(std::vector const& ipc_handles) { + assert(ipc_handles.size() == all_buffer_ipc_handles.size()); + for (int i = 0; i < world_size; i++) { + all_buffer_ipc_handles[i] = ipc_handles[i]; + } + + // Open device memory access to the IPC communication buffers. + // Note: For our own rank, we do not need to open a handle. + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK(hipIpcOpenMemHandle((void**)&buffer_list[i], + all_buffer_ipc_handles[i], + hipIpcMemLazyEnablePeerAccess)); + } else { + buffer_list[i] = dbuffer; + } + } + + HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), + world_size * sizeof(uint8_t*), hipMemcpyHostToDevice)); + } + + template + void allreduce(T const* A, T* B, uint32_t N, int quant_level, + hipStream_t stream) { + if (world_size != 2 && world_size != 4 && world_size != 8) { + throw std::runtime_error("All Reduce not supported for world_size = " + + std::to_string(world_size)); + } + + // Configuration. + uint32_t msg_size = N * sizeof(T); + uint32_t num_blocks = divceil(msg_size, kTileSize); + uint32_t grid = min(kMaxNumBlocks, num_blocks); + auto quant_level_ = static_cast(quant_level); + switch (quant_level_) { + case QuickReduceQuantLevel::INT8: + TWOSHOT_DISPATCH(CodecQ8) + break; + case QuickReduceQuantLevel::INT6: + TWOSHOT_DISPATCH(CodecQ6) + break; + case QuickReduceQuantLevel::INT4: + TWOSHOT_DISPATCH(CodecQ4) + break; + default: + TWOSHOT_DISPATCH(CodecFP) + break; + } + HIP_CHECK(cudaGetLastError()); + // Rotate the flag color. + flag_color += divceil(N, grid); + } +}; + +} // namespace quickreduce \ No newline at end of file diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh new file mode 100644 index 0000000000..17816c552d --- /dev/null +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -0,0 +1,698 @@ +#pragma once + +#include +#include "base.h" + +namespace quickreduce { + +struct CodecBase { + const int thread; + const int rank; + const int group_leader; + __quickreduce_device_inline__ CodecBase(int thread, int rank) + : thread(thread), + rank(rank), + group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) { + set_fp16_ovfl(true); + } +}; + +// Default full precision codec. +template +struct CodecFP : public CodecBase { + static constexpr int kWorldSize = world_size; + static constexpr int kRankAtoms = kAtoms / kWorldSize; + + // Codec tile size process by this workgroup. + // Each thread processes atoms of f16x8_t (16B). + static constexpr int kRankTransmittedTileSize = + kBlockSize * kRankAtoms * sizeof(int32x4_t); + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTransmittedTileSize must be 16B aligned."); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; + + __quickreduce_device_inline__ CodecFP(int thread, int rank) + : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + const int32x4_t* __restrict__ data) { + for (int i = 0; i < kRankAtoms; i++) { + __builtin_nontemporal_store(data[i], send_buffer + thread); + send_buffer += kAtomStride; + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int i = 0; i < kRankAtoms; i++) { + data[i] = __builtin_nontemporal_load(*recv_buffer + thread); + *recv_buffer += kAtomStride; + } + } +}; + +// Int4 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int4 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ4 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int4x8_t (4B) and a fp16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 1152; + static constexpr int kRankTileScaleOffset = 1024; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTransmittedTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/8.0h, -1/8.0h}, f16x2_t + static constexpr int kScaleFactor = + std::is_same::value ? 0xB000B000 : 0xBE00BE00; + + // {1e-7, 1e-7}, f16x2_t + static constexpr int kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-8, -8}, f16x2_t + static constexpr int kRangeMin = + std::is_same::value ? 0xC800C800 : 0xC100C100; + + // {+7, +7}, f16x2_t + static constexpr int kRangeMax = + std::is_same::value ? 0x47004700 : 0x40E040E0; + + // {+8, +8}, int16x2_t + static constexpr int kRangeBias = 0x00080008; + + __quickreduce_device_inline__ CodecQ4(int thread, int rank) + : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + const int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q4 into int32_t + int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + int32_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q4 into f16x8_t + int32x4_t w; + { + static constexpr uint kMask000F = 0x000F000F; + static constexpr uint kHalf2_1024 = + 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1032 = + 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t + + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same::value) { + int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; + w[i] = packed_add(q4, kHalf2_1032); + } else { + int32_t int16_2 = (qw >> (i * 4)) & kMask000F; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + data[k] = w; + } + } +}; + +// Int6 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int6 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ6 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of fp16x8_t (16B), + // into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 1664; + static constexpr int kRankTileQ2Offset = 1024; + static constexpr int kRankTileScaleOffset = 1536; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTransmittedTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/32.0h, -1/32.0h}, fp16x2_t + static constexpr int kScaleFactor = + std::is_same::value ? 0xA800A800 : 0xBD00BD00; + + // {1e-7, 1e-7}, fp16x2_t + static constexpr int kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-32, -32}, fp16x2_t + static constexpr int kRangeMin = + std::is_same::value ? 0xD000D000 : 0xC200C200; + + // {+31, +31}, fp16x2_t + static constexpr int kRangeMax = + std::is_same::value ? 0x4FC04FC0 : 0x41F841F8; + + // {+32, +32}, int16x2_t + static constexpr int kRangeBias = 0x00200020; + + __quickreduce_device_inline__ CodecQ6(int thread, int rank) + : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + const int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q6 into int32_t + int16_t + uint32_t q4w; + uint16_t q2w = 0; + q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) | + ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12); + { + int16_t* tw = reinterpret_cast(&q); +#pragma unroll + for (int i = 0; i < 8; i++) { + q2w |= (tw[i] >> 4) << (i * 2); + } + } + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = + reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(q4w, q4w_ptr); + __builtin_nontemporal_store(q2w, q2w_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + uint32_t* q4w_ptr = reinterpret_cast(atom_ptr) + thread; + uint16_t* q2w_ptr = + reinterpret_cast(atom_ptr + kRankTileQ2Offset) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + uint32_t q4w = __builtin_nontemporal_load(q4w_ptr); + uint16_t q2w = __builtin_nontemporal_load(q2w_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q6 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask000F = 0x000F000F; + static uint constexpr kHalf2_1024 = + 0x64006400; // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1056 = + 0xE420E420; // {-1056.0, -1056.0}, fp16x2_t + +#pragma unroll + for (int i = 0; i < 4; i++) { + int32_t q4 = q4w & kMask000F; + int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14); + q4w >>= 4; + q2w >>= 4; + if constexpr (std::is_same::value) { + int32_t q6 = q4 | (q2 << 4) | kHalf2_1024; + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(w[i]) + : "v"(q6), "v"(kHalf2_1056)); + } else { + int32_t int16_2 = q4 | (q2 << 4); + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + // That's pretty much it... + data[k] = w; + } + } +}; + +// Int8 symmetric quantization codec. +// We quantize the FP16 data to block-scaled Int8 in blocks of 4 * +// kThreadGroupSize. +template +struct CodecQ8 : public CodecBase { + static constexpr int kWorldSize = world_size; + + // Codec tile size process by this workgroup. + // Each threads processes a fragment of f16x8_t (16B), + // into a int8x8_t (8B) and a f16 scale shared among 32 values. + static constexpr int kRankAtoms = kAtoms / kWorldSize; + static constexpr int kRankTileStride = 2176; + static constexpr int kRankTileScaleOffset = 2048; + static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; + static_assert(kRankTransmittedTileSize % 16 == 0, + "kRankTileSize must be 16B aligned."); + + static constexpr int kRankBufferTileStride = + kRankTileStride / sizeof(int32x4_t); + + // Total tile size for the collective communication. + static constexpr int kTransmittedTileSize = + kRankTransmittedTileSize * kWorldSize; + + // Constants configuration + + // {-1/128.0h, -1/128.0h}, f16x2_t + static constexpr int kScaleFactor = + std::is_same::value ? 0xA000A000 : 0xBC00BC00; + + // {1e-7, 1e-7}, f16x2_t + static constexpr int kScaleEpsilon = + std::is_same::value ? 0x00010001 : 0x33D733D7; + + // {-128, -128}, f16x2_t + static constexpr int kRangeMin = + std::is_same::value ? 0xD800D800 : 0xC300C300; + // {+127, +127}, f16x2_t + static constexpr int kRangeMax = + std::is_same::value ? 0x57F057F0 : 0x42FE42FE; + + // {+128, +128}, int16x2_t + static constexpr int kRangeBias = 0x00800080; + + __quickreduce_device_inline__ CodecQ8(int thread, int rank) + : CodecBase(thread, rank) {} + + __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, + int32x4_t const* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + int32x4_t const atom = data[k]; + // Compute the absolute maximum of the atom in the thread group + // In 2 blocks of values, upper/lower halves of the f16x2_t + int wblockmax = group_abs_max(atom); + + // Derive scales + int decoding_scale; + int encoding_scale; + decoding_scale = packed_mul(wblockmax, kScaleFactor); + encoding_scale = packed_add(decoding_scale, kScaleEpsilon); + encoding_scale = packed_rcp(encoding_scale); + + // Apply scales to get quantized values + int32x4_t w; + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(atom[i], encoding_scale); + w[i] = packed_max(w[i], kRangeMin); + w[i] = packed_min(w[i], kRangeMax); + } + + // Convert from f16x2_t to uint16x2_t + int32x4_t q; + { + int16_t* qi = reinterpret_cast(&q); + T* wh = reinterpret_cast(&w); + for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); + + for (int i = 0; i < 4; i++) { + q[i] = packed_add(q[i], kRangeBias); + } + } + + // Pack 8 x q8 into int32x2_t + int32x2_t qw; + qw[0] = q[0] | (q[1] << 8); + qw[1] = q[2] | (q[3] << 8); + + // Write quantized atom to send_buffer + // note: only the group leader stores the scale + uint8_t* atom_ptr = + reinterpret_cast(send_buffer + k * kRankBufferTileStride); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + __builtin_nontemporal_store(qw, qw_ptr); + if (threadIdx.x == group_leader) { + __builtin_nontemporal_store(decoding_scale, qs_ptr); + } + } + } + + __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, + int32x4_t* __restrict__ data) { + for (int k = 0; k < kRankAtoms; k++) { + // Directly read quantized atom from recv_buffer + uint8_t* atom_ptr = reinterpret_cast(*recv_buffer); + int32x2_t* qw_ptr = reinterpret_cast(atom_ptr) + thread; + int* qs_ptr = reinterpret_cast(atom_ptr + kRankTileScaleOffset) + + (thread / 8); + + int32x2_t qw = __builtin_nontemporal_load(qw_ptr); + int qs = __builtin_nontemporal_load(qs_ptr); + + *recv_buffer += kRankBufferTileStride; + + // Unpack q8 into fp16x8_t + int32x4_t w; + { + static uint constexpr kMask00FF = 0x00FF00FF; + + // {1024.0, 1024.0}, fp16x2_t + static uint constexpr kHalf2_1024 = 0x64006400; + + // {-1152.0, -1152.0}, fp16x2_t + static uint constexpr kHalf2_1152 = 0xE480E480; + +#pragma unroll + for (int i = 0; i < 4; i++) { + if constexpr (std::is_same::value) { + int32_t q8 = + ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024; + w[i] = packed_add(q8, kHalf2_1152); + } else { + int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF; + int16_t low = static_cast(int16_2 & 0xFFFF); + int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); + nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); + int32_t packed_bf16 = *reinterpret_cast(&bf2); + w[i] = packed_add(packed_bf16, kRangeMin); + } + } + } + + // Apply decoding scales + for (int i = 0; i < 4; i++) { + w[i] = packed_mul(w[i], qs); + } + + data[k] = w; + } + } +}; + +// Twoshot All Reduce +template +struct AllReduceTwoshot { + static_assert(sizeof(T) == 2); + + static constexpr int kWorldSize = Codec::kWorldSize; + + __device__ static void run( + T const* __restrict__ input, T* __restrict__ output, + uint32_t const N, // number of elements + int const block, // block index + int const rank, // rank index + uint8_t** __restrict__ buffer_list, // communication buffers + uint32_t const data_offset, // offset to start of the data buffer + uint32_t flag_color) { + // Topology + int thread = threadIdx.x + threadIdx.y * kWavefront; + uint8_t* rank_buffer = buffer_list[rank]; + Codec codec(thread, rank); + int block_id = blockIdx.x; + int grid_size = gridDim.x; + // -------------------------------------------------------- + // Read input into registers + int32x4_t tA[kAtoms]; + + BufferResource src_buffer(const_cast(input), N * sizeof(T)); + uint32_t src_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); + src_offset += kAtomStride * sizeof(int32x4_t); + if constexpr (cast_bf2half) { + const nv_bfloat162* bf_buf = + reinterpret_cast(&tA[i]); + half2 half_buf[4]; +#pragma unroll + for (int j = 0; j < 4; ++j) { + float2 f = __bfloat1622float2(bf_buf[j]); + half_buf[j] = __float22half2_rn(f); + } + tA[i] = *reinterpret_cast(half_buf); + } + } + + // -------------------------------------------------------- + // Phase-1A: Write segment data into the communication buffer of the target + // rank responsible for this segment. + uint32_t comm_data0_offset = + data_offset + block_id * Codec::kTransmittedTileSize; + uint32_t comm_data1_offset = + grid_size * Codec::kTransmittedTileSize + comm_data0_offset; + + uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t)); + uint32_t comm_flags1_offset = + grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset; + + for (int r = 0; r < kWorldSize; r++) { + int32x4_t* send_buffer = + reinterpret_cast(buffer_list[r] + comm_data0_offset + + rank * Codec::kRankTransmittedTileSize); + codec.send(send_buffer, &tA[r * Codec::kRankAtoms]); + } + + __syncthreads(); + if (thread < kWorldSize) { + int r = thread; + uint32_t* flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags0_offset + rank * sizeof(uint32_t)); + set_sync_flag(flag_ptr, flag_color); + } + // -------------------------------------------------------- + // Phase-1B: Reduce the segment data from the communication buffers. + int32x4_t tR[Codec::kRankAtoms] = {}; + { + // Read the data from the communication buffer. + int32x4_t* recv_buffer = + reinterpret_cast(rank_buffer + comm_data0_offset); + uint32_t* flag_ptr = + reinterpret_cast(rank_buffer + comm_flags0_offset); + + for (int r = 0; r < kWorldSize; r++) { + // Wait for the flags to be set. + if (thread == 0) { + wait_sync_flag(&flag_ptr[r], flag_color); + } + __syncthreads(); + + // note: we reuse tA as temp buffer here + codec.recv(&recv_buffer, tA); + + for (int i = 0; i < Codec::kRankAtoms; i++) { + packed_assign_add(&tR[i], &tA[i]); + } + } + } + + // Phase-2: Write the reduced segment to every other rank + for (int r = 0; r < kWorldSize; r++) { + int32x4_t* send_buffer = + reinterpret_cast(buffer_list[r] + comm_data1_offset + + rank * Codec::kRankTransmittedTileSize); + codec.send(send_buffer, tR); + } + + __syncthreads(); + if (thread < kWorldSize) { + int r = thread; + uint32_t* flag_ptr = reinterpret_cast( + buffer_list[r] + comm_flags1_offset + rank * sizeof(uint32_t)); + set_sync_flag(flag_ptr, flag_color); + } + + // Phase-2: Read the gather segments from the rank's communication buffer. + { + // Read the data from the communication buffer. + int32x4_t* recv_buffer = + reinterpret_cast(rank_buffer + comm_data1_offset); + uint32_t* flag_ptr = + reinterpret_cast(rank_buffer + comm_flags1_offset); + + for (int r = 0; r < kWorldSize; r++) { + // Wait for the flags to be set. + if (thread == 0) { + wait_sync_flag(&flag_ptr[r], flag_color); + } + __syncthreads(); + + // Gather all reduced and final rank segments into tA. + codec.recv(&recv_buffer, &tA[r * Codec::kRankAtoms]); + } + } + + // -------------------------------------------------------- + // Write the result to output. + BufferResource dst_buffer(output, N * sizeof(T)); + uint32_t dst_offset = block * kTileSize + thread * sizeof(int32x4_t); + + for (int i = 0; i < kAtoms; i++) { + if constexpr (cast_bf2half) { + const half2* half_buf = reinterpret_cast(&tA[i]); + nv_bfloat162 bf16_buf[4]; +#pragma unroll + for (int j = 0; j < 4; ++j) { + float2 f = __half22float2(half_buf[j]); + bf16_buf[j] = __float22bfloat162_rn(f); + } + buffer_store_dwordx4(*reinterpret_cast(bf16_buf), + dst_buffer.descriptor, dst_offset, 0, 0); + } else { + buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0); + } + dst_offset += kAtomStride * sizeof(int32x4_t); + } + } +}; + +} // namespace quickreduce \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 1a1896b4c1..8bb71cad29 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -725,6 +725,24 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle); custom_ar.def("free_shared_buffer", &free_shared_buffer); +#ifdef USE_ROCM + // Quick Reduce all-reduce kernels + custom_ar.def( + "qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool " + "cast_bf2half) -> ()"); + custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce); + + custom_ar.def("init_custom_qr", &init_custom_qr); + custom_ar.def("qr_destroy", &qr_destroy); + + custom_ar.def("qr_get_handle", &qr_get_handle); + + custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()"); + custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles); + + // Max input size in bytes + custom_ar.def("qr_max_size", &qr_max_size); +#endif } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py new file mode 100644 index 0000000000..a4added291 --- /dev/null +++ b/tests/distributed/test_quick_all_reduce.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random + +import pytest +import ray +import torch +import torch.distributed as dist + +from vllm.distributed.communication_op import ( # noqa + tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, + get_tp_group, graph_capture) +from vllm.platforms import current_platform + +from ..utils import (ensure_model_parallel_initialized, + init_test_distributed_environment, multi_process_parallel) + +torch.manual_seed(42) +random.seed(44) +# Size over 8MB is sufficient for custom quick allreduce. +test_sizes = [ + random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8) +] +for i, v in enumerate(test_sizes): + test_sizes[i] -= v % 8 + + +@ray.remote(num_gpus=1, max_calls=1) +def graph_quickreduce( + monkeypatch: pytest.MonkeyPatch, + tp_size, + pp_size, + rank, + distributed_init_port, +): + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + ensure_model_parallel_initialized(tp_size, pp_size) + group = get_tensor_model_parallel_group().device_group + + # A small all_reduce for warmup. + # this is needed because device communicators might be created lazily + # (e.g. NCCL). This will ensure that the communicator is initialized + # before any communication happens, so that this group can be used for + # graph capture immediately. + data = torch.zeros(1) + data = data.to(device=device) + torch.distributed.all_reduce(data, group=group) + torch.cuda.synchronize() + del data + + # we use the first group to communicate once + # and the second group to communicate twice + # and so on + # this is used to demonstrate that each group can + # communicate independently + num_communication = rank // tp_size + 1 + + for sz in test_sizes: + for dtype in [torch.float16, torch.bfloat16]: + with graph_capture(device=device) as graph_capture_context: + inp1 = torch.randint(1, + 23, (sz, ), + dtype=dtype, + device=torch.cuda.current_device()) + inp2 = torch.randint(-23, + 1, (sz, ), + dtype=dtype, + device=torch.cuda.current_device()) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, + stream=graph_capture_context.stream): + for _ in range(num_communication): + out1 = tensor_model_parallel_all_reduce(inp1) + dist.all_reduce(inp1, group=group) + out2 = tensor_model_parallel_all_reduce(inp2) + dist.all_reduce(inp2, group=group) + graph.replay() + torch.testing.assert_close(out1, inp1, atol=2.5, rtol=0.1) + torch.testing.assert_close(out2, inp2, atol=2.5, rtol=0.1) + + +@ray.remote(num_gpus=1, max_calls=1) +def eager_quickreduce( + monkeypatch: pytest.MonkeyPatch, + tp_size, + pp_size, + rank, + distributed_init_port, +): + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + # Size over 8MB is sufficient for custom quick allreduce. + sz = 16 * 1024 * 1024 + fa = get_tp_group().device_communicator.qr_comm + inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)], + dtype=torch.float16, + device=device) + out = fa.quick_all_reduce(inp) + torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1) + + inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)], + dtype=torch.bfloat16, + device=device) + out = fa.quick_all_reduce(inp) + torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1) + + +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="only test quick allreduce for rocm") +@pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"]) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) +@pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce]) +def test_custom_quick_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, + pipeline_parallel_size, test_target, + quant_mode): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + + monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode) + + multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, + test_target) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d5a4128438..215f35bad3 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1748,6 +1748,38 @@ def free_shared_buffer(ptr: int) -> None: torch.ops._C_custom_ar.free_shared_buffer(ptr) +# quick all reduce +def init_custom_qr(rank: int, + world_size: int, + qr_max_size: Optional[int] = None) -> int: + return torch.ops._C_custom_ar.init_custom_qr(rank, world_size, qr_max_size) + + +def qr_destroy(fa: int) -> None: + torch.ops._C_custom_ar.qr_destroy(fa) + + +def qr_all_reduce(fa: int, + inp: torch.Tensor, + out: torch.Tensor, + quant_level: int, + cast_bf2half: bool = False) -> None: + torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level, + cast_bf2half) + + +def qr_get_handle(fa: int) -> torch.Tensor: + return torch.ops._C_custom_ar.qr_get_handle(fa) + + +def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None: + return torch.ops._C_custom_ar.qr_open_handles(fa, handles) + + +def qr_max_size() -> int: + return torch.ops._C_custom_ar.qr_max_size() + + def get_flash_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 055d91690e..3958d566b1 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -8,6 +8,7 @@ from torch.distributed import ProcessGroup import vllm.envs as envs from vllm.logger import init_logger +from vllm.platforms import current_platform from .base_device_communicator import DeviceCommunicatorBase @@ -41,6 +42,8 @@ class CudaCommunicator(DeviceCommunicatorBase): CustomAllreduce) from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) + from vllm.distributed.device_communicators.quick_all_reduce import ( + QuickAllReduce) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -50,6 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ) self.ca_comm: Optional[CustomAllreduce] = None + self.qr_comm: Optional[QuickAllReduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( @@ -57,6 +61,14 @@ class CudaCommunicator(DeviceCommunicatorBase): device=self.device, ) + if current_platform.is_rocm(): + # Initialize a custom quick all-reduce implementation for AMD. + # Quick reduce is designed as a complement to custom allreduce. + # Based on quickreduce (https://github.com/mk1-project/quickreduce). + # If it's a rocm, 'use_custom_allreduce==True' means it must + # currently be an MI300 series. + self.qr_comm = QuickAllReduce(group=self.cpu_group, + device=self.device) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": @@ -79,8 +91,14 @@ class CudaCommunicator(DeviceCommunicatorBase): raise ValueError(f"Unknown all2all backend: {all2all_backend}") def all_reduce(self, input_): - # always try custom allreduce first, - # and then pynccl. + # always try quick reduce first, then custom allreduce, + # and then pynccl. (quick reduce just for ROCM MI3*) + qr_comm = self.qr_comm + if qr_comm is not None and not qr_comm.disabled and \ + qr_comm.should_quick_allreduce(input_): + out = qr_comm.quick_all_reduce(input_) + assert out is not None + return out ca_comm = self.ca_comm if ca_comm is not None and not ca_comm.disabled and \ ca_comm.should_custom_ar(input_): diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py new file mode 100644 index 0000000000..c61231e2d3 --- /dev/null +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -0,0 +1,278 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from enum import Enum +from typing import Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.config import get_current_vllm_config +from vllm.distributed.parallel_state import in_the_same_node_as +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import cuda_device_count_stateless + +logger = init_logger(__name__) + +try: + ops.qr_max_size() + quick_ar = True +except Exception: + # For CPUs and CUDA + quick_ar = False + + +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or (inp.storage().nbytes() - + inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size()) + + +class QuickReduceRegime(Enum): + FP = 0 + INT8 = 1 + INT6 = 2 + INT4 = 3 + NONE = 4 + + +MB = 1024 * 1024 + + +class QuickAllReduce: + + _SUPPORTED_WORLD_SIZES = [2, 4, 8] + _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] + # The following data is based on kernel tests. + # In this order [FP, INT8, INT6, INT4]. + _QR_MIN_SIZE = { + (torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB], + (torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB], + (torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB], + (torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB], + (torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB], + (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB], + } + + def __init__(self, group: ProcessGroup, + device: Union[int, str, torch.device]) -> None: + """ + Custom allreduce provides non-destructive acceleration and is + available for CUDA and ROCm MI300 series. + + Custom quick allreduce leverages quantization for further + acceleration on ROCm. It currently supports Q8, Q6, and Q4 + quantization formats and FP(float16, bfloat16). + + Quick allreduce is designed as a complement to custom allreduce. + Its initialization requires even stricter conditions. + + Only the ROCm MI300 series is supported for quick allreduce at + this time. + + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self.disabled = True + if not self._rocm_arch_available(): + logger.debug( + "Custom quick allreduce is only supported on ROCm MI300 series." + ) + return + + if not quick_ar: + # disable because of missing quick reduce library + # e.g. in a cuda environment + logger.info("Custom quick allreduce is disabled because " + "of missing custom quick allreduce library") + return + + self.group = group + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "Custom quick allreduce should be attached to a non-NCCL group.") + if not all(in_the_same_node_as(group, source_rank=0)): + # No need to initialize custom quick allreduce for + # multi-node case. + logger.warning("Custom quick allreduce is disabled because this " + "process group spans across nodes.") + return + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + self.rank = rank + self.world_size = world_size + if world_size == 1: + # No need to initialize QuickReduce for single GPU case. + return + + if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom quick allreduce is disabled due to an " + "unsupported world size: %d. Supported world sizes: %s.", + world_size, str(QuickAllReduce._SUPPORTED_WORLD_SIZES)) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(cuda_device_count_stateless())) + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], + dtype=torch.int, + device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(self.world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom quick allreduce is not supported + # this checks hardware and driver support for NVLink + assert current_platform.is_cuda_alike() + self.fully_connected = current_platform.is_fully_connected( + physical_device_ids) + if self.world_size > 2 and not self.fully_connected: + logger.debug( + "Custom quick allreduce is disabled because it's not supported " + "on more than two PCIe-only GPUs. ") + return + + self.init_quick_all_reduce() + + def init_quick_all_reduce(self): + # On RocM, bfloat16 kernels are slower than fp16 + # due to slower match operations + # If environment variable is set to 1, we convert input to fp16 + self.use_fp16_kernels = envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16 + regime_str = envs.VLLM_ROCM_QUICK_REDUCE_QUANTIZATION + if regime_str not in QuickReduceRegime.__members__: + logger.warning( + "Custom quick allreduce:", + f"Invalid quantization level: {regime_str}. " + "Supported levels: " + f"{list(QuickReduceRegime.__members__.keys())}") + return + + if regime_str == "NONE": + logger.debug("Custom quick allreduce is disabled based " + "on env variable " + "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION='NONE'") + return + self.qr_quant_level = QuickReduceRegime[regime_str] + vllm_config = get_current_vllm_config() + if vllm_config is not None and \ + hasattr(vllm_config, "model_config") and \ + hasattr(vllm_config.model_config, "dtype"): + dtype = vllm_config.model_config.dtype + if dtype not in [torch.float16, torch.bfloat16]: + logger.debug( + "Custom quick allreduce disabled: only supports " + "float16 and float16, but get %s.", dtype) + return + + if dtype == torch.bfloat16 and self.use_fp16_kernels: + logger.info( + "Custom quick allreduce: BF16 inputs will be converted " + "to FP16 to improve performance. set " + "envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16=0 " + "to turn off.") + + # VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB + qr_max_size = envs.VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB + if qr_max_size is not None: + if qr_max_size < 1: + logger.info( + "You should not set a max_size smaller than 1MB, which can " + "lead to error or degradation to custom allreduce or rccl." + ) + qr_max_size = qr_max_size * MB + self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size) + self.qr_max_size = qr_max_size if qr_max_size is not None \ + else ops.qr_max_size() + self.create_shared_buffer() + self.disabled = False + + def _rocm_arch_available(self): + if not current_platform.is_rocm(): + return False + try: + props = torch.cuda.get_device_properties(0) + gcn_arch = getattr(props, "gcnArchName", "") + supported_archs = ['gfx94', 'gfx95'] + return any(gfx in gcn_arch for gfx in supported_archs) + except Exception as e: + logger.warning("Failed to determine ROCm for quick allreduce: %s", + e) + return False + + def create_shared_buffer(self): + """ + Creates a shared buffer for quickreduce. + Has to be called after init_custom_qr + """ + handle = ops.qr_get_handle(self._ptr) + world_size = dist.get_world_size(group=self.group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=self.group) + ops.qr_open_handles(self._ptr, handles) + + def should_quick_allreduce(self, inp: torch.Tensor): + """ + Check if quickreduce is available + """ + if self.disabled: + return False + if inp.dtype not in self._SUPPORTED_DTYPES: + return False + inp_size = inp.numel() * inp.element_size() + # custom quick allreduce requires input byte size to be + # multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + dtype = inp.dtype + if self.use_fp16_kernels: + dtype = torch.float16 + return inp_size <= self.qr_max_size and \ + inp_size >= self._QR_MIN_SIZE[(dtype, self.world_size)]\ + [self.qr_quant_level.value] + + def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): + """Performs an out-of-place custom quick all reduce.""" + # quick allreduce doesn't require a separate graph mode, + # as QR uses static IPC buffer. + if out is None: + out = torch.empty_like(inp) + ops.qr_all_reduce(self._ptr, inp, out, self.qr_quant_level.value, + self.use_fp16_kernels) + return out + + def close(self): + if not self.disabled and getattr(self, "_ptr", None): + if ops is not None: + ops.qr_destroy(self._ptr) + self._ptr = 0 + self.disabled = True + + def __del__(self): + self.close() diff --git a/vllm/envs.py b/vllm/envs.py index c9c81603a7..a3f19c7ee5 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -135,6 +135,9 @@ if TYPE_CHECKING: VLLM_KV_CACHE_LAYOUT: Optional[str] = None VLLM_COMPUTE_NANS_IN_LOGITS: bool = False VLLM_USE_NVFP4_CT_EMULATIONS: bool = False + VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE" + VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True + VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None def get_default_cache_root(): @@ -690,6 +693,31 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")), + # Custom quick allreduce kernel for MI3* cards + # Choice of quantization level: FP, INT8, INT6, INT4 or NONE + # Recommended for large models to get allreduce + "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION": + lambda: os.getenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "NONE").upper(), + + # Custom quick allreduce kernel for MI3* cards + # Due to the lack of the bfloat16 asm instruction, bfloat16 + # kernels are slower than fp16, + # If environment variable is set to 1, the input is converted to fp16 + "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16": + lambda: + (os.getenv("VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "True").lower() in + ("true", "1")), + + # Custom quick allreduce kernel for MI3* cards. + # Controls the maximum allowed number of data bytes(MB) for custom quick + # allreduce communication. + # Default: 2048 MB. + # Data exceeding this size will use either custom allreduce or RCCL + # communication. + "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB": + lambda: maybe_convert_int( + os.environ.get("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", None)), + # If set, when running in Quark emulation mode, do not dequantize the # weights at load time. Instead, dequantize weights on-the-fly during # kernel execution.