correct tensors thru RoPE

This commit is contained in:
Phillip Kravtsov 2023-09-26 00:07:19 -07:00
parent 4bcf412d86
commit c9e1446f52
2 changed files with 5 additions and 5 deletions

6
ggml.c
View file

@ -12729,8 +12729,8 @@ static void ggml_compute_forward_rope_f32(
if (strncmp(src0->name, "qrot", 4) == 0 && params->ith == 0) { if (strncmp(src0->name, "qrot", 4) == 0 && params->ith == 0) {
GGML_PRINT("\nValues at RoPE time for %s\n", src0->name); GGML_PRINT("\nValues at RoPE time for %s\n", src0->name);
ggml_print_tensor(src0); ggml_print_tensor(src0);
int starts[] = {0, 0, 1, 0}; int starts[] = {0, 1, 0, 0};
ggml_print_tensor_values(src0, starts, 1, 10); ggml_print_tensor_values(src0, starts, 0, 10);
} }
float freq_base; float freq_base;
@ -12863,7 +12863,7 @@ static void ggml_compute_forward_rope_f32(
if (strncmp(src0->name, "qrot", 4) == 0 && params->ith == 0) { if (strncmp(src0->name, "qrot", 4) == 0 && params->ith == 0) {
GGML_PRINT("\n dest at RoPE time for %s\n", src0->name); GGML_PRINT("\n dest at RoPE time for %s\n", src0->name);
// print shape and strides // print shape and strides
int starts[4] = {0,0,0,0}; int starts[4] = {0,0,1,0};
ggml_print_tensor(dst); ggml_print_tensor(dst);
ggml_print_tensor_values(dst, starts, 0, 10); ggml_print_tensor_values(dst, starts, 0, 10);
} }

View file

@ -3945,13 +3945,13 @@ static struct ggml_cgraph * llm_build_adept(
struct ggml_tensor * qrotated = ggml_cont(ctx0, ggml_permute(ctx0, struct ggml_tensor * qrotated = ggml_cont(ctx0, ggml_permute(ctx0,
ggml_rope_custom_inplace( ggml_rope_custom_inplace(
ctx0, qrot, n_past, n_rot, 0, 0, freq_base, freq_scale ctx0, qrot, n_past, n_rot, 2, 0, freq_base, freq_scale
), ),
2, 1, 0, 3 2, 1, 0, 3
)); ));
struct ggml_tensor * krotated = ggml_cont(ctx0, ggml_permute(ctx0, struct ggml_tensor * krotated = ggml_cont(ctx0, ggml_permute(ctx0,
ggml_rope_custom_inplace( ggml_rope_custom_inplace(
ctx0, krot, n_past, n_rot, 0, 0, freq_base, freq_scale ctx0, krot, n_past, n_rot, 2, 0, freq_base, freq_scale
), ),
2, 1, 0, 3 2, 1, 0, 3
)); ));