ggml : add ggml_soft_max_ext (#4256)

* metal : implement soft_max_ext

* cuda : implement soft_max_ext

* ggml : implement soft_max_ext (CPU)

* batched-bench : print threads

ggml-ci

* metal : simplify soft_max encoding

ggml-ci

* cuda : use 512 threads for soft_max instead of 32

* ggml : update soft max cpu

* cuda : do warp-based block reduce

* cuda : increase max block size to 1024

* cuda : fix warp reduction initialization of shared mem

* metal : warp-based reduction for soft max kernel

* metal : warp-based reduce for rms_norm

* metal : simplify soft max kernel

ggml-ci

* alloc : fix build with debug
This commit is contained in:
Georgi Gerganov 2023-12-01 10:51:24 +02:00 committed by GitHub
parent 1d144112c0
commit ef47ec18da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 311 additions and 196 deletions

View file

@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_SCALE_BLOCK_SIZE 256
#define CUDA_CLAMP_BLOCK_SIZE 256
#define CUDA_ROPE_BLOCK_SIZE 256
#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
#define CUDA_ALIBI_BLOCK_SIZE 32
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
#define CUDA_QUANTIZE_BLOCK_SIZE 256
@ -501,6 +502,31 @@ static size_t g_scratch_offset = 0;
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
}
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
}
return a;
}
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
}
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -577,15 +603,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] * x[i];
}
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
}
return a;
}
template <int block_size>
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@ -624,14 +641,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
}
}
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
}
template <int block_size>
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@ -4717,45 +4726,74 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
}
// the CUDA soft max implementation differs from the CPU implementation
// instead of doubles floats are used
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int block_size = blockDim.y;
const int tid = threadIdx.y;
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
const int block_size = blockDim.x;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
float max_val = -INFINITY;
for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
max_val = max(max_val, x[i]);
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
}
// find the max value in the block
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf[lane_id] = -INFINITY;
}
__syncthreads();
if (lane_id == 0) {
buf[warp_id] = max_val;
}
__syncthreads();
max_val = buf[lane_id];
max_val = warp_reduce_max(max_val);
}
float tmp = 0.f;
for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
const float val = expf(x[i] - max_val);
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
tmp += val;
dst[i] = val;
dst[ix] = val;
}
// sum up partial sums
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf[lane_id] = 0.f;
}
__syncthreads();
if (lane_id == 0) {
buf[warp_id] = tmp;
}
__syncthreads();
tmp = buf[lane_id];
tmp = warp_reduce_sum(tmp);
}
const float inv_tmp = 1.f / tmp;
for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
const int i = rowx*ncols + col;
dst[i] *= inv_tmp;
}
}
@ -5792,10 +5830,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
const dim3 block_dims(1, WARP_SIZE, 1);
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
}
static void im2col_f32_f16_cuda(const float * x, half * dst,
@ -6846,14 +6886,18 @@ inline void ggml_cuda_op_soft_max(
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
float scale = 1.0f;
memcpy(&scale, dst->op_params, sizeof(float));
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
inline void ggml_cuda_op_scale(