From c989fd06b1c035489538cfa55366f1217eebe311 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 May 2024 13:49:46 +0300 Subject: [PATCH] cuda : better rope implementation ggml-ci --- ggml-cuda/rope.cu | 214 ++++++++++++++++++------------------- tests/test-backend-ops.cpp | 1 + 2 files changed, 106 insertions(+), 109 deletions(-) diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index 7eb39dc86..58d2b2f7d 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -1,7 +1,7 @@ #include "rope.cuh" struct rope_corr_dims { - float v[4]; + float v[4]; // TODO: is there any reson for this to be 4 instead of 2? }; static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) { @@ -13,8 +13,7 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static __device__ void rope_yarn( float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale, - float * cos_theta, float * sin_theta -) { + float * cos_theta, float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; @@ -29,27 +28,38 @@ static __device__ void rope_yarn( *sin_theta = sinf(theta) * mscale; } -// rope == RoPE == rotary positional embedding -template -static __global__ void rope( - const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, - float ext_factor, float attn_factor, rope_corr_dims corr_dims -) { - const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); +template +static __global__ void rope_norm( + const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (col >= ncols) { + if (i0 >= ne0) { return; } const int row = blockDim.x*blockIdx.x + threadIdx.x; - const int i = row*ncols + col; + + if (i0 >= n_dims) { + const int i = row*ne0 + i0; + + dst[i + 0] = x[i + 0]; + dst[i + 1] = x[i + 1]; + + return; + } + + const int i = row*ne0 + i0; const int i2 = row/p_delta_rows; - const int p = has_pos ? pos[i2] : 0; - const float theta_base = p*powf(freq_base, -float(col)/ncols); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); - float cos_theta, sin_theta; - rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta); + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); const float x0 = x[i + 0]; const float x1 = x[i + 1]; @@ -58,23 +68,20 @@ static __global__ void rope( dst[i + 1] = x0*sin_theta + x1*cos_theta; } -template +template static __global__ void rope_neox( - const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors -) { - const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); + const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (col >= ncols) { + if (i0 >= ne0) { return; } const int row = blockDim.x*blockIdx.x + threadIdx.x; - const int ib = col / n_dims; - const int ic = col % n_dims; - if (ib > 0) { - const int i = row*ncols + ib*n_dims + ic; + if (i0 >= n_dims) { + const int i = row*ne0 + i0; dst[i + 0] = x[i + 0]; dst[i + 1] = x[i + 1]; @@ -82,16 +89,17 @@ static __global__ void rope_neox( return; } - const int i = row*ncols + ib*n_dims + ic/2; + const int i = row*ne0 + i0/2; const int i2 = row/p_delta_rows; - const int p = has_pos ? pos[i2] : 0; - const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f; + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); - const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor; + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; - float cos_theta, sin_theta; - rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta); + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); const float x0 = x[i + 0]; const float x1 = x[i + n_dims/2]; @@ -100,93 +108,81 @@ static __global__ void rope_neox( dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; } - template -static void rope_cuda( - const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream -) { - GGML_ASSERT(ncols % 2 == 0); +static void rope_norm_cuda( + const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); - const dim3 block_nums(nrows, num_blocks_x, 1); - if (pos == nullptr) { - rope<<>>( - x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims - ); + const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, 1); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + if (freq_factors == nullptr) { + rope_norm<<>>( + x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors + ); } else { - rope<<>>( - x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims - ); + rope_norm<<>>( + x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors + ); } } template static void rope_neox_cuda( - const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream -) { - GGML_ASSERT(ncols % 2 == 0); + const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); - const dim3 block_nums(nrows, num_blocks_x, 1); + const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, 1); const float theta_scale = powf(freq_base, -2.0f/n_dims); - if (pos == nullptr) { - if (freq_factors == nullptr) { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + if (freq_factors == nullptr) { + rope_neox<<>>( + x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors ); - } else { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); - } } else { - if (freq_factors == nullptr) { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + rope_neox<<>>( + x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors ); - } else { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); - } } } -static void rope_cuda_f16( - const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) { +static void rope_norm_cuda_f16( + const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - rope_cuda(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); + rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } -static void rope_cuda_f32( - const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) { +static void rope_norm_cuda_f32( + const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - rope_cuda(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); + rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } static void rope_neox_cuda_f16( - const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, + const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); + rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } static void rope_neox_cuda_f32( - const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, + const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream ) { - rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); + rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -207,16 +203,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; - const int64_t nrows = ggml_nrows(src0); + const int64_t nr = ggml_nrows(src0); - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - //const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; // RoPE alteration for extended context - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -224,19 +226,13 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - const float * freq_factors = nullptr; - const int32_t * pos = nullptr; - const bool is_neox = mode & 2; - pos = (const int32_t *) src1_d; + const int32_t * pos = (const int32_t *) src1_d; - if (is_neox) { - if (src2 != nullptr) { - freq_factors = (const float *) src2->data; - } - } else { - GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox"); + const float * freq_factors = nullptr; + if (src2 != nullptr) { + freq_factors = (const float *) src2->data; } rope_corr_dims corr_dims; @@ -246,12 +242,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (is_neox) { if (src0->type == GGML_TYPE_F32) { rope_neox_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, + (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream ); } else if (src0->type == GGML_TYPE_F16) { rope_neox_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, + (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream ); } else { @@ -259,14 +255,14 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } } else { if (src0->type == GGML_TYPE_F32) { - rope_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + rope_norm_cuda_f32( + (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream ); } else if (src0->type == GGML_TYPE_F16) { - rope_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + rope_norm_cuda_f16( + (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream ); } else { GGML_ASSERT(false); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 58a3bd6a4..481c70504 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2256,6 +2256,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B) } } + all = false; } }