YaRN : fix missing parameter in CUDA impl

This commit is contained in:
Cebtenzzre 2023-09-05 14:17:50 -04:00
parent cf731d5648
commit dcb058ce5d

View file

@ -4073,8 +4073,8 @@ static __device__ void rope_yarn(
// rope == RoPE == rotary positional embedding // rope == RoPE == rotary positional embedding
static __global__ void rope_f32( static __global__ void rope_f32(
const float * x, float * dst, const int ncols, const float freq_scale, const float ext_factor, float * x, float * dst, int ncols, float freq_scale, float ext_factor, float attn_factor, float theta_scale,
const float theta_scale, const float p0, const int p_delta_rows, const rope_corr_dims corr_dims float p0, int p_delta_rows, rope_corr_dims corr_dims
) { ) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@ -4086,7 +4086,7 @@ static __global__ void rope_f32(
const int i = row*ncols + col; const int i = row*ncols + col;
const float p = p0 + row / p_delta_rows; const float p = p0 + row / p_delta_rows;
const float theta_base = p*powf(theta_scale, col/2); const float theta_base = p*powf(theta_scale, col/2);
float cos_theta, sin_theta; float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta); rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
@ -5001,15 +5001,15 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
} }
static void rope_f32_cuda( static void rope_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const float freq_scale, const float ext_factor, float * x, float * dst, int ncols, int nrows, float freq_scale, float ext_factor, float attn_factor,
const float theta_scale, const float p0, const int p_delta_rows, const rope_corr_dims corr_dims, cudaStream_t stream float theta_scale, float p0, int p_delta_rows, rope_corr_dims corr_dims, cudaStream_t stream
) { ) {
GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1); const dim3 block_nums(nrows, num_blocks_x, 1);
rope_f32<<<block_nums, block_dims, 0, stream>>>( rope_f32<<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, freq_scale, ext_factor, theta_scale, p0, p_delta_rows, corr_dims x, dst, ncols, freq_scale, ext_factor, attn_factor, theta_scale, p0, p_delta_rows, corr_dims
); );
} }
@ -5785,8 +5785,8 @@ inline void ggml_cuda_op_rope(
ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims.v); ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims.v);
rope_f32_cuda( rope_f32_cuda(
src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ext_factor, theta_scale, p0, ne01, corr_dims, src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ext_factor, attn_factor, theta_scale, p0, ne01,
cudaStream_main corr_dims, cudaStream_main
); );
} }