llama : custom attention mask + parallel decoding + no context swaps (#3228)

* tests : verify that RoPE is "additive"

* llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)

* ggml : ggml_rope now takes a vector with positions instead of n_past

* metal : add rope_f16 kernel + optimize cpy kernels

* llama : unified KV cache + batch inference API

* llama : add new llama_decode() API that works with llama_batch

* llama : add cell_max heuristic for more efficient kv_cache

* llama : extend llama_kv_cache API

* llama : more robust cell_max heuristic + wip shift

* metal : disable concurrency optimization

* llama : add llama_kv_cache_shift_seq + no more context swaps

* llama : apply K-cache roping for Falcon and Baichuan

* speculative : fix KV cache management

* parallel : example for serving multiple users in parallel

* parallel : disable hot-plug to avoid cache fragmentation

* fixes : speculative KV cache + llama worst-case graph

* llama : extend batch API to select which logits to output

* llama : fix worst case graph build

* ggml-cuda : update rope implementation for parallel decoding (#3254)

* ggml-cuda : update rope implementation for parallel decoding

* better solution for p0 computation

* fix rope

* simpler rope implementation

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* make : add parallel to build + fix static functions in llama.cpp

* simple : fix token counting

* parallel : various improvements

* llama : fix cell_max logic + rename functions

* parallel : try smaller batches when the KV cache is fragmented

* parallel : fix sequence termination criteria

* llama : silence errors KV cache errors

* parallel : remove new line from prompt

* parallel : process system prompt once + configurable paramters + llama API

* parallel : remove question with short answers

* parallel : count cache misses

* parallel : print misses on each request

* parallel : minor

* llama : fix n_kv to never become 0

* parallel : rename hot-plug to continuous-batching

* llama : improve llama_batch API + simplify parallel example

* simple : add parallel decoding support

* simple : improve comments + free batch

* ggml-cuda : add rope f16, restore performance with parallel decoding (#3272)

* ggml-cuda : add rope f16, restore performance

* offload KQ_mask with all models

* fix rope shift

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* llama : disable MPI for now

ggml-ci

* train : make KQ_pos memory buffer permanent via dummy scale op

* ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)

ggml-ci

* parallel : fix bug (extra BOS) + smaller token_prev array

* parallel : fix cases where the input prompts can overflow the batch

* parallel : add disabled experimental batch chunking in powers of two

* llama : llama.h formatting + comments

* simple : add README.md

* llama : fix kv cache heuristic when context is less than 32

* parallel : fix crash when `-n -1`

* llama : simplify returns if/else branches

* metal : use mm kernels for batch size > 2

* examples : utilize new llama_get_logits_ith()

* examples : add example for batched decoding

* examples : do not eval prompt 2 times (close #3348)

* server : clear the KV cache beyond n_past before llama_decode

* server : avoid context swaps by shifting the KV cache

---------

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov 2023-09-28 19:04:36 +03:00 committed by GitHub
parent 45855b3f1c
commit ec893798b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 2700 additions and 673 deletions

170
ggml.c
View file

@ -6406,6 +6406,54 @@ struct ggml_tensor * ggml_cont_inplace(
return ggml_cont_impl(ctx, a, true);
}
// make contiguous, with new shape
GGML_API struct ggml_tensor * ggml_cont_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0) {
return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
}
GGML_API struct ggml_tensor * ggml_cont_2d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1) {
return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
}
GGML_API struct ggml_tensor * ggml_cont_3d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2) {
return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
}
struct ggml_tensor * ggml_cont_4d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3) {
GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
bool is_node = false;
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
ggml_format_name(result, "%s (cont)", a->name);
result->op = GGML_OP_CONT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_reshape
struct ggml_tensor * ggml_reshape(
@ -6968,7 +7016,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
static struct ggml_tensor * ggml_rope_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx,
@ -6977,7 +7025,10 @@ static struct ggml_tensor * ggml_rope_impl(
float xpos_base,
bool xpos_down,
bool inplace) {
GGML_ASSERT(n_past >= 0);
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
bool is_node = false;
if (a->grad) {
@ -6986,7 +7037,7 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
int32_t params[8] = { n_past, n_dims, mode, n_ctx };
int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
memcpy(params + 4, &freq_base, sizeof(float));
memcpy(params + 5, &freq_scale, sizeof(float));
memcpy(params + 6, &xpos_base, sizeof(float));
@ -6996,6 +7047,7 @@ static struct ggml_tensor * ggml_rope_impl(
result->op = GGML_OP_ROPE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
@ -7003,55 +7055,55 @@ static struct ggml_tensor * ggml_rope_impl(
struct ggml_tensor * ggml_rope(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
}
struct ggml_tensor * ggml_rope_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
}
struct ggml_tensor * ggml_rope_custom(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx,
float freq_base,
float freq_scale) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
}
struct ggml_tensor * ggml_rope_custom_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx,
float freq_base,
float freq_scale) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
}
struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
struct ggml_tensor * b,
int n_dims,
float base,
bool down) {
return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
}
// ggml_rope_back
@ -7059,7 +7111,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_tensor * ggml_rope_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
struct ggml_tensor * b,
int n_dims,
int mode,
int n_ctx,
@ -7067,7 +7119,10 @@ struct ggml_tensor * ggml_rope_back(
float freq_scale,
float xpos_base,
bool xpos_down) {
GGML_ASSERT(n_past >= 0);
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
bool is_node = false;
@ -7078,7 +7133,7 @@ struct ggml_tensor * ggml_rope_back(
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
int32_t params[8] = { n_past, n_dims, mode, n_ctx };
int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
memcpy(params + 4, &freq_base, sizeof(float));
memcpy(params + 5, &freq_scale, sizeof(float));
memcpy(params + 6, &xpos_base, sizeof(float));
@ -7088,6 +7143,7 @@ struct ggml_tensor * ggml_rope_back(
result->op = GGML_OP_ROPE_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
@ -8798,8 +8854,6 @@ static void ggml_compute_forward_add_f32(
#else
ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
#endif
// }
// }
}
} else {
// src1 is not contiguous
@ -12456,13 +12510,11 @@ static void ggml_compute_forward_alibi_f16(
return;
}
const int n_past = ((int32_t *) dst->op_params)[0];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
assert(n_past >= 0);
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past
const int ne2 = src0->ne[2]; // n_head -> this is k
@ -12477,7 +12529,7 @@ static void ggml_compute_forward_alibi_f16(
//const int nb3 = src0->nb[3];
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
//GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
GGML_ASSERT(n_head == ne2);
// add alibi to src0 (KQ_scaled)
@ -12623,8 +12675,8 @@ static void ggml_compute_forward_clamp(
static void ggml_compute_forward_rope_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
@ -12634,9 +12686,9 @@ static void ggml_compute_forward_rope_f32(
// these two only relevant for xPos RoPE:
float xpos_base;
bool xpos_down;
bool xpos_down;
const int n_past = ((int32_t *) dst->op_params)[0];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
@ -12645,8 +12697,6 @@ static void ggml_compute_forward_rope_f32(
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
assert(n_past >= 0);
GGML_TENSOR_UNARY_OP_LOCALS;
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
@ -12677,9 +12727,11 @@ static void ggml_compute_forward_rope_f32(
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
@ -12716,7 +12768,7 @@ static void ggml_compute_forward_rope_f32(
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
// zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
if (xpos_down) zeta = 1.0f / zeta;
theta *= theta_scale;
@ -12761,8 +12813,8 @@ static void ggml_compute_forward_rope_f32(
static void ggml_compute_forward_rope_f16(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
@ -12770,15 +12822,13 @@ static void ggml_compute_forward_rope_f16(
float freq_base;
float freq_scale;
const int n_past = ((int32_t *) dst->op_params)[0];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
assert(n_past >= 0);
GGML_TENSOR_UNARY_OP_LOCALS;
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
@ -12809,9 +12859,11 @@ static void ggml_compute_forward_rope_f16(
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
@ -12890,15 +12942,16 @@ static void ggml_compute_forward_rope_f16(
static void ggml_compute_forward_rope(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_rope_f16(params, src0, dst);
ggml_compute_forward_rope_f16(params, src0, src1, dst);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_rope_f32(params, src0, dst);
ggml_compute_forward_rope_f32(params, src0, src1, dst);
} break;
default:
{
@ -12912,6 +12965,7 @@ static void ggml_compute_forward_rope(
static void ggml_compute_forward_rope_back_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@ -12929,7 +12983,7 @@ static void ggml_compute_forward_rope_back_f32(
float xpos_base;
bool xpos_down;
const int n_past = ((int32_t *) dst->op_params)[0];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
@ -12938,8 +12992,6 @@ static void ggml_compute_forward_rope_back_f32(
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
assert(n_past >= 0);
GGML_TENSOR_UNARY_OP_LOCALS;
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
@ -12966,9 +13018,11 @@ static void ggml_compute_forward_rope_back_f32(
const bool is_neox = mode & 2;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
@ -12980,7 +13034,7 @@ static void ggml_compute_forward_rope_back_f32(
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
// zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
if (xpos_down) zeta = 1.0f / zeta;
theta *= theta_scale;
@ -13023,6 +13077,7 @@ static void ggml_compute_forward_rope_back_f32(
static void ggml_compute_forward_rope_back_f16(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@ -13033,12 +13088,10 @@ static void ggml_compute_forward_rope_back_f16(
// dx = rope_back(dy, src1)
// src0 is dy, src1 contains options
const int n_past = ((int32_t *) dst->op_params)[0];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
assert(n_past >= 0);
GGML_TENSOR_UNARY_OP_LOCALS;
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
@ -13065,9 +13118,11 @@ static void ggml_compute_forward_rope_back_f16(
const bool is_neox = mode & 2;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
@ -13119,15 +13174,16 @@ static void ggml_compute_forward_rope_back_f16(
static void ggml_compute_forward_rope_back(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_rope_back_f16(params, src0, dst);
ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_rope_back_f32(params, src0, dst);
ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
} break;
default:
{
@ -15864,11 +15920,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_ROPE:
{
ggml_compute_forward_rope(params, tensor->src[0], tensor);
ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_ROPE_BACK:
{
ggml_compute_forward_rope_back(params, tensor->src[0], tensor);
ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_ALIBI:
{
@ -16506,7 +16562,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
// necessary for llama
if (src0->grad) {
const int n_past = ((int32_t *) tensor->op_params)[0];
//const int n_past = ((int32_t *) tensor->op_params)[0];
const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3];
@ -16523,7 +16579,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad,
ggml_rope_back(ctx,
tensor->grad,
n_past,
src1,
n_dims,
mode,
n_ctx,
@ -16537,7 +16593,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_OP_ROPE_BACK:
{
if (src0->grad) {
const int n_past = ((int32_t *) tensor->op_params)[0];
//const int n_past = ((int32_t *) tensor->op_params)[0];
const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3];
@ -16554,7 +16610,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad,
ggml_rope_impl(ctx,
tensor->grad,
n_past,
src1,
n_dims,
mode,
n_ctx,