llama : switch to floating-point token positions
ggml-ci
This commit is contained in:
parent
15499eb942
commit
fc775366f1
14 changed files with 68 additions and 61 deletions
|
@ -1674,7 +1674,7 @@ static void rope_yarn_corr_dims(
|
|||
|
||||
typedef void (rope_t)(
|
||||
device const void * src0,
|
||||
device const int32_t * src1,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -1709,7 +1709,7 @@ typedef void (rope_t)(
|
|||
template<typename T>
|
||||
kernel void kernel_rope(
|
||||
device const void * src0,
|
||||
device const int32_t * src1,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -1749,11 +1749,11 @@ kernel void kernel_rope(
|
|||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = src1;
|
||||
device const float * pos = src1;
|
||||
|
||||
const int64_t p = pos[i2];
|
||||
const float p = pos[i2];
|
||||
|
||||
const float theta_0 = (float)p;
|
||||
const float theta_0 = p;
|
||||
const float inv_ndims = -1.f/n_dims;
|
||||
|
||||
if (!is_neox) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue