From aa18b939802a1be7f65f6d77ccc032434e5b5e01 Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 19 Sep 2023 08:51:05 +0200 Subject: [PATCH] simpler rope implementation --- ggml-cuda.cu | 69 ++++++++++++++++++++++++---------------------------- 1 file changed, 32 insertions(+), 37 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 268b3666a..14b1ecf7d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4356,15 +4356,8 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, } // rope == RoPE == rotary positional embedding -static __global__ void compute_rope_p0(const int32_t * pos, float * p0, int n, int mode, float freq_scale) { - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < n) { - int p = pos[i]; - p0[i] = (((mode & 1) == 0 ? p : 0)) * freq_scale; - } -} -static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float * p0, +static __global__ void rope_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -4376,7 +4369,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c const int i = row*ncols + col; const int i2 = row/p_delta_rows; - const float theta = p0[i2]*powf(theta_scale, col/2); + const int p = pos != nullptr ? pos[i2] : 0; + const float p0 = p * freq_scale; + const float theta = p0*powf(theta_scale, col/2); const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4387,8 +4382,8 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c dst[i + 1] = x0*sin_theta + x1*cos_theta; } -static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float * p0, - const int p_delta_rows, const float theta_scale) { +static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (col >= ncols) { @@ -4399,7 +4394,9 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco const int i = row*ncols + col/2; const int i2 = row/p_delta_rows; - const float theta = p0[i2]*powf(theta_scale, col/2); + const int p = pos != nullptr ? pos[i2] : 0; + const float p0 = p * freq_scale; + const float theta = p0*powf(theta_scale, col/2); const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4410,8 +4407,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; } -static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float * p0, - const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) { +static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale, const int n_ctx) { const int col = blockDim.x*blockIdx.x + threadIdx.x; const int half_n_dims = ncols/4; @@ -4425,9 +4422,9 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol const float col_theta_scale = powf(theta_scale, col); // FIXME: this is likely wrong - const float p = p0[i2]; + const int p = pos != nullptr ? pos[i2] : 0; - const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale; + const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale; const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4437,7 +4434,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol dst[i + 0] = x0*cos_theta - x1*sin_theta; dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; - const float block_theta = max(p - p_delta*(n_ctx - 2), 0.f)*col_theta_scale; + const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale; const float sin_block_theta = sinf(block_theta); const float cos_block_theta = cosf(block_theta); @@ -5374,31 +5371,31 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons scale_f32<<>>(x, dst, scale, k); } -static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0, +static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nrows, num_blocks_x, 1); - rope_f32<<>>(x, dst, ncols, p0, p_delta_rows, theta_scale); + rope_f32<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); } -static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0, +static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nrows, num_blocks_x, 1); - rope_neox_f32<<>>(x, dst, ncols, p0, p_delta_rows, theta_scale); + rope_neox_f32<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); } -static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0, - const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) { +static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) { GGML_ASSERT(ncols % 4 == 0); const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1); const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE; const dim3 block_nums(num_blocks_x, nrows, 1); - rope_glm_f32<<>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale, n_ctx); + rope_glm_f32<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx); } static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, @@ -6105,32 +6102,30 @@ inline void ggml_cuda_op_rope( int id; CUDA_CHECK(cudaGetDevice(&id)); - struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - if (!src1_extra->copied) { - CUDA_CHECK(cudaMemcpyAsync(src1_extra->data_device[id], src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream)); - src1_extra->copied = true; + int * pos = nullptr; + if ((mode & 1) == 0) { + struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + pos = (int *) src1_extra->data_device[id]; + if (!src1_extra->copied) { + CUDA_CHECK(cudaMemcpyAsync(pos, src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream)); + src1_extra->copied = true; + } } - size_t p0d_as; - float * p0d = (float *) ggml_cuda_pool_malloc(ne2 * sizeof(float), &p0d_as); - compute_rope_p0<<<(ne2 + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE, CUDA_ROPE_BLOCK_SIZE, 0, main_stream>>>((int32_t*)src1_extra->data_device[id], p0d, ne2, mode, freq_scale); - const bool is_neox = mode & 2; const bool is_glm = mode & 4; // compute if (is_glm) { GGML_ASSERT(false); - rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, n_ctx, main_stream); + rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream); } else if (is_neox) { GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); - rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, ne01, theta_scale, main_stream); + rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); } else { - rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, ne01, theta_scale, main_stream); + rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); } - ggml_cuda_pool_free(p0d, p0d_as); - (void) src1; (void) dst; (void) src1_dd;