modified rope for cuda
This commit is contained in:
parent
374fffb9c6
commit
0d7240b320
1 changed files with 8 additions and 2 deletions
10
ggml-cuda.cu
10
ggml-cuda.cu
|
@ -2972,13 +2972,19 @@ inline void ggml_cuda_op_rope(
|
|||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t i01_diff = i01_high - i01_low;
|
||||
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
|
||||
const int n_past = ((int32_t *) src1->data)[0];
|
||||
const int n_dims = ((int32_t *) src1->data)[1];
|
||||
const int mode = ((int32_t *) src1->data)[2];
|
||||
const int n_ctx = ((int32_t *) src1->data)[3];
|
||||
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
|
||||
|
||||
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
||||
const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
const float p0 = ((mode & 1) == 0 ? n_past + i02 : i02);
|
||||
const float p = p0 * freq_scale;
|
||||
|
||||
bool is_glm = mode & 4;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue