tests : verify that RoPE is "additive"

This commit is contained in:
Georgi Gerganov 2023-09-17 17:54:14 +03:00
parent 80291a1d02
commit c5df72e848
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 217 additions and 5 deletions

9
ggml.c
View file

@ -6977,7 +6977,6 @@ static struct ggml_tensor * ggml_rope_impl(
float xpos_base,
bool xpos_down,
bool inplace) {
GGML_ASSERT(n_past >= 0);
bool is_node = false;
if (a->grad) {
@ -12645,8 +12644,6 @@ static void ggml_compute_forward_rope_f32(
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
assert(n_past >= 0);
GGML_TENSOR_UNARY_OP_LOCALS;
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
@ -12674,12 +12671,14 @@ static void ggml_compute_forward_rope_f32(
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const bool is_skip = mode & 1;
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
const bool is_diff = mode & 8; // TODO: temporary
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
for (int64_t i2 = (is_skip ? n_past : 0); i2 < ne2; i2++) {
const int64_t p = is_diff ? n_past : is_skip ? i2 : n_past + i2;
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;