llama : custom attention mask + parallel decoding + no context swaps (#3228)

* tests : verify that RoPE is "additive"

* llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)

* ggml : ggml_rope now takes a vector with positions instead of n_past

* metal : add rope_f16 kernel + optimize cpy kernels

* llama : unified KV cache + batch inference API

* llama : add new llama_decode() API that works with llama_batch

* llama : add cell_max heuristic for more efficient kv_cache

* llama : extend llama_kv_cache API

* llama : more robust cell_max heuristic + wip shift

* metal : disable concurrency optimization

* llama : add llama_kv_cache_shift_seq + no more context swaps

* llama : apply K-cache roping for Falcon and Baichuan

* speculative : fix KV cache management

* parallel : example for serving multiple users in parallel

* parallel : disable hot-plug to avoid cache fragmentation

* fixes : speculative KV cache + llama worst-case graph

* llama : extend batch API to select which logits to output

* llama : fix worst case graph build

* ggml-cuda : update rope implementation for parallel decoding (#3254)

* ggml-cuda : update rope implementation for parallel decoding

* better solution for p0 computation

* fix rope

* simpler rope implementation

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* make : add parallel to build + fix static functions in llama.cpp

* simple : fix token counting

* parallel : various improvements

* llama : fix cell_max logic + rename functions

* parallel : try smaller batches when the KV cache is fragmented

* parallel : fix sequence termination criteria

* llama : silence errors KV cache errors

* parallel : remove new line from prompt

* parallel : process system prompt once + configurable paramters + llama API

* parallel : remove question with short answers

* parallel : count cache misses

* parallel : print misses on each request

* parallel : minor

* llama : fix n_kv to never become 0

* parallel : rename hot-plug to continuous-batching

* llama : improve llama_batch API + simplify parallel example

* simple : add parallel decoding support

* simple : improve comments + free batch

* ggml-cuda : add rope f16, restore performance with parallel decoding (#3272)

* ggml-cuda : add rope f16, restore performance

* offload KQ_mask with all models

* fix rope shift

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* llama : disable MPI for now

ggml-ci

* train : make KQ_pos memory buffer permanent via dummy scale op

* ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)

ggml-ci

* parallel : fix bug (extra BOS) + smaller token_prev array

* parallel : fix cases where the input prompts can overflow the batch

* parallel : add disabled experimental batch chunking in powers of two

* llama : llama.h formatting + comments

* simple : add README.md

* llama : fix kv cache heuristic when context is less than 32

* parallel : fix crash when `-n -1`

* llama : simplify returns if/else branches

* metal : use mm kernels for batch size > 2

* examples : utilize new llama_get_logits_ith()

* examples : add example for batched decoding

* examples : do not eval prompt 2 times (close #3348)

* server : clear the KV cache beyond n_past before llama_decode

* server : avoid context swaps by shifting the KV cache

---------

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov 2023-09-28 19:04:36 +03:00 committed by GitHub
parent 45855b3f1c
commit ec893798b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 2700 additions and 673 deletions

View file

@ -26,12 +26,18 @@ int main(int argc, char ** argv) {
params.prompt = "Hello my name is";
}
// total length of the sequence including the prompt
const int n_len = 32;
// init LLM
llama_backend_init(params.numa);
llama_context_params ctx_params = llama_context_default_params();
ctx_params.seed = 1234;
ctx_params.n_ctx = 2048;
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
if (model == NULL) {
@ -41,20 +47,31 @@ int main(int argc, char ** argv) {
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
}
// tokenize the prompt
std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
const int max_context_size = llama_n_ctx(ctx);
const int max_tokens_list_size = max_context_size - 4;
const int n_ctx = llama_n_ctx(ctx);
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
if ((int) tokens_list.size() > max_tokens_list_size) {
fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) tokens_list.size(), max_tokens_list_size);
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req);
// make sure the KV cache is big enough to hold all the prompt and generated tokens
if (n_kv_req > n_ctx) {
LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__);
LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__);
return 1;
}
fprintf(stderr, "\n\n");
// print the prompt token-by-token
fprintf(stderr, "\n");
for (auto id : tokens_list) {
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
@ -62,63 +79,104 @@ int main(int argc, char ** argv) {
fflush(stderr);
// create a llama_batch with size 512
// we use this object to submit token data for decoding
llama_batch batch = llama_batch_init(512, 0);
// evaluate the initial prompt
batch.n_tokens = tokens_list.size();
for (int32_t i = 0; i < batch.n_tokens; i++) {
batch.token[i] = tokens_list[i];
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
}
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
if (llama_decode(ctx, batch, params.n_threads) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
// main loop
// The LLM keeps a contextual cache memory of previous token evaluation.
// Usually, once this cache is full, it is required to recompute a compressed context based on previous
// tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist
// example, we will just stop the loop once this cache is full or once an end of stream is detected.
int n_cur = batch.n_tokens;
int n_decode = 0;
const int n_gen = std::min(32, max_context_size);
const auto t_main_start = ggml_time_us();
while (llama_get_kv_cache_token_count(ctx) < n_gen) {
// evaluate the transformer
while (n_cur <= n_len) {
// sample the next token
{
auto n_vocab = llama_n_vocab(ctx);
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// sample the most likely token
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
// is it an end of stream?
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
LOG_TEE("\n");
break;
}
LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
fflush(stdout);
// prepare the next batch
batch.n_tokens = 0;
// push this new token for next evaluation
batch.token [batch.n_tokens] = new_token_id;
batch.pos [batch.n_tokens] = n_cur;
batch.seq_id[batch.n_tokens] = 0;
batch.logits[batch.n_tokens] = true;
batch.n_tokens += 1;
n_decode += 1;
}
n_cur += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
tokens_list.clear();
// sample the next token
llama_token new_token_id = 0;
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
new_token_id = llama_sample_token_greedy(ctx , &candidates_p);
// is it an end of stream ?
if (new_token_id == llama_token_eos(ctx)) {
fprintf(stderr, " [end of text]\n");
break;
}
// print the new token :
printf("%s", llama_token_to_piece(ctx, new_token_id).c_str());
fflush(stdout);
// push this new token for next evaluation
tokens_list.push_back(new_token_id);
}
LOG_TEE("\n");
const auto t_main_end = ggml_time_us();
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
llama_print_timings(ctx);
fprintf(stderr, "\n");
llama_batch_free(batch);
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
fprintf(stderr, "\n\n");
return 0;
}