YaRN : fix missing parameter in CUDA impl
This commit is contained in:
parent
cf731d5648
commit
dcb058ce5d
1 changed files with 8 additions and 8 deletions
16
ggml-cuda.cu
16
ggml-cuda.cu
|
@ -4073,8 +4073,8 @@ static __device__ void rope_yarn(
|
||||||
|
|
||||||
// rope == RoPE == rotary positional embedding
|
// rope == RoPE == rotary positional embedding
|
||||||
static __global__ void rope_f32(
|
static __global__ void rope_f32(
|
||||||
const float * x, float * dst, const int ncols, const float freq_scale, const float ext_factor,
|
float * x, float * dst, int ncols, float freq_scale, float ext_factor, float attn_factor, float theta_scale,
|
||||||
const float theta_scale, const float p0, const int p_delta_rows, const rope_corr_dims corr_dims
|
float p0, int p_delta_rows, rope_corr_dims corr_dims
|
||||||
) {
|
) {
|
||||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||||
|
|
||||||
|
@ -4086,7 +4086,7 @@ static __global__ void rope_f32(
|
||||||
const int i = row*ncols + col;
|
const int i = row*ncols + col;
|
||||||
|
|
||||||
const float p = p0 + row / p_delta_rows;
|
const float p = p0 + row / p_delta_rows;
|
||||||
const float theta_base = p*powf(theta_scale, col/2);
|
const float theta_base = p*powf(theta_scale, col/2);
|
||||||
|
|
||||||
float cos_theta, sin_theta;
|
float cos_theta, sin_theta;
|
||||||
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||||
|
@ -5001,15 +5001,15 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rope_f32_cuda(
|
static void rope_f32_cuda(
|
||||||
const float * x, float * dst, const int ncols, const int nrows, const float freq_scale, const float ext_factor,
|
float * x, float * dst, int ncols, int nrows, float freq_scale, float ext_factor, float attn_factor,
|
||||||
const float theta_scale, const float p0, const int p_delta_rows, const rope_corr_dims corr_dims, cudaStream_t stream
|
float theta_scale, float p0, int p_delta_rows, rope_corr_dims corr_dims, cudaStream_t stream
|
||||||
) {
|
) {
|
||||||
GGML_ASSERT(ncols % 2 == 0);
|
GGML_ASSERT(ncols % 2 == 0);
|
||||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||||
rope_f32<<<block_nums, block_dims, 0, stream>>>(
|
rope_f32<<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ncols, freq_scale, ext_factor, theta_scale, p0, p_delta_rows, corr_dims
|
x, dst, ncols, freq_scale, ext_factor, attn_factor, theta_scale, p0, p_delta_rows, corr_dims
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5785,8 +5785,8 @@ inline void ggml_cuda_op_rope(
|
||||||
ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims.v);
|
ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||||
|
|
||||||
rope_f32_cuda(
|
rope_f32_cuda(
|
||||||
src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ext_factor, theta_scale, p0, ne01, corr_dims,
|
src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ext_factor, attn_factor, theta_scale, p0, ne01,
|
||||||
cudaStream_main
|
corr_dims, cudaStream_main
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue