diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 08428ea3f..9ead57648 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5,6 +5,7 @@ #include #include #include +#include #if defined(GGML_USE_HIPBLAS) #include @@ -4355,7 +4356,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, } // rope == RoPE == rotary positional embedding -static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0, +static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float * p0, const float p_delta, const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -4365,8 +4366,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c const int row = blockDim.x*blockIdx.x + threadIdx.x; const int i = row*ncols + col; + const int i2 = row/p_delta_rows; - const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); + const float theta = (p0[i2] + p_delta*i2)*powf(theta_scale, col/2); const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4377,7 +4379,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c dst[i + 1] = x0*sin_theta + x1*cos_theta; } -static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0, +static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float * p0, const float p_delta, const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -4387,8 +4389,9 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco const int row = blockDim.x*blockIdx.x + threadIdx.x; const int i = row*ncols + col/2; + const int i2 = row/p_delta_rows; - const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); + const float theta = (p0[i2] + p_delta*i2)*powf(theta_scale, col/2); const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -4399,7 +4402,7 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; } -static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0, +static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float * p0, const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) { const int col = blockDim.x*blockIdx.x + threadIdx.x; const int half_n_dims = ncols/4; @@ -4410,9 +4413,10 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol const int row = blockDim.y*blockIdx.y + threadIdx.y; const int i = row*ncols + col; + const int i2 = row/p_delta_rows; const float col_theta_scale = powf(theta_scale, col); - const float p = p0 + p_delta*(row/p_delta_rows); + const float p = p0[i2] + p_delta*i2; const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale; const float sin_theta = sinf(theta); @@ -5361,7 +5365,7 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons scale_f32<<>>(x, dst, scale, k); } -static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, +static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0, const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); @@ -5370,7 +5374,7 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i rope_f32<<>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); } -static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, +static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0, const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); @@ -5379,7 +5383,7 @@ static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, co rope_neox_f32<<>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); } -static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, +static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0, const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) { GGML_ASSERT(ncols % 4 == 0); const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1); @@ -6069,9 +6073,10 @@ inline void ggml_cuda_op_rope( const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; + const int64_t ne2 = dst->ne[2]; const int64_t nrows = ggml_nrows(src0); - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; @@ -6082,21 +6087,38 @@ inline void ggml_cuda_op_rope( memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; + //const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; + + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(src1->ne[0] == ne2); + + std::vector p0s(ne2); + for (int64_t i = 0; i < ne2; ++i) { + int n_past = ((int32_t *) src1->data)[i]; + p0s[i] = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; + } + + size_t p0d_as = 0; + float * p0d; + + p0d = (float *) ggml_cuda_pool_malloc(ne2 * sizeof(float), &p0d_as); + CUDA_CHECK(cudaMemcpyAsync(p0d, p0s.data(), ne2 * sizeof(float), cudaMemcpyHostToDevice, main_stream)); const bool is_neox = mode & 2; const bool is_glm = mode & 4; // compute if (is_glm) { - rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, n_ctx, main_stream); + rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, n_ctx, main_stream); } else if (is_neox) { GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); - rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream); + rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, main_stream); } else { - rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream); + rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, main_stream); } + ggml_cuda_pool_free(p0d, p0d_as); + (void) src1; (void) dst; (void) src1_dd;