ggml : ggml_rope now takes a vector with positions instead of n_past

This commit is contained in:
Georgi Gerganov 2023-09-17 21:12:51 +03:00
parent 3b4bab6a38
commit 1fb033fd85
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
9 changed files with 270 additions and 131 deletions

View file

@ -1404,6 +1404,11 @@ int main(int argc, const char ** argv) {
for (int n_past = 1; n_past < ne2[2]; ++n_past) {
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]);
for (int i = 0; i < ne2[2]; ++i) {
((int32_t *) p->data)[i] = n_past + i;
}
ggml_set_param(ctx0, x[0]);
const bool skip_past = (mode & 1);
@ -1415,7 +1420,7 @@ int main(int argc, const char ** argv) {
continue;
}
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0));
GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
@ -1438,6 +1443,11 @@ int main(int argc, const char ** argv) {
for (int n_past = 1; n_past < ne2[2]; ++n_past) {
x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f);
struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]);
for (int i = 0; i < ne2[2]; ++i) {
((int32_t *) p->data)[i] = n_past + i;
}
ggml_set_param(ctx0, x[0]);
const bool skip_past = (mode & 1);
@ -1449,7 +1459,7 @@ int main(int argc, const char ** argv) {
continue;
}
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0));
GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);

View file

@ -144,7 +144,17 @@ int main(int /*argc*/, const char ** /*argv*/) {
const int64_t ne[4] = { 2*n_rot, 32, 73, 1 };
const int n_past_0 = 100;
const int n_past_1 = 33;
const int n_past_2 = 33;
struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
for (int i = 0; i < ne[2]; ++i) {
((int32_t *) p0->data)[i] = n_past_0 + i;
((int32_t *) p1->data)[i] = n_past_2 - n_past_0;
((int32_t *) p2->data)[i] = n_past_2 + i;
}
// test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
const int mode = m == 0 ? 0 : m == 1 ? 2 : 4;
@ -152,12 +162,12 @@ int main(int /*argc*/, const char ** /*argv*/) {
x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
// 100, 101, 102, ..., 172
struct ggml_tensor * r0 = ggml_rope(ctx0, x, n_past_0, n_rot, mode, 1024);
struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode, 1024);
// -67, -67, -67, ..., -67
struct ggml_tensor * r1 = ggml_rope(ctx0, r0, n_past_1 - n_past_0, n_rot, mode + 8, 1024); // diff mode
struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode, 1024); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
// 33, 34, 35, ..., 105
struct ggml_tensor * r2 = ggml_rope(ctx0, x, n_past_1, n_rot, mode, 1024);
struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode, 1024);
ggml_cgraph * gf = ggml_new_graph(ctx0);