ggml : ggml_rope now takes a vector with positions instead of n_past

This commit is contained in:
Georgi Gerganov 2023-09-17 21:12:51 +03:00
parent 3b4bab6a38
commit 1fb033fd85
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
9 changed files with 270 additions and 131 deletions

View file

@ -679,15 +679,23 @@ struct ggml_tensor * llama_build_train_graphs(
}
};
// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
{
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}
// rope has so much parameters that we make a custom function for it
auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
(struct ggml_tensor * t) -> struct ggml_tensor * {
// not capturing these, to silcence warnings
const int n_past = 0;
const int rope_mode = 0;
return ggml_rope_custom(ctx,
t, n_past, n_rot, rope_mode, n_ctx,
t, KQ_pos, n_rot, rope_mode, n_ctx,
rope_freq_base, rope_freq_scale);
};