YaRN : fix missing parameter in CUDA impl
This commit is contained in:
parent
cf731d5648
commit
dcb058ce5d
1 changed files with 8 additions and 8 deletions
14
ggml-cuda.cu
14
ggml-cuda.cu
|
@ -4073,8 +4073,8 @@ static __device__ void rope_yarn(
|
|||
|
||||
// rope == RoPE == rotary positional embedding
|
||||
static __global__ void rope_f32(
|
||||
const float * x, float * dst, const int ncols, const float freq_scale, const float ext_factor,
|
||||
const float theta_scale, const float p0, const int p_delta_rows, const rope_corr_dims corr_dims
|
||||
float * x, float * dst, int ncols, float freq_scale, float ext_factor, float attn_factor, float theta_scale,
|
||||
float p0, int p_delta_rows, rope_corr_dims corr_dims
|
||||
) {
|
||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
|
@ -5001,15 +5001,15 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
|
|||
}
|
||||
|
||||
static void rope_f32_cuda(
|
||||
const float * x, float * dst, const int ncols, const int nrows, const float freq_scale, const float ext_factor,
|
||||
const float theta_scale, const float p0, const int p_delta_rows, const rope_corr_dims corr_dims, cudaStream_t stream
|
||||
float * x, float * dst, int ncols, int nrows, float freq_scale, float ext_factor, float attn_factor,
|
||||
float theta_scale, float p0, int p_delta_rows, rope_corr_dims corr_dims, cudaStream_t stream
|
||||
) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
const dim3 block_nums(nrows, num_blocks_x, 1);
|
||||
rope_f32<<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, freq_scale, ext_factor, theta_scale, p0, p_delta_rows, corr_dims
|
||||
x, dst, ncols, freq_scale, ext_factor, attn_factor, theta_scale, p0, p_delta_rows, corr_dims
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -5785,8 +5785,8 @@ inline void ggml_cuda_op_rope(
|
|||
ggml_rope_yarn_corr_dims(n_dims, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||
|
||||
rope_f32_cuda(
|
||||
src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ext_factor, theta_scale, p0, ne01, corr_dims,
|
||||
cudaStream_main
|
||||
src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ext_factor, attn_factor, theta_scale, p0, ne01,
|
||||
corr_dims, cudaStream_main
|
||||
);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue