modified rope for cuda

This commit is contained in:
Concedo 2023-07-19 14:16:27 +08:00
parent 374fffb9c6
commit 0d7240b320

View file

@ -2972,13 +2972,19 @@ inline void ggml_cuda_op_rope(
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low; const int64_t i01_diff = i01_high - i01_low;
float freq_base;
float freq_scale;
const int n_past = ((int32_t *) src1->data)[0]; const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1]; const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2]; const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int32_t *) src1->data)[3]; const int n_ctx = ((int32_t *) src1->data)[3];
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
const float theta_scale = powf(10000.0, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float p = ((mode & 1) == 0 ? n_past + i02 : i02); const float p0 = ((mode & 1) == 0 ? n_past + i02 : i02);
const float p = p0 * freq_scale;
bool is_glm = mode & 4; bool is_glm = mode & 4;