llama : custom attention mask + parallel decoding + no context swaps (#3228)

* tests : verify that RoPE is "additive"

* llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)

* ggml : ggml_rope now takes a vector with positions instead of n_past

* metal : add rope_f16 kernel + optimize cpy kernels

* llama : unified KV cache + batch inference API

* llama : add new llama_decode() API that works with llama_batch

* llama : add cell_max heuristic for more efficient kv_cache

* llama : extend llama_kv_cache API

* llama : more robust cell_max heuristic + wip shift

* metal : disable concurrency optimization

* llama : add llama_kv_cache_shift_seq + no more context swaps

* llama : apply K-cache roping for Falcon and Baichuan

* speculative : fix KV cache management

* parallel : example for serving multiple users in parallel

* parallel : disable hot-plug to avoid cache fragmentation

* fixes : speculative KV cache + llama worst-case graph

* llama : extend batch API to select which logits to output

* llama : fix worst case graph build

* ggml-cuda : update rope implementation for parallel decoding (#3254)

* ggml-cuda : update rope implementation for parallel decoding

* better solution for p0 computation

* fix rope

* simpler rope implementation

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* make : add parallel to build + fix static functions in llama.cpp

* simple : fix token counting

* parallel : various improvements

* llama : fix cell_max logic + rename functions

* parallel : try smaller batches when the KV cache is fragmented

* parallel : fix sequence termination criteria

* llama : silence errors KV cache errors

* parallel : remove new line from prompt

* parallel : process system prompt once + configurable paramters + llama API

* parallel : remove question with short answers

* parallel : count cache misses

* parallel : print misses on each request

* parallel : minor

* llama : fix n_kv to never become 0

* parallel : rename hot-plug to continuous-batching

* llama : improve llama_batch API + simplify parallel example

* simple : add parallel decoding support

* simple : improve comments + free batch

* ggml-cuda : add rope f16, restore performance with parallel decoding (#3272)

* ggml-cuda : add rope f16, restore performance

* offload KQ_mask with all models

* fix rope shift

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* llama : disable MPI for now

ggml-ci

* train : make KQ_pos memory buffer permanent via dummy scale op

* ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)

ggml-ci

* parallel : fix bug (extra BOS) + smaller token_prev array

* parallel : fix cases where the input prompts can overflow the batch

* parallel : add disabled experimental batch chunking in powers of two

* llama : llama.h formatting + comments

* simple : add README.md

* llama : fix kv cache heuristic when context is less than 32

* parallel : fix crash when `-n -1`

* llama : simplify returns if/else branches

* metal : use mm kernels for batch size > 2

* examples : utilize new llama_get_logits_ith()

* examples : add example for batched decoding

* examples : do not eval prompt 2 times (close #3348)

* server : clear the KV cache beyond n_past before llama_decode

* server : avoid context swaps by shifting the KV cache

---------

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov 2023-09-28 19:04:36 +03:00 committed by GitHub
parent 45855b3f1c
commit ec893798b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 2700 additions and 673 deletions

View file

@ -24,12 +24,59 @@ typedef struct {
int8_t qs[QK8_0]; // quants
} block_q8_0;
// general-purpose kernel for addition of two tensors
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
// cons: not very efficient
kernel void kernel_add(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig];
device const char * src0,
device const char * src1,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant int64_t & nb00,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant int64_t & nb0,
constant int64_t & nb1,
constant int64_t & nb2,
constant int64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig.z;
const int64_t i02 = tgpig.y;
const int64_t i01 = tgpig.x;
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
src0_ptr += ntg.x*nb00;
src1_ptr += ntg.x*nb10;
dst_ptr += ntg.x*nb0;
}
}
// assumption: src1 is a row
@ -38,7 +85,7 @@ kernel void kernel_add_row(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
constant int64_t & nb,
constant int64_t & nb [[buffer(27)]],
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig % nb];
}
@ -806,30 +853,61 @@ kernel void kernel_alibi_f32(
}
}
typedef void (rope_t)(
device const void * src0,
device const int32_t * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant float & freq_base,
constant float & freq_scale,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]);
template<typename T>
kernel void kernel_rope(
device const void * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant float & freq_base,
constant float & freq_scale,
device const void * src0,
device const int32_t * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int & n_past,
constant int & n_dims,
constant int & mode,
constant float & freq_base,
constant float & freq_scale,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
@ -839,7 +917,9 @@ kernel void kernel_rope(
const bool is_neox = mode & 2;
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
device const int32_t * pos = src1;
const int64_t p = pos[i2];
const float theta_0 = freq_scale * (float)p;
const float inv_ndims = -1.f/n_dims;
@ -851,11 +931,11 @@ kernel void kernel_rope(
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0];
const float x1 = src[1];
const T x0 = src[0];
const T x1 = src[1];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
@ -870,8 +950,8 @@ kernel void kernel_rope(
const int64_t i0 = ib*n_dims + ic/2;
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
@ -883,6 +963,9 @@ kernel void kernel_rope(
}
}
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
kernel void kernel_cpy_f16_f16(
device const half * src0,
device half * dst,
@ -1273,8 +1356,8 @@ kernel void kernel_mul_mat_q3_K_f32(
float yl[32];
const uint16_t kmask1 = 0x3030;
const uint16_t kmask2 = 0x0f0f;
//const uint16_t kmask1 = 0x3030;
//const uint16_t kmask2 = 0x0f0f;
const int tid = tiisg/4;
const int ix = tiisg%4;