llama : unified KV cache + batch inference API

This commit is contained in:
Georgi Gerganov 2023-09-18 10:08:22 +03:00
parent fad56936d4
commit d29e76937c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
10 changed files with 315 additions and 236 deletions

6
ggml.c
View file

@ -12462,13 +12462,11 @@ static void ggml_compute_forward_alibi_f16(
return;
}
const int n_past = ((int32_t *) dst->op_params)[0];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
assert(n_past >= 0);
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past
const int ne2 = src0->ne[2]; // n_head -> this is k
@ -12483,7 +12481,7 @@ static void ggml_compute_forward_alibi_f16(
//const int nb3 = src0->nb[3];
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
//GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
GGML_ASSERT(n_head == ne2);
// add alibi to src0 (KQ_scaled)