cuda : add ALiBi support in ggml_soft_max_ext

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-02-14 18:29:24 +02:00
parent 97d6a0cc06
commit a0f8a93bf1
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 69 additions and 34 deletions

View file

@ -5957,7 +5957,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
} }
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check> template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template; const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2; const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
@ -5971,6 +5971,21 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
const int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE;
float slope = 0.0f;
// ALiBi
if (max_bias > 0.0f) {
const uint32_t n_head_kv = gridDim.x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
const float m0 = pow(2.0f, -(max_bias ) / n_head_log2);
const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
const int h = rowx/nrows_y; // head index
slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1);
}
extern __shared__ half data_soft_max_f16[]; extern __shared__ half data_soft_max_f16[];
half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
// (shared memory) buffer to cache values between iterations: // (shared memory) buffer to cache values between iterations:
@ -5992,12 +6007,12 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
if (need_check && col_data + 0 >= ncols_data) { if (need_check && col_data + 0 >= ncols_data) {
val.x = -INFINITY; val.x = -INFINITY;
} else { } else {
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f); val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f) + slope*col_data;
} }
if (need_check && col_data + WARP_SIZE >= ncols_data) { if (need_check && col_data + WARP_SIZE >= ncols_data) {
val.y = -INFINITY; val.y = -INFINITY;
} else { } else {
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f); val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f) + slope*(col_data + WARP_SIZE);
} }
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) { if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
vals[col_smem] = val; vals[col_smem] = val;
@ -6087,7 +6102,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
} }
template <bool vals_smem, int ncols_template, int block_size_template> template <bool vals_smem, int ncols_template, int block_size_template>
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) { static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x; const int tid = threadIdx.x;
@ -6099,6 +6114,21 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
const int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE;
float slope = 0.0f;
// ALiBi
if (max_bias > 0.0f) {
const uint32_t n_head_kv = gridDim.x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
const float m0 = pow(2.0f, -(max_bias ) / n_head_log2);
const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
const int h = rowx/nrows_y; // head index
slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1);
}
extern __shared__ float data_soft_max_f32[]; extern __shared__ float data_soft_max_f32[];
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
// shared memory buffer to cache values between iterations: // shared memory buffer to cache values between iterations:
@ -6117,7 +6147,8 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
const int ix = rowx*ncols + col; const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col; const int iy = rowy*ncols + col;
const float val = x[ix]*scale + (y ? y[iy] : 0.0f); const float val = x[ix]*scale + (y ? y[iy] : 0.0f) + slope*col;
vals[col] = val; vals[col] = val;
max_val = max(max_val, val); max_val = max(max_val, val);
} }
@ -7589,7 +7620,7 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past); diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
} }
static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
int nth = WARP_SIZE; int nth = WARP_SIZE;
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1); const dim3 block_dims(nth, 1, 1);
@ -7599,40 +7630,40 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con
if (shmem <= g_device_caps[g_main_device].smpb) { if (shmem <= g_device_caps[g_main_device].smpb) {
switch (ncols_x) { switch (ncols_x) {
case 32: case 32:
soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
case 64: case 64:
soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
case 128: case 128:
soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
case 256: case 256:
soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
case 512: case 512:
soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
case 1024: case 1024:
soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
case 2048: case 2048:
soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
case 4096: case 4096:
soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
default: default:
soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
} }
} else { } else {
const size_t shmem_low = WARP_SIZE*sizeof(half); const size_t shmem_low = WARP_SIZE*sizeof(half);
soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
} }
} }
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
int nth = WARP_SIZE; int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1); const dim3 block_dims(nth, 1, 1);
@ -7642,36 +7673,36 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
if (shmem < g_device_caps[g_main_device].smpb) { if (shmem < g_device_caps[g_main_device].smpb) {
switch (ncols_x) { switch (ncols_x) {
case 32: case 32:
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
case 64: case 64:
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
case 128: case 128:
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
case 256: case 256:
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
case 512: case 512:
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
case 1024: case 1024:
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
case 2048: case 2048:
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
case 4096: case 4096:
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
default: default:
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
break; break;
} }
} else { } else {
const size_t shmem_low = WARP_SIZE*sizeof(float); const size_t shmem_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale); soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale, max_bias);
} }
} }
@ -9091,10 +9122,13 @@ static void ggml_cuda_op_soft_max(
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0); const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1; const int64_t nrows_y = src0->ne[1];
float scale = 1.0f; float scale = 1.0f;
memcpy(&scale, dst->op_params, sizeof(float)); float max_bias = 0.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION >= CUDART_HMAX #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION >= CUDART_HMAX
#ifdef GGML_CUDA_F16 #ifdef GGML_CUDA_F16
@ -9107,9 +9141,9 @@ static void ggml_cuda_op_soft_max(
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
if (use_f16_soft_max) { if (use_f16_soft_max) {
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream);
} else { } else {
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream); soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream);
} }
(void) dst; (void) dst;

View file

@ -373,6 +373,7 @@ kernel void kernel_soft_max(
float slope = 0.0f; float slope = 0.0f;
// ALiBi
if (max_bias > 0.0f) { if (max_bias > 0.0f) {
const uint32_t n_head_kv = ne02; const uint32_t n_head_kv = ne02;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv)); const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));

2
ggml.c
View file

@ -11516,7 +11516,7 @@ static void ggml_compute_forward_soft_max_f32(
ggml_vec_acc_f32(nc, wp, mp); ggml_vec_acc_f32(nc, wp, mp);
} }
// alibi bias // ALiBi bias
if (max_bias > 0.0f) { if (max_bias > 0.0f) {
const uint32_t h = (i1/ne01)%ne02; // head const uint32_t h = (i1/ne01)%ne02; // head
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);

View file

@ -2093,7 +2093,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (int n = 0; n < 10; ++n) { for (int n = 0; n < 10; ++n) {
int64_t ne0 = dist_ne0(rng); int64_t ne0 = dist_ne0(rng);
int64_t ne1 = dist_ne1(rng); int64_t ne1 = dist_ne1(rng);
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, 4.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f));
} }
exponent <<= 1; exponent <<= 1;