[Feature] add quick all reduce (#19744)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: Haoyang Li <Haoyang.Li@amd.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
li haoyang 2025-06-27 11:54:24 +08:00 committed by GitHub
parent 44d2e6af63
commit 0740e29b66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1879 additions and 2 deletions

View File

@ -648,6 +648,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# if CUDA endif
endif()
if (VLLM_GPU_LANG STREQUAL "HIP")
# Add QuickReduce kernels
list(APPEND VLLM_EXT_SRC
"csrc/custom_quickreduce.cu"
)
# if ROCM endif
endif()
message(STATUS "Enabling C extension.")
define_gpu_extension_target(
_C

114
csrc/custom_quickreduce.cu Normal file
View File

@ -0,0 +1,114 @@
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#ifdef USE_ROCM
#include "quickreduce/quick_reduce.h"
quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size,
std::optional<int64_t> qr_max_size) {
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size == 6)
throw std::invalid_argument("world size == 6 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");
quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms();
fptr->init(world_size, rank, qr_max_size);
return (quickreduce::fptr_t)fptr;
}
void qr_destroy(quickreduce::fptr_t _fa) {
if (_fa) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
fa->destroy();
delete fa;
}
}
torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
hipIpcMemHandle_t handle = fa->get_handle();
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto data_handle =
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t));
return data_handle;
}
void qr_open_handles(quickreduce::fptr_t _fa,
const std::vector<torch::Tensor>& handles) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
std::vector<hipIpcMemHandle_t> ipc_handles;
ipc_handles.reserve(handles.size());
for (auto& handle : handles) {
// Ensure the tensor is on the same device as the current device.
hipIpcMemHandle_t ipc_handle;
std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t));
ipc_handles.push_back(ipc_handle);
}
fa->open_ipc_handles(ipc_handles);
}
void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp,
torch::Tensor& out, int64_t quant_level, bool cast_bf2half) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize);
if (out.scalar_type() == at::ScalarType::Half) {
fa->allreduce<half, false>(reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
out.numel(), quant_level, stream);
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
if (cast_bf2half) {
fa->allreduce<half, true>(reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
out.numel(), quant_level, stream);
} else {
fa->allreduce<quickreduce::nv_bfloat16, false>(
reinterpret_cast<quickreduce::nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<quickreduce::nv_bfloat16*>(out.data_ptr()),
out.numel(), quant_level, stream);
}
} else {
throw std::runtime_error(
"quick allreduce only supports float16 and bfloat16");
}
}
int64_t qr_max_size() {
// The default is 2GB (2,147,483,648 bytes)
return static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
}
#define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 2>, \
cast_bf2half>; \
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 4>, \
cast_bf2half>; \
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 8>, cast_bf2half>;
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true)
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false)
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false)
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false)
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false)
#endif // USE_ROCM

View File

@ -360,3 +360,14 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
int64_t size);
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);
#ifdef USE_ROCM
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
std::optional<int64_t> qr_max_size = std::nullopt);
void qr_destroy(fptr_t _fa);
torch::Tensor qr_get_handle(fptr_t _fa);
void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t quant_level, bool cast_bf2half = false);
int64_t qr_max_size();
#endif

338
csrc/quickreduce/base.h Normal file
View File

@ -0,0 +1,338 @@
#pragma once
#include <cstdint>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#define __quickreduce_device_inline__ __device__ __forceinline__
#define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4)
#define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4)
namespace quickreduce {
typedef __hip_bfloat16 nv_bfloat16;
typedef __hip_bfloat162 nv_bfloat162;
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
// Setup acquire-release semantics for vector memory reads (mubuf instruction)
// as per architecture.
#if defined(__gfx942__)
// CDNA3: Scope bits sc0, sc1
#define MUBUF_ACQUIRE 16
#define MUBUF_RELEASE 16
#elif (defined(__gfx908__) || defined(__gfx90a__))
// CDNA1 and CDNA2 - glc bit
#define MUBUF_ACQUIRE 1
#define MUBUF_RELEASE 0
#endif
static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t
// Number of atoms (4xf16x2_t) processed by a single thread
static constexpr int kAtoms = 8;
// We use a workgroup of 256 threads
static constexpr int kBlockSize = 256;
static constexpr int kAtomStride = kBlockSize;
// Size and atom stride of source/destination data that the block will
// process.
// Workgroup scope = Tile = (256 threads x 8 atoms x 16B)
static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t);
// Max number of blocks. 304 CUs on MI300
static constexpr int kMaxNumBlocks = 304 * 4;
// Standard CDNA wavefront size.
static constexpr int kWavefront = 64;
// 256 thread, 4 wavefronts.
static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1};
// Number of threads in a group for quantization
// It corresponds to 32 F16 elements in quantization block
static constexpr int kThreadGroupSize = 8;
// Methods
__quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x,
unsigned long y) {
return ((x + y - 1) / y);
}
union BufferResource {
__quickreduce_device_inline__ constexpr BufferResource()
: config(0x00020000U) {}
__quickreduce_device_inline__ constexpr BufferResource(void* buffer_address,
uint32_t buffer_size)
: address(buffer_address), range(buffer_size), config(0x00020000U) {}
int32x4_t descriptor;
struct {
void* address; // 8B, out of which first 48b is address, and 16b is stride
// (unused)
uint32_t range; // Byte range for the buffer resource
uint32_t config; // Constant, DFMT=32b
};
};
__quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4(
int32x4_t srsrc, int32_t voffset, int32_t soffset,
int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
__quickreduce_device_inline__ static void buffer_store_dwordx4(
int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset,
int32_t aux) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) {
#if defined(__gfx942__)
if (value) {
asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::);
} else {
asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::);
}
#endif
}
union bf162_int_union {
int i;
nv_bfloat162 bf2;
};
template <typename T>
__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A,
int32x4_t* B);
template <>
__quickreduce_device_inline__ void packed_assign_add<half>(int32x4_t* A,
int32x4_t* B) {
int32x4_t& tR_fragment = A[0];
int32x4_t& tA_fragment = B[0];
asm volatile("v_pk_add_f16 %0, %1, %2"
: "=v"(tR_fragment[0])
: "v"(tR_fragment[0]), "v"(tA_fragment[0]));
asm volatile("v_pk_add_f16 %0, %1, %2"
: "=v"(tR_fragment[1])
: "v"(tR_fragment[1]), "v"(tA_fragment[1]));
asm volatile("v_pk_add_f16 %0, %1, %2"
: "=v"(tR_fragment[2])
: "v"(tR_fragment[2]), "v"(tA_fragment[2]));
asm volatile("v_pk_add_f16 %0, %1, %2"
: "=v"(tR_fragment[3])
: "v"(tR_fragment[3]), "v"(tA_fragment[3]));
}
template <>
__quickreduce_device_inline__ void packed_assign_add<nv_bfloat16>(
int32x4_t* A, int32x4_t* B) {
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(A);
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(B);
#pragma unroll
for (int i = 0; i < 4; i++) {
tA[i] = __hadd2(tA[i], tB[i]);
}
}
template <typename T>
__quickreduce_device_inline__ int packed_max(int a, int b);
template <>
__quickreduce_device_inline__ int packed_max<half>(int a, int b) {
int result;
asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
return result;
}
template <>
__quickreduce_device_inline__ int packed_max<nv_bfloat16>(int a, int b) {
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hmax2(A.bf2, B.bf2);
return R.i;
}
template <typename T>
__quickreduce_device_inline__ int packed_min(int a, int b);
template <>
__quickreduce_device_inline__ int packed_min<half>(int a, int b) {
int result;
asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
return result;
}
template <>
__quickreduce_device_inline__ int packed_min<nv_bfloat16>(int a, int b) {
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hmin2(A.bf2, B.bf2);
return R.i;
}
template <typename T>
__quickreduce_device_inline__ int packed_abs_max(int a, int b);
template <>
__quickreduce_device_inline__ int packed_abs_max<half>(int a, int b) {
half2 wmaxh2 = __builtin_bit_cast(half2, a);
half2 wminh2 = __builtin_bit_cast(half2, b);
half2 wblockmaxh2;
wblockmaxh2.x =
__hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x;
wblockmaxh2.y =
__hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y;
return __builtin_bit_cast(int, wblockmaxh2);
}
template <>
__quickreduce_device_inline__ int packed_abs_max<nv_bfloat16>(int a, int b) {
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x;
R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y;
return R.i;
}
template <typename T>
__quickreduce_device_inline__ int packed_add(int a, int b);
template <>
__quickreduce_device_inline__ int packed_add<half>(int a, int b) {
int result;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
return result;
}
template <>
__quickreduce_device_inline__ int packed_add<nv_bfloat16>(int a, int b) {
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hadd2(A.bf2, B.bf2);
return R.i;
}
template <>
__quickreduce_device_inline__ int packed_add<int16_t>(int a, int b) {
int result;
asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
return result;
}
template <typename T>
__quickreduce_device_inline__ int packed_sub(int a, int b);
template <>
__quickreduce_device_inline__ int packed_sub<half>(int a, int b) {
int result;
// MI300 lacks packed fp16 sub instruction. So we do -1 * min + max
asm volatile("v_pk_fma_f16 %0, %1, %2 %3"
: "=v"(result)
: "v"(kNegOne), "v"(b), "v"(a));
return result;
}
template <>
__quickreduce_device_inline__ int packed_sub<nv_bfloat16>(int a, int b) {
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hsub2(A.bf2, B.bf2);
return R.i;
}
template <typename T>
__quickreduce_device_inline__ int packed_mul(int a, int b);
template <>
__quickreduce_device_inline__ int packed_mul<half>(int a, int b) {
int result;
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
return result;
}
template <>
__quickreduce_device_inline__ int packed_mul<nv_bfloat16>(int a, int b) {
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(&a);
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(&b);
nv_bfloat162 tR = __hmul2(*tA, *tB);
return *(reinterpret_cast<int*>(&tR));
}
template <typename T>
__quickreduce_device_inline__ int packed_rcp(int a);
template <>
__quickreduce_device_inline__ int packed_rcp<half>(int a) {
return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a)));
}
template <>
__quickreduce_device_inline__ int packed_rcp<nv_bfloat16>(int a) {
bf162_int_union A, R;
A.i = a;
R.bf2 = h2rcp(A.bf2);
return R.i;
}
// changes dtype
__quickreduce_device_inline__ float T2float_cast(half a) {
return __half2float(a);
}
__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) {
return __bfloat162float(a);
}
template <typename T>
__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) {
const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize;
int wmax, wmin, wblockmax;
int a, b;
a = packed_max<T>(atom[0], atom[1]);
b = packed_max<T>(atom[2], atom[3]);
wmax = packed_max<T>(a, b);
a = packed_min<T>(atom[0], atom[1]);
b = packed_min<T>(atom[2], atom[3]);
wmin = packed_min<T>(a, b);
// Reduce the max among a group of threads
// Note: This is basically 2 blocks of values setup as the
// upper/lower halves of the f16x2_t
for (int i = 1; i < kThreadGroupSize; i <<= 1) {
int x = __shfl_down(wmax, i);
wmax = packed_max<T>(wmax, x);
int y = __shfl_down(wmin, i);
wmin = packed_min<T>(wmin, y);
}
wblockmax = packed_abs_max<T>(wmax, wmin);
// Share with the cohort
wblockmax = __shfl(wblockmax, group_leader);
return wblockmax;
}
__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr,
uint32_t flag) {
__atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE);
}
__quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr,
uint32_t flag) {
while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) {
}
}
} // namespace quickreduce

View File

@ -0,0 +1,196 @@
#pragma once
#include <vector>
#include <hip/hip_runtime.h>
#include "quick_reduce_impl.cuh"
#define HIP_CHECK(err) \
do { \
hipError_t err_ = (err); \
if (err_ != hipSuccess) { \
std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, \
hipGetErrorString(err_)); \
throw std::runtime_error("HIP error"); \
} \
} while (0)
namespace quickreduce {
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
template <typename AllReduceKernel, typename T>
__global__ __quickreduce_launch_bounds_two_shot__ static void
allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
int rank, uint8_t** dbuffer_list,
uint32_t data_offset, uint32_t flag_color) {
int block = blockIdx.x;
int grid = gridDim.x;
while (block < num_blocks) {
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset,
flag_color);
block += grid;
flag_color++;
}
}
#define TWOSHOT_DISPATCH(__codec) \
if (world_size == 2) { \
using LineCodec = __codec<T, 2>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
} else if (world_size == 4) { \
using LineCodec = __codec<T, 4>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
} else if (world_size == 8) { \
using LineCodec = __codec<T, 8>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
}
enum QuickReduceQuantLevel {
F16 = 0,
INT8 = 1,
INT6 = 2,
INT4 = 3,
};
struct DeviceComms {
// Max problem size is 2GB (in bytes) or half of uint32_t max value.
int64_t kMaxProblemSize =
static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
// Max TP-8
static int constexpr kMaxWorldSize = 8;
bool initialized = false;
uint32_t flag_color = 1;
int world_size;
int rank;
uint8_t* dbuffer;
uint8_t** dbuffer_list;
hipIpcMemHandle_t buffer_ipc_handle;
std::vector<hipIpcMemHandle_t> all_buffer_ipc_handles;
std::vector<uint8_t*> buffer_list;
uint32_t data_offset;
DeviceComms() : initialized(false), world_size(1), rank(0) {}
~DeviceComms() { destroy(); }
void init(int world_size, int rank,
std::optional<int64_t> max_problem_size = std::nullopt) {
destroy();
this->world_size = world_size;
this->rank = rank;
if (max_problem_size.has_value() && max_problem_size.value() > 0) {
this->kMaxProblemSize = max_problem_size.value();
}
// Allocate buffer size for worst case: F16 2-stage buffer.
uint32_t flags_buffer_size =
2 * world_size * kMaxNumBlocks * sizeof(uint32_t);
static int64_t data_buffer_size = 2 * this->kMaxProblemSize;
int64_t total_buffer_size = flags_buffer_size + data_buffer_size;
data_offset = flags_buffer_size;
HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size,
hipDeviceMallocUncached));
// Clear the flags buffer.
HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size));
// Device-side list of IPC buffers.
buffer_list.resize(world_size);
HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*)));
// Create IPC handles for rank's communication buffer.
all_buffer_ipc_handles.resize(world_size);
HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer));
initialized = true;
}
int get_world_size() { return world_size; }
int get_rank() { return rank; }
bool status() { return initialized; }
hipIpcMemHandle_t const get_handle() { return buffer_ipc_handle; }
void destroy() {
if (initialized) {
for (int i = 0; i < world_size; i++) {
if (i != rank) {
HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i]));
}
}
HIP_CHECK(hipFree(dbuffer));
HIP_CHECK(hipFree(dbuffer_list));
initialized = false;
}
}
void open_ipc_handles(std::vector<hipIpcMemHandle_t> const& ipc_handles) {
assert(ipc_handles.size() == all_buffer_ipc_handles.size());
for (int i = 0; i < world_size; i++) {
all_buffer_ipc_handles[i] = ipc_handles[i];
}
// Open device memory access to the IPC communication buffers.
// Note: For our own rank, we do not need to open a handle.
for (int i = 0; i < world_size; i++) {
if (i != rank) {
HIP_CHECK(hipIpcOpenMemHandle((void**)&buffer_list[i],
all_buffer_ipc_handles[i],
hipIpcMemLazyEnablePeerAccess));
} else {
buffer_list[i] = dbuffer;
}
}
HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(),
world_size * sizeof(uint8_t*), hipMemcpyHostToDevice));
}
template <typename T, bool cast_bf2half>
void allreduce(T const* A, T* B, uint32_t N, int quant_level,
hipStream_t stream) {
if (world_size != 2 && world_size != 4 && world_size != 8) {
throw std::runtime_error("All Reduce not supported for world_size = " +
std::to_string(world_size));
}
// Configuration.
uint32_t msg_size = N * sizeof(T);
uint32_t num_blocks = divceil(msg_size, kTileSize);
uint32_t grid = min(kMaxNumBlocks, num_blocks);
auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level);
switch (quant_level_) {
case QuickReduceQuantLevel::INT8:
TWOSHOT_DISPATCH(CodecQ8)
break;
case QuickReduceQuantLevel::INT6:
TWOSHOT_DISPATCH(CodecQ6)
break;
case QuickReduceQuantLevel::INT4:
TWOSHOT_DISPATCH(CodecQ4)
break;
default:
TWOSHOT_DISPATCH(CodecFP)
break;
}
HIP_CHECK(cudaGetLastError());
// Rotate the flag color.
flag_color += divceil(N, grid);
}
};
} // namespace quickreduce

View File

@ -0,0 +1,698 @@
#pragma once
#include <hip/hip_runtime.h>
#include "base.h"
namespace quickreduce {
struct CodecBase {
const int thread;
const int rank;
const int group_leader;
__quickreduce_device_inline__ CodecBase(int thread, int rank)
: thread(thread),
rank(rank),
group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) {
set_fp16_ovfl(true);
}
};
// Default full precision codec.
template <typename T, int world_size>
struct CodecFP : public CodecBase {
static constexpr int kWorldSize = world_size;
static constexpr int kRankAtoms = kAtoms / kWorldSize;
// Codec tile size process by this workgroup.
// Each thread processes atoms of f16x8_t (16B).
static constexpr int kRankTransmittedTileSize =
kBlockSize * kRankAtoms * sizeof(int32x4_t);
static_assert(kRankTransmittedTileSize % 16 == 0,
"kRankTransmittedTileSize must be 16B aligned.");
// Total tile size for the collective communication.
static constexpr int kTransmittedTileSize =
kRankTransmittedTileSize * kWorldSize;
__quickreduce_device_inline__ CodecFP(int thread, int rank)
: CodecBase(thread, rank) {}
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
const int32x4_t* __restrict__ data) {
for (int i = 0; i < kRankAtoms; i++) {
__builtin_nontemporal_store(data[i], send_buffer + thread);
send_buffer += kAtomStride;
}
}
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
int32x4_t* __restrict__ data) {
for (int i = 0; i < kRankAtoms; i++) {
data[i] = __builtin_nontemporal_load(*recv_buffer + thread);
*recv_buffer += kAtomStride;
}
}
};
// Int4 symmetric quantization codec.
// We quantize the FP16 data to block-scaled Int4 in blocks of 4 *
// kThreadGroupSize.
template <typename T, int world_size>
struct CodecQ4 : public CodecBase {
static constexpr int kWorldSize = world_size;
// Codec tile size process by this workgroup.
// Each threads processes a fragment of fp16x8_t (16B),
// into a int4x8_t (4B) and a fp16 scale shared among 32 values.
static constexpr int kRankAtoms = kAtoms / kWorldSize;
static constexpr int kRankTileStride = 1152;
static constexpr int kRankTileScaleOffset = 1024;
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
static_assert(kRankTransmittedTileSize % 16 == 0,
"kRankTransmittedTileSize must be 16B aligned.");
static constexpr int kRankBufferTileStride =
kRankTileStride / sizeof(int32x4_t);
// Total tile size for the collective communication.
static constexpr int kTransmittedTileSize =
kRankTransmittedTileSize * kWorldSize;
// Constants configuration
// {-1/8.0h, -1/8.0h}, f16x2_t
static constexpr int kScaleFactor =
std::is_same<T, half>::value ? 0xB000B000 : 0xBE00BE00;
// {1e-7, 1e-7}, f16x2_t
static constexpr int kScaleEpsilon =
std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
// {-8, -8}, f16x2_t
static constexpr int kRangeMin =
std::is_same<T, half>::value ? 0xC800C800 : 0xC100C100;
// {+7, +7}, f16x2_t
static constexpr int kRangeMax =
std::is_same<T, half>::value ? 0x47004700 : 0x40E040E0;
// {+8, +8}, int16x2_t
static constexpr int kRangeBias = 0x00080008;
__quickreduce_device_inline__ CodecQ4(int thread, int rank)
: CodecBase(thread, rank) {}
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
const int32x4_t* __restrict__ data) {
for (int k = 0; k < kRankAtoms; k++) {
int32x4_t const atom = data[k];
// Compute the absolute maximum of the atom in the thread group
// In 2 blocks of values, upper/lower halves of the f16x2_t
int wblockmax = group_abs_max<T>(atom);
// Derive scales
int decoding_scale;
int encoding_scale;
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
encoding_scale = packed_rcp<T>(encoding_scale);
// Apply scales to get quantized values
int32x4_t w;
for (int i = 0; i < 4; i++) {
w[i] = packed_mul<T>(atom[i], encoding_scale);
w[i] = packed_max<T>(w[i], kRangeMin);
w[i] = packed_min<T>(w[i], kRangeMax);
}
// Convert from f16x2_t to uint16x2_t
int32x4_t q;
{
int16_t* qi = reinterpret_cast<int16_t*>(&q);
T* wh = reinterpret_cast<T*>(&w);
for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
for (int i = 0; i < 4; i++) {
q[i] = packed_add<int16_t>(q[i], kRangeBias);
}
}
// Pack 8 x q4 into int32_t
int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12);
// Write quantized atom to send_buffer
// note: only the group leader stores the scale
uint8_t* atom_ptr =
reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
(thread / 8);
__builtin_nontemporal_store(qw, qw_ptr);
if (threadIdx.x == group_leader) {
__builtin_nontemporal_store(decoding_scale, qs_ptr);
}
}
}
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
int32x4_t* __restrict__ data) {
for (int k = 0; k < kRankAtoms; k++) {
// Directly read quantized atom from recv_buffer
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
(thread / 8);
int32_t qw = __builtin_nontemporal_load(qw_ptr);
int qs = __builtin_nontemporal_load(qs_ptr);
*recv_buffer += kRankBufferTileStride;
// Unpack q4 into f16x8_t
int32x4_t w;
{
static constexpr uint kMask000F = 0x000F000F;
static constexpr uint kHalf2_1024 =
0x64006400; // {1024.0, 1024.0}, fp16x2_t
static uint constexpr kHalf2_1032 =
0xE408E408; // {-1032.0, -1032.0}, fp16x2_t
for (int i = 0; i < 4; i++) {
if constexpr (std::is_same<T, half>::value) {
int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024;
w[i] = packed_add<half>(q4, kHalf2_1032);
} else {
int32_t int16_2 = (qw >> (i * 4)) & kMask000F;
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
}
}
}
// Apply decoding scales
for (int i = 0; i < 4; i++) {
w[i] = packed_mul<T>(w[i], qs);
}
data[k] = w;
}
}
};
// Int6 symmetric quantization codec.
// We quantize the FP16 data to block-scaled Int6 in blocks of 4 *
// kThreadGroupSize.
template <typename T, int world_size>
struct CodecQ6 : public CodecBase {
static constexpr int kWorldSize = world_size;
// Codec tile size process by this workgroup.
// Each threads processes a fragment of fp16x8_t (16B),
// into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values.
static constexpr int kRankAtoms = kAtoms / kWorldSize;
static constexpr int kRankTileStride = 1664;
static constexpr int kRankTileQ2Offset = 1024;
static constexpr int kRankTileScaleOffset = 1536;
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
static_assert(kRankTransmittedTileSize % 16 == 0,
"kRankTransmittedTileSize must be 16B aligned.");
static constexpr int kRankBufferTileStride =
kRankTileStride / sizeof(int32x4_t);
// Total tile size for the collective communication.
static constexpr int kTransmittedTileSize =
kRankTransmittedTileSize * kWorldSize;
// Constants configuration
// {-1/32.0h, -1/32.0h}, fp16x2_t
static constexpr int kScaleFactor =
std::is_same<T, half>::value ? 0xA800A800 : 0xBD00BD00;
// {1e-7, 1e-7}, fp16x2_t
static constexpr int kScaleEpsilon =
std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
// {-32, -32}, fp16x2_t
static constexpr int kRangeMin =
std::is_same<T, half>::value ? 0xD000D000 : 0xC200C200;
// {+31, +31}, fp16x2_t
static constexpr int kRangeMax =
std::is_same<T, half>::value ? 0x4FC04FC0 : 0x41F841F8;
// {+32, +32}, int16x2_t
static constexpr int kRangeBias = 0x00200020;
__quickreduce_device_inline__ CodecQ6(int thread, int rank)
: CodecBase(thread, rank) {}
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
const int32x4_t* __restrict__ data) {
for (int k = 0; k < kRankAtoms; k++) {
int32x4_t const atom = data[k];
// Compute the absolute maximum of the atom in the thread group
// In 2 blocks of values, upper/lower halves of the f16x2_t
int wblockmax = group_abs_max<T>(atom);
// Derive scales
int decoding_scale;
int encoding_scale;
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
encoding_scale = packed_rcp<T>(encoding_scale);
// Apply scales to get quantized values
int32x4_t w;
for (int i = 0; i < 4; i++) {
w[i] = packed_mul<T>(atom[i], encoding_scale);
w[i] = packed_max<T>(w[i], kRangeMin);
w[i] = packed_min<T>(w[i], kRangeMax);
}
// Convert from f16x2_t to uint16x2_t
int32x4_t q;
{
int16_t* qi = reinterpret_cast<int16_t*>(&q);
T* wh = reinterpret_cast<T*>(&w);
for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
for (int i = 0; i < 4; i++) {
q[i] = packed_add<int16_t>(q[i], kRangeBias);
}
}
// Pack 8 x q6 into int32_t + int16_t
uint32_t q4w;
uint16_t q2w = 0;
q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) |
((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12);
{
int16_t* tw = reinterpret_cast<int16_t*>(&q);
#pragma unroll
for (int i = 0; i < 8; i++) {
q2w |= (tw[i] >> 4) << (i * 2);
}
}
// Write quantized atom to send_buffer
// note: only the group leader stores the scale
uint8_t* atom_ptr =
reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread;
uint16_t* q2w_ptr =
reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread;
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
(thread / 8);
__builtin_nontemporal_store(q4w, q4w_ptr);
__builtin_nontemporal_store(q2w, q2w_ptr);
if (threadIdx.x == group_leader) {
__builtin_nontemporal_store(decoding_scale, qs_ptr);
}
}
}
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
int32x4_t* __restrict__ data) {
for (int k = 0; k < kRankAtoms; k++) {
// Directly read quantized atom from recv_buffer
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread;
uint16_t* q2w_ptr =
reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread;
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
(thread / 8);
uint32_t q4w = __builtin_nontemporal_load(q4w_ptr);
uint16_t q2w = __builtin_nontemporal_load(q2w_ptr);
int qs = __builtin_nontemporal_load(qs_ptr);
*recv_buffer += kRankBufferTileStride;
// Unpack q6 into fp16x8_t
int32x4_t w;
{
static uint constexpr kMask000F = 0x000F000F;
static uint constexpr kHalf2_1024 =
0x64006400; // {1024.0, 1024.0}, fp16x2_t
static uint constexpr kHalf2_1056 =
0xE420E420; // {-1056.0, -1056.0}, fp16x2_t
#pragma unroll
for (int i = 0; i < 4; i++) {
int32_t q4 = q4w & kMask000F;
int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14);
q4w >>= 4;
q2w >>= 4;
if constexpr (std::is_same<T, half>::value) {
int32_t q6 = q4 | (q2 << 4) | kHalf2_1024;
asm volatile("v_pk_add_f16 %0, %1, %2"
: "=v"(w[i])
: "v"(q6), "v"(kHalf2_1056));
} else {
int32_t int16_2 = q4 | (q2 << 4);
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
}
}
}
// Apply decoding scales
for (int i = 0; i < 4; i++) {
w[i] = packed_mul<T>(w[i], qs);
}
// That's pretty much it...
data[k] = w;
}
}
};
// Int8 symmetric quantization codec.
// We quantize the FP16 data to block-scaled Int8 in blocks of 4 *
// kThreadGroupSize.
template <typename T, int world_size>
struct CodecQ8 : public CodecBase {
static constexpr int kWorldSize = world_size;
// Codec tile size process by this workgroup.
// Each threads processes a fragment of f16x8_t (16B),
// into a int8x8_t (8B) and a f16 scale shared among 32 values.
static constexpr int kRankAtoms = kAtoms / kWorldSize;
static constexpr int kRankTileStride = 2176;
static constexpr int kRankTileScaleOffset = 2048;
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
static_assert(kRankTransmittedTileSize % 16 == 0,
"kRankTileSize must be 16B aligned.");
static constexpr int kRankBufferTileStride =
kRankTileStride / sizeof(int32x4_t);
// Total tile size for the collective communication.
static constexpr int kTransmittedTileSize =
kRankTransmittedTileSize * kWorldSize;
// Constants configuration
// {-1/128.0h, -1/128.0h}, f16x2_t
static constexpr int kScaleFactor =
std::is_same<T, half>::value ? 0xA000A000 : 0xBC00BC00;
// {1e-7, 1e-7}, f16x2_t
static constexpr int kScaleEpsilon =
std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
// {-128, -128}, f16x2_t
static constexpr int kRangeMin =
std::is_same<T, half>::value ? 0xD800D800 : 0xC300C300;
// {+127, +127}, f16x2_t
static constexpr int kRangeMax =
std::is_same<T, half>::value ? 0x57F057F0 : 0x42FE42FE;
// {+128, +128}, int16x2_t
static constexpr int kRangeBias = 0x00800080;
__quickreduce_device_inline__ CodecQ8(int thread, int rank)
: CodecBase(thread, rank) {}
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
int32x4_t const* __restrict__ data) {
for (int k = 0; k < kRankAtoms; k++) {
int32x4_t const atom = data[k];
// Compute the absolute maximum of the atom in the thread group
// In 2 blocks of values, upper/lower halves of the f16x2_t
int wblockmax = group_abs_max<T>(atom);
// Derive scales
int decoding_scale;
int encoding_scale;
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
encoding_scale = packed_rcp<T>(encoding_scale);
// Apply scales to get quantized values
int32x4_t w;
for (int i = 0; i < 4; i++) {
w[i] = packed_mul<T>(atom[i], encoding_scale);
w[i] = packed_max<T>(w[i], kRangeMin);
w[i] = packed_min<T>(w[i], kRangeMax);
}
// Convert from f16x2_t to uint16x2_t
int32x4_t q;
{
int16_t* qi = reinterpret_cast<int16_t*>(&q);
T* wh = reinterpret_cast<T*>(&w);
for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
for (int i = 0; i < 4; i++) {
q[i] = packed_add<int16_t>(q[i], kRangeBias);
}
}
// Pack 8 x q8 into int32x2_t
int32x2_t qw;
qw[0] = q[0] | (q[1] << 8);
qw[1] = q[2] | (q[3] << 8);
// Write quantized atom to send_buffer
// note: only the group leader stores the scale
uint8_t* atom_ptr =
reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
(thread / 8);
__builtin_nontemporal_store(qw, qw_ptr);
if (threadIdx.x == group_leader) {
__builtin_nontemporal_store(decoding_scale, qs_ptr);
}
}
}
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
int32x4_t* __restrict__ data) {
for (int k = 0; k < kRankAtoms; k++) {
// Directly read quantized atom from recv_buffer
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
(thread / 8);
int32x2_t qw = __builtin_nontemporal_load(qw_ptr);
int qs = __builtin_nontemporal_load(qs_ptr);
*recv_buffer += kRankBufferTileStride;
// Unpack q8 into fp16x8_t
int32x4_t w;
{
static uint constexpr kMask00FF = 0x00FF00FF;
// {1024.0, 1024.0}, fp16x2_t
static uint constexpr kHalf2_1024 = 0x64006400;
// {-1152.0, -1152.0}, fp16x2_t
static uint constexpr kHalf2_1152 = 0xE480E480;
#pragma unroll
for (int i = 0; i < 4; i++) {
if constexpr (std::is_same<T, half>::value) {
int32_t q8 =
((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024;
w[i] = packed_add<half>(q8, kHalf2_1152);
} else {
int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF;
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
}
}
}
// Apply decoding scales
for (int i = 0; i < 4; i++) {
w[i] = packed_mul<T>(w[i], qs);
}
data[k] = w;
}
}
};
// Twoshot All Reduce
template <typename T, class Codec, bool cast_bf2half>
struct AllReduceTwoshot {
static_assert(sizeof(T) == 2);
static constexpr int kWorldSize = Codec::kWorldSize;
__device__ static void run(
T const* __restrict__ input, T* __restrict__ output,
uint32_t const N, // number of elements
int const block, // block index
int const rank, // rank index
uint8_t** __restrict__ buffer_list, // communication buffers
uint32_t const data_offset, // offset to start of the data buffer
uint32_t flag_color) {
// Topology
int thread = threadIdx.x + threadIdx.y * kWavefront;
uint8_t* rank_buffer = buffer_list[rank];
Codec codec(thread, rank);
int block_id = blockIdx.x;
int grid_size = gridDim.x;
// --------------------------------------------------------
// Read input into registers
int32x4_t tA[kAtoms];
BufferResource src_buffer(const_cast<T*>(input), N * sizeof(T));
uint32_t src_offset = block * kTileSize + thread * sizeof(int32x4_t);
for (int i = 0; i < kAtoms; i++) {
tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0);
src_offset += kAtomStride * sizeof(int32x4_t);
if constexpr (cast_bf2half) {
const nv_bfloat162* bf_buf =
reinterpret_cast<const nv_bfloat162*>(&tA[i]);
half2 half_buf[4];
#pragma unroll
for (int j = 0; j < 4; ++j) {
float2 f = __bfloat1622float2(bf_buf[j]);
half_buf[j] = __float22half2_rn(f);
}
tA[i] = *reinterpret_cast<const int32x4_t*>(half_buf);
}
}
// --------------------------------------------------------
// Phase-1A: Write segment data into the communication buffer of the target
// rank responsible for this segment.
uint32_t comm_data0_offset =
data_offset + block_id * Codec::kTransmittedTileSize;
uint32_t comm_data1_offset =
grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
uint32_t comm_flags1_offset =
grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
for (int r = 0; r < kWorldSize; r++) {
int32x4_t* send_buffer =
reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data0_offset +
rank * Codec::kRankTransmittedTileSize);
codec.send(send_buffer, &tA[r * Codec::kRankAtoms]);
}
__syncthreads();
if (thread < kWorldSize) {
int r = thread;
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(
buffer_list[r] + comm_flags0_offset + rank * sizeof(uint32_t));
set_sync_flag(flag_ptr, flag_color);
}
// --------------------------------------------------------
// Phase-1B: Reduce the segment data from the communication buffers.
int32x4_t tR[Codec::kRankAtoms] = {};
{
// Read the data from the communication buffer.
int32x4_t* recv_buffer =
reinterpret_cast<int32x4_t*>(rank_buffer + comm_data0_offset);
uint32_t* flag_ptr =
reinterpret_cast<uint32_t*>(rank_buffer + comm_flags0_offset);
for (int r = 0; r < kWorldSize; r++) {
// Wait for the flags to be set.
if (thread == 0) {
wait_sync_flag(&flag_ptr[r], flag_color);
}
__syncthreads();
// note: we reuse tA as temp buffer here
codec.recv(&recv_buffer, tA);
for (int i = 0; i < Codec::kRankAtoms; i++) {
packed_assign_add<T>(&tR[i], &tA[i]);
}
}
}
// Phase-2: Write the reduced segment to every other rank
for (int r = 0; r < kWorldSize; r++) {
int32x4_t* send_buffer =
reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data1_offset +
rank * Codec::kRankTransmittedTileSize);
codec.send(send_buffer, tR);
}
__syncthreads();
if (thread < kWorldSize) {
int r = thread;
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(
buffer_list[r] + comm_flags1_offset + rank * sizeof(uint32_t));
set_sync_flag(flag_ptr, flag_color);
}
// Phase-2: Read the gather segments from the rank's communication buffer.
{
// Read the data from the communication buffer.
int32x4_t* recv_buffer =
reinterpret_cast<int32x4_t*>(rank_buffer + comm_data1_offset);
uint32_t* flag_ptr =
reinterpret_cast<uint32_t*>(rank_buffer + comm_flags1_offset);
for (int r = 0; r < kWorldSize; r++) {
// Wait for the flags to be set.
if (thread == 0) {
wait_sync_flag(&flag_ptr[r], flag_color);
}
__syncthreads();
// Gather all reduced and final rank segments into tA.
codec.recv(&recv_buffer, &tA[r * Codec::kRankAtoms]);
}
}
// --------------------------------------------------------
// Write the result to output.
BufferResource dst_buffer(output, N * sizeof(T));
uint32_t dst_offset = block * kTileSize + thread * sizeof(int32x4_t);
for (int i = 0; i < kAtoms; i++) {
if constexpr (cast_bf2half) {
const half2* half_buf = reinterpret_cast<const half2*>(&tA[i]);
nv_bfloat162 bf16_buf[4];
#pragma unroll
for (int j = 0; j < 4; ++j) {
float2 f = __half22float2(half_buf[j]);
bf16_buf[j] = __float22bfloat162_rn(f);
}
buffer_store_dwordx4(*reinterpret_cast<const int32x4_t*>(bf16_buf),
dst_buffer.descriptor, dst_offset, 0, 0);
} else {
buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0);
}
dst_offset += kAtomStride * sizeof(int32x4_t);
}
}
};
} // namespace quickreduce

View File

@ -725,6 +725,24 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle);
custom_ar.def("free_shared_buffer", &free_shared_buffer);
#ifdef USE_ROCM
// Quick Reduce all-reduce kernels
custom_ar.def(
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
"cast_bf2half) -> ()");
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
custom_ar.def("init_custom_qr", &init_custom_qr);
custom_ar.def("qr_destroy", &qr_destroy);
custom_ar.def("qr_get_handle", &qr_get_handle);
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
// Max input size in bytes
custom_ar.def("qr_max_size", &qr_max_size);
#endif
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

View File

@ -0,0 +1,138 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import pytest
import ray
import torch
import torch.distributed as dist
from vllm.distributed.communication_op import ( # noqa
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_group, graph_capture)
from vllm.platforms import current_platform
from ..utils import (ensure_model_parallel_initialized,
init_test_distributed_environment, multi_process_parallel)
torch.manual_seed(42)
random.seed(44)
# Size over 8MB is sufficient for custom quick allreduce.
test_sizes = [
random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)
]
for i, v in enumerate(test_sizes):
test_sizes[i] -= v % 8
@ray.remote(num_gpus=1, max_calls=1)
def graph_quickreduce(
monkeypatch: pytest.MonkeyPatch,
tp_size,
pp_size,
rank,
distributed_init_port,
):
with monkeypatch.context() as m:
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
ensure_model_parallel_initialized(tp_size, pp_size)
group = get_tensor_model_parallel_group().device_group
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
# (e.g. NCCL). This will ensure that the communicator is initialized
# before any communication happens, so that this group can be used for
# graph capture immediately.
data = torch.zeros(1)
data = data.to(device=device)
torch.distributed.all_reduce(data, group=group)
torch.cuda.synchronize()
del data
# we use the first group to communicate once
# and the second group to communicate twice
# and so on
# this is used to demonstrate that each group can
# communicate independently
num_communication = rank // tp_size + 1
for sz in test_sizes:
for dtype in [torch.float16, torch.bfloat16]:
with graph_capture(device=device) as graph_capture_context:
inp1 = torch.randint(1,
23, (sz, ),
dtype=dtype,
device=torch.cuda.current_device())
inp2 = torch.randint(-23,
1, (sz, ),
dtype=dtype,
device=torch.cuda.current_device())
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph,
stream=graph_capture_context.stream):
for _ in range(num_communication):
out1 = tensor_model_parallel_all_reduce(inp1)
dist.all_reduce(inp1, group=group)
out2 = tensor_model_parallel_all_reduce(inp2)
dist.all_reduce(inp2, group=group)
graph.replay()
torch.testing.assert_close(out1, inp1, atol=2.5, rtol=0.1)
torch.testing.assert_close(out2, inp2, atol=2.5, rtol=0.1)
@ray.remote(num_gpus=1, max_calls=1)
def eager_quickreduce(
monkeypatch: pytest.MonkeyPatch,
tp_size,
pp_size,
rank,
distributed_init_port,
):
with monkeypatch.context() as m:
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
# Size over 8MB is sufficient for custom quick allreduce.
sz = 16 * 1024 * 1024
fa = get_tp_group().device_communicator.qr_comm
inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)],
dtype=torch.float16,
device=device)
out = fa.quick_all_reduce(inp)
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)],
dtype=torch.bfloat16,
device=device)
out = fa.quick_all_reduce(inp)
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test quick allreduce for rocm")
@pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"])
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
@pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce])
def test_custom_quick_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
pipeline_parallel_size, test_target,
quant_mode):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size,
test_target)

View File

@ -1748,6 +1748,38 @@ def free_shared_buffer(ptr: int) -> None:
torch.ops._C_custom_ar.free_shared_buffer(ptr)
# quick all reduce
def init_custom_qr(rank: int,
world_size: int,
qr_max_size: Optional[int] = None) -> int:
return torch.ops._C_custom_ar.init_custom_qr(rank, world_size, qr_max_size)
def qr_destroy(fa: int) -> None:
torch.ops._C_custom_ar.qr_destroy(fa)
def qr_all_reduce(fa: int,
inp: torch.Tensor,
out: torch.Tensor,
quant_level: int,
cast_bf2half: bool = False) -> None:
torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level,
cast_bf2half)
def qr_get_handle(fa: int) -> torch.Tensor:
return torch.ops._C_custom_ar.qr_get_handle(fa)
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
return torch.ops._C_custom_ar.qr_open_handles(fa, handles)
def qr_max_size() -> int:
return torch.ops._C_custom_ar.qr_max_size()
def get_flash_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,

View File

@ -8,6 +8,7 @@ from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .base_device_communicator import DeviceCommunicatorBase
@ -41,6 +42,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
CustomAllreduce)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
@ -50,6 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
)
self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
@ -57,6 +61,14 @@ class CudaCommunicator(DeviceCommunicatorBase):
device=self.device,
)
if current_platform.is_rocm():
# Initialize a custom quick all-reduce implementation for AMD.
# Quick reduce is designed as a complement to custom allreduce.
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
# If it's a rocm, 'use_custom_allreduce==True' means it must
# currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
@ -79,8 +91,14 @@ class CudaCommunicator(DeviceCommunicatorBase):
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
def all_reduce(self, input_):
# always try custom allreduce first,
# and then pynccl.
# always try quick reduce first, then custom allreduce,
# and then pynccl. (quick reduce just for ROCM MI3*)
qr_comm = self.qr_comm
if qr_comm is not None and not qr_comm.disabled and \
qr_comm.should_quick_allreduce(input_):
out = qr_comm.quick_all_reduce(input_)
assert out is not None
return out
ca_comm = self.ca_comm
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):

View File

@ -0,0 +1,278 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless
logger = init_logger(__name__)
try:
ops.qr_max_size()
quick_ar = True
except Exception:
# For CPUs and CUDA
quick_ar = False
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (inp.storage().nbytes() -
inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size())
class QuickReduceRegime(Enum):
FP = 0
INT8 = 1
INT6 = 2
INT4 = 3
NONE = 4
MB = 1024 * 1024
class QuickAllReduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 8]
_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
# The following data is based on kernel tests.
# In this order [FP, INT8, INT6, INT4].
_QR_MIN_SIZE = {
(torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB],
(torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB],
(torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB],
(torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB],
(torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB],
(torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB],
}
def __init__(self, group: ProcessGroup,
device: Union[int, str, torch.device]) -> None:
"""
Custom allreduce provides non-destructive acceleration and is
available for CUDA and ROCm MI300 series.
Custom quick allreduce leverages quantization for further
acceleration on ROCm. It currently supports Q8, Q6, and Q4
quantization formats and FP(float16, bfloat16).
Quick allreduce is designed as a complement to custom allreduce.
Its initialization requires even stricter conditions.
Only the ROCm MI300 series is supported for quick allreduce at
this time.
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self.disabled = True
if not self._rocm_arch_available():
logger.debug(
"Custom quick allreduce is only supported on ROCm MI300 series."
)
return
if not quick_ar:
# disable because of missing quick reduce library
# e.g. in a cuda environment
logger.info("Custom quick allreduce is disabled because "
"of missing custom quick allreduce library")
return
self.group = group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"Custom quick allreduce should be attached to a non-NCCL group.")
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom quick allreduce for
# multi-node case.
logger.warning("Custom quick allreduce is disabled because this "
"process group spans across nodes.")
return
rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
self.rank = rank
self.world_size = world_size
if world_size == 1:
# No need to initialize QuickReduce for single GPU case.
return
if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom quick allreduce is disabled due to an "
"unsupported world size: %d. Supported world sizes: %s.",
world_size, str(QuickAllReduce._SUPPORTED_WORLD_SIZES))
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
assert isinstance(device, torch.device)
self.device = device
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(cuda_device_count_stateless()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
dtype=torch.int,
device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu")
for _ in range(self.world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
# test nvlink first, this will filter out most of the cases
# where custom quick allreduce is not supported
# this checks hardware and driver support for NVLink
assert current_platform.is_cuda_alike()
self.fully_connected = current_platform.is_fully_connected(
physical_device_ids)
if self.world_size > 2 and not self.fully_connected:
logger.debug(
"Custom quick allreduce is disabled because it's not supported "
"on more than two PCIe-only GPUs. ")
return
self.init_quick_all_reduce()
def init_quick_all_reduce(self):
# On RocM, bfloat16 kernels are slower than fp16
# due to slower match operations
# If environment variable is set to 1, we convert input to fp16
self.use_fp16_kernels = envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16
regime_str = envs.VLLM_ROCM_QUICK_REDUCE_QUANTIZATION
if regime_str not in QuickReduceRegime.__members__:
logger.warning(
"Custom quick allreduce:",
f"Invalid quantization level: {regime_str}. "
"Supported levels: "
f"{list(QuickReduceRegime.__members__.keys())}")
return
if regime_str == "NONE":
logger.debug("Custom quick allreduce is disabled based "
"on env variable "
"VLLM_ROCM_QUICK_REDUCE_QUANTIZATION='NONE'")
return
self.qr_quant_level = QuickReduceRegime[regime_str]
vllm_config = get_current_vllm_config()
if vllm_config is not None and \
hasattr(vllm_config, "model_config") and \
hasattr(vllm_config.model_config, "dtype"):
dtype = vllm_config.model_config.dtype
if dtype not in [torch.float16, torch.bfloat16]:
logger.debug(
"Custom quick allreduce disabled: only supports "
"float16 and float16, but get %s.", dtype)
return
if dtype == torch.bfloat16 and self.use_fp16_kernels:
logger.info(
"Custom quick allreduce: BF16 inputs will be converted "
"to FP16 to improve performance. set "
"envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16=0 "
"to turn off.")
# VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB
qr_max_size = envs.VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB
if qr_max_size is not None:
if qr_max_size < 1:
logger.info(
"You should not set a max_size smaller than 1MB, which can "
"lead to error or degradation to custom allreduce or rccl."
)
qr_max_size = qr_max_size * MB
self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size)
self.qr_max_size = qr_max_size if qr_max_size is not None \
else ops.qr_max_size()
self.create_shared_buffer()
self.disabled = False
def _rocm_arch_available(self):
if not current_platform.is_rocm():
return False
try:
props = torch.cuda.get_device_properties(0)
gcn_arch = getattr(props, "gcnArchName", "")
supported_archs = ['gfx94', 'gfx95']
return any(gfx in gcn_arch for gfx in supported_archs)
except Exception as e:
logger.warning("Failed to determine ROCm for quick allreduce: %s",
e)
return False
def create_shared_buffer(self):
"""
Creates a shared buffer for quickreduce.
Has to be called after init_custom_qr
"""
handle = ops.qr_get_handle(self._ptr)
world_size = dist.get_world_size(group=self.group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=self.group)
ops.qr_open_handles(self._ptr, handles)
def should_quick_allreduce(self, inp: torch.Tensor):
"""
Check if quickreduce is available
"""
if self.disabled:
return False
if inp.dtype not in self._SUPPORTED_DTYPES:
return False
inp_size = inp.numel() * inp.element_size()
# custom quick allreduce requires input byte size to be
# multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
dtype = inp.dtype
if self.use_fp16_kernels:
dtype = torch.float16
return inp_size <= self.qr_max_size and \
inp_size >= self._QR_MIN_SIZE[(dtype, self.world_size)]\
[self.qr_quant_level.value]
def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None):
"""Performs an out-of-place custom quick all reduce."""
# quick allreduce doesn't require a separate graph mode,
# as QR uses static IPC buffer.
if out is None:
out = torch.empty_like(inp)
ops.qr_all_reduce(self._ptr, inp, out, self.qr_quant_level.value,
self.use_fp16_kernels)
return out
def close(self):
if not self.disabled and getattr(self, "_ptr", None):
if ops is not None:
ops.qr_destroy(self._ptr)
self._ptr = 0
self.disabled = True
def __del__(self):
self.close()

View File

@ -135,6 +135,9 @@ if TYPE_CHECKING:
VLLM_KV_CACHE_LAYOUT: Optional[str] = None
VLLM_COMPUTE_NANS_IN_LOGITS: bool = False
VLLM_USE_NVFP4_CT_EMULATIONS: bool = False
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
def get_default_cache_root():
@ -690,6 +693,31 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
("true", "1")),
# Custom quick allreduce kernel for MI3* cards
# Choice of quantization level: FP, INT8, INT6, INT4 or NONE
# Recommended for large models to get allreduce
"VLLM_ROCM_QUICK_REDUCE_QUANTIZATION":
lambda: os.getenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "NONE").upper(),
# Custom quick allreduce kernel for MI3* cards
# Due to the lack of the bfloat16 asm instruction, bfloat16
# kernels are slower than fp16,
# If environment variable is set to 1, the input is converted to fp16
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16":
lambda:
(os.getenv("VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "True").lower() in
("true", "1")),
# Custom quick allreduce kernel for MI3* cards.
# Controls the maximum allowed number of data bytes(MB) for custom quick
# allreduce communication.
# Default: 2048 MB.
# Data exceeding this size will use either custom allreduce or RCCL
# communication.
"VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB":
lambda: maybe_convert_int(
os.environ.get("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", None)),
# If set, when running in Quark emulation mode, do not dequantize the
# weights at load time. Instead, dequantize weights on-the-fly during
# kernel execution.