tests : verify that RoPE is "additive"
This commit is contained in:
parent
80291a1d02
commit
c5df72e848
3 changed files with 217 additions and 5 deletions
9
ggml.c
9
ggml.c
|
@ -6977,7 +6977,6 @@ static struct ggml_tensor * ggml_rope_impl(
|
|||
float xpos_base,
|
||||
bool xpos_down,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(n_past >= 0);
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
|
@ -12645,8 +12644,6 @@ static void ggml_compute_forward_rope_f32(
|
|||
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
|
||||
|
||||
assert(n_past >= 0);
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
||||
|
@ -12674,12 +12671,14 @@ static void ggml_compute_forward_rope_f32(
|
|||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
||||
const bool is_skip = mode & 1;
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
const bool is_diff = mode & 8; // TODO: temporary
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
||||
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
||||
for (int64_t i2 = (is_skip ? n_past : 0); i2 < ne2; i2++) {
|
||||
const int64_t p = is_diff ? n_past : is_skip ? i2 : n_past + i2;
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
if (ir++ < ir0) continue;
|
||||
if (ir > ir1) break;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue