llama : unified KV cache + batch inference API
This commit is contained in:
parent
fad56936d4
commit
d29e76937c
10 changed files with 315 additions and 236 deletions
6
ggml.c
6
ggml.c
|
@ -12462,13 +12462,11 @@ static void ggml_compute_forward_alibi_f16(
|
|||
return;
|
||||
}
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
||||
float max_bias;
|
||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
assert(n_past >= 0);
|
||||
|
||||
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
||||
const int ne1 = src0->ne[1]; // seq_len_without_past
|
||||
const int ne2 = src0->ne[2]; // n_head -> this is k
|
||||
|
@ -12483,7 +12481,7 @@ static void ggml_compute_forward_alibi_f16(
|
|||
//const int nb3 = src0->nb[3];
|
||||
|
||||
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
|
||||
//GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
|
||||
GGML_ASSERT(n_head == ne2);
|
||||
|
||||
// add alibi to src0 (KQ_scaled)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue