fix rope
This commit is contained in:
parent
fb92acdd6b
commit
cbe2bac281
1 changed files with 15 additions and 13 deletions
28
ggml-cuda.cu
28
ggml-cuda.cu
|
@ -4365,7 +4365,7 @@ static __global__ void compute_rope_p0(const int32_t * pos, float * p0, int n, i
|
|||
}
|
||||
|
||||
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float * p0,
|
||||
const float p_delta, const int p_delta_rows, const float theta_scale) {
|
||||
const int p_delta_rows, const float theta_scale) {
|
||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (col >= ncols) {
|
||||
|
@ -4376,7 +4376,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
|
|||
const int i = row*ncols + col;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const float theta = (p0[i2] + p_delta*i2)*powf(theta_scale, col/2);
|
||||
const float theta = p0[i2]*powf(theta_scale, col/2);
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta);
|
||||
|
||||
|
@ -4388,7 +4388,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
|
|||
}
|
||||
|
||||
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float * p0,
|
||||
const float p_delta, const int p_delta_rows, const float theta_scale) {
|
||||
const int p_delta_rows, const float theta_scale) {
|
||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (col >= ncols) {
|
||||
|
@ -4399,7 +4399,7 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
|
|||
const int i = row*ncols + col/2;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const float theta = (p0[i2] + p_delta*i2)*powf(theta_scale, col/2);
|
||||
const float theta = p0[i2]*powf(theta_scale, col/2);
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta);
|
||||
|
||||
|
@ -4424,7 +4424,8 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
|
|||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const float col_theta_scale = powf(theta_scale, col);
|
||||
const float p = p0[i2] + p_delta*i2;
|
||||
// FIXME: this is likely wrong
|
||||
const float p = p0[i2];
|
||||
|
||||
const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale;
|
||||
const float sin_theta = sinf(theta);
|
||||
|
@ -5374,21 +5375,21 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
|
|||
}
|
||||
|
||||
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0,
|
||||
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
||||
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta_rows, theta_scale);
|
||||
}
|
||||
|
||||
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0,
|
||||
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
||||
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta_rows, theta_scale);
|
||||
}
|
||||
|
||||
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float * p0,
|
||||
|
@ -6095,7 +6096,7 @@ inline void ggml_cuda_op_rope(
|
|||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
//const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
|
||||
// const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(src1->ne[0] == ne2);
|
||||
|
@ -6110,7 +6111,7 @@ inline void ggml_cuda_op_rope(
|
|||
src1_extra->copied = true;
|
||||
}
|
||||
|
||||
size_t p0d_as = 0;
|
||||
size_t p0d_as;
|
||||
float * p0d = (float *) ggml_cuda_pool_malloc(ne2 * sizeof(float), &p0d_as);
|
||||
compute_rope_p0<<<(ne2 + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE, CUDA_ROPE_BLOCK_SIZE, 0, main_stream>>>((int32_t*)src1_extra->data_device[id], p0d, ne2, mode, freq_scale);
|
||||
|
||||
|
@ -6119,12 +6120,13 @@ inline void ggml_cuda_op_rope(
|
|||
|
||||
// compute
|
||||
if (is_glm) {
|
||||
GGML_ASSERT(false);
|
||||
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, n_ctx, main_stream);
|
||||
} else if (is_neox) {
|
||||
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
|
||||
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, main_stream);
|
||||
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, ne01, theta_scale, main_stream);
|
||||
} else {
|
||||
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, freq_scale, ne01, theta_scale, main_stream);
|
||||
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0d, ne01, theta_scale, main_stream);
|
||||
}
|
||||
|
||||
ggml_cuda_pool_free(p0d, p0d_as);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue