llama : switch to floating-point token positions

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-02-23 12:18:30 +02:00
parent 15499eb942
commit fc775366f1
No known key found for this signature in database
GPG key ID: BF970631944C16B7
14 changed files with 68 additions and 61 deletions

View file

@ -1674,7 +1674,7 @@ static void rope_yarn_corr_dims(
typedef void (rope_t)(
device const void * src0,
device const int32_t * src1,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
@ -1709,7 +1709,7 @@ typedef void (rope_t)(
template<typename T>
kernel void kernel_rope(
device const void * src0,
device const int32_t * src1,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
@ -1749,11 +1749,11 @@ kernel void kernel_rope(
float corr_dims[2];
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
device const int32_t * pos = src1;
device const float * pos = src1;
const int64_t p = pos[i2];
const float p = pos[i2];
const float theta_0 = (float)p;
const float theta_0 = p;
const float inv_ndims = -1.f/n_dims;
if (!is_neox) {