vllm/csrc/quickreduce/quick_reduce.h

196 lines
7.3 KiB
C++

#pragma once
#include <vector>
#include <hip/hip_runtime.h>
#include "quick_reduce_impl.cuh"
#define HIP_CHECK(err) \
do { \
hipError_t err_ = (err); \
if (err_ != hipSuccess) { \
std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, \
hipGetErrorString(err_)); \
throw std::runtime_error("HIP error"); \
} \
} while (0)
namespace quickreduce {
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
template <typename AllReduceKernel, typename T>
__global__ __quickreduce_launch_bounds_two_shot__ static void
allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
int rank, uint8_t** dbuffer_list,
uint32_t data_offset, uint32_t flag_color) {
int block = blockIdx.x;
int grid = gridDim.x;
while (block < num_blocks) {
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset,
flag_color);
block += grid;
flag_color++;
}
}
#define TWOSHOT_DISPATCH(__codec) \
if (world_size == 2) { \
using LineCodec = __codec<T, 2>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
} else if (world_size == 4) { \
using LineCodec = __codec<T, 4>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
} else if (world_size == 8) { \
using LineCodec = __codec<T, 8>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \
}
enum QuickReduceQuantLevel {
F16 = 0,
INT8 = 1,
INT6 = 2,
INT4 = 3,
};
struct DeviceComms {
// Max problem size is 2GB (in bytes) or half of uint32_t max value.
int64_t kMaxProblemSize =
static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
// Max TP-8
static int constexpr kMaxWorldSize = 8;
bool initialized = false;
uint32_t flag_color = 1;
int world_size;
int rank;
uint8_t* dbuffer;
uint8_t** dbuffer_list;
hipIpcMemHandle_t buffer_ipc_handle;
std::vector<hipIpcMemHandle_t> all_buffer_ipc_handles;
std::vector<uint8_t*> buffer_list;
uint32_t data_offset;
DeviceComms() : initialized(false), world_size(1), rank(0) {}
~DeviceComms() { destroy(); }
void init(int world_size, int rank,
std::optional<int64_t> max_problem_size = std::nullopt) {
destroy();
this->world_size = world_size;
this->rank = rank;
if (max_problem_size.has_value() && max_problem_size.value() > 0) {
this->kMaxProblemSize = max_problem_size.value();
}
// Allocate buffer size for worst case: F16 2-stage buffer.
uint32_t flags_buffer_size =
2 * world_size * kMaxNumBlocks * sizeof(uint32_t);
static int64_t data_buffer_size = 2 * this->kMaxProblemSize;
int64_t total_buffer_size = flags_buffer_size + data_buffer_size;
data_offset = flags_buffer_size;
HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size,
hipDeviceMallocUncached));
// Clear the flags buffer.
HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size));
// Device-side list of IPC buffers.
buffer_list.resize(world_size);
HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*)));
// Create IPC handles for rank's communication buffer.
all_buffer_ipc_handles.resize(world_size);
HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer));
initialized = true;
}
int get_world_size() { return world_size; }
int get_rank() { return rank; }
bool status() { return initialized; }
hipIpcMemHandle_t const get_handle() { return buffer_ipc_handle; }
void destroy() {
if (initialized) {
for (int i = 0; i < world_size; i++) {
if (i != rank) {
HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i]));
}
}
HIP_CHECK(hipFree(dbuffer));
HIP_CHECK(hipFree(dbuffer_list));
initialized = false;
}
}
void open_ipc_handles(std::vector<hipIpcMemHandle_t> const& ipc_handles) {
assert(ipc_handles.size() == all_buffer_ipc_handles.size());
for (int i = 0; i < world_size; i++) {
all_buffer_ipc_handles[i] = ipc_handles[i];
}
// Open device memory access to the IPC communication buffers.
// Note: For our own rank, we do not need to open a handle.
for (int i = 0; i < world_size; i++) {
if (i != rank) {
HIP_CHECK(hipIpcOpenMemHandle((void**)&buffer_list[i],
all_buffer_ipc_handles[i],
hipIpcMemLazyEnablePeerAccess));
} else {
buffer_list[i] = dbuffer;
}
}
HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(),
world_size * sizeof(uint8_t*), hipMemcpyHostToDevice));
}
template <typename T, bool cast_bf2half>
void allreduce(T const* A, T* B, uint32_t N, int quant_level,
hipStream_t stream) {
if (world_size != 2 && world_size != 4 && world_size != 8) {
throw std::runtime_error("All Reduce not supported for world_size = " +
std::to_string(world_size));
}
// Configuration.
uint32_t msg_size = N * sizeof(T);
uint32_t num_blocks = divceil(msg_size, kTileSize);
uint32_t grid = min(kMaxNumBlocks, num_blocks);
auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level);
switch (quant_level_) {
case QuickReduceQuantLevel::INT8:
TWOSHOT_DISPATCH(CodecQ8)
break;
case QuickReduceQuantLevel::INT6:
TWOSHOT_DISPATCH(CodecQ6)
break;
case QuickReduceQuantLevel::INT4:
TWOSHOT_DISPATCH(CodecQ4)
break;
default:
TWOSHOT_DISPATCH(CodecFP)
break;
}
HIP_CHECK(cudaGetLastError());
// Rotate the flag color.
flag_color += divceil(N, grid);
}
};
} // namespace quickreduce