ggml : ggml_rope now takes a vector with positions instead of n_past
This commit is contained in:
parent
3b4bab6a38
commit
1fb033fd85
9 changed files with 270 additions and 131 deletions
|
@ -679,15 +679,23 @@ struct ggml_tensor * llama_build_train_graphs(
|
|||
}
|
||||
};
|
||||
|
||||
// KQ_pos - contains the positions
|
||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
|
||||
{
|
||||
int * data = (int *) KQ_pos->data;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
data[i] = n_past + i;
|
||||
}
|
||||
}
|
||||
|
||||
// rope has so much parameters that we make a custom function for it
|
||||
auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
|
||||
auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
|
||||
(struct ggml_tensor * t) -> struct ggml_tensor * {
|
||||
// not capturing these, to silcence warnings
|
||||
const int n_past = 0;
|
||||
const int rope_mode = 0;
|
||||
|
||||
return ggml_rope_custom(ctx,
|
||||
t, n_past, n_rot, rope_mode, n_ctx,
|
||||
t, KQ_pos, n_rot, rope_mode, n_ctx,
|
||||
rope_freq_base, rope_freq_scale);
|
||||
};
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue