parallel : process system prompt once + configurable paramters + llama API

This commit is contained in:
Georgi Gerganov 2023-09-19 17:00:42 +03:00
parent 82e20e9ba0
commit 4b5f3cd6bf
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
9 changed files with 187 additions and 93 deletions

View file

@ -977,7 +977,7 @@ int main(int argc, char ** argv) {
test t(inst, lmodel, ctx);
llama_kv_cache_rm_tokens(ctx, -1, -1);
llama_kv_cache_tokens_rm(ctx, -1, -1);
// warmup run
if (t.n_prompt > 0) {
@ -988,7 +988,7 @@ int main(int argc, char ** argv) {
}
for (int i = 0; i < params.reps; i++) {
llama_kv_cache_rm_tokens(ctx, -1, -1);
llama_kv_cache_tokens_rm(ctx, -1, -1);
uint64_t t_start = get_time_ns();
if (t.n_prompt > 0) {

View file

@ -505,8 +505,8 @@ int main(int argc, char ** argv) {
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_rm_seq (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_shift_seq(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
n_past -= n_discard;

View file

@ -35,7 +35,7 @@ User: Hello, what is the temperature outside?
Assistant: It is 72 degrees Fahrenheit.
User: What is the definition of a prime number?
Assistant: A prime number is a number that is divisible only by itself and 1.
User: )";
User:)";
static std::vector<std::string> k_prompts = {
"What is the meaning of life?",
@ -70,7 +70,7 @@ struct client {
std::string prompt;
std::string response;
std::vector<llama_token> last_tokens;
std::vector<llama_token> tokens_prev;
};
int main(int argc, char ** argv) {
@ -80,13 +80,14 @@ int main(int argc, char ** argv) {
return 1;
}
const int n_clients = 8;
// insert new requests as soon as the previous one is done
const bool hot_plug = true;
// number of simultaneous "clients" to simulate
const int32_t n_clients = params.n_parallel;
// requests to simulate
const int32_t n_seq = 128;
const int32_t n_seq = params.n_sequences;
// insert new requests as soon as the previous one is done
const bool hot_plug = params.hot_plug;
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("parallel", "log"));
@ -114,13 +115,17 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i];
client.id = i;
client.last_tokens.resize(n_ctx);
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
client.tokens_prev.resize(n_ctx);
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
std::vector<llama_token> tokens_system;
tokens_system = ::llama_tokenize(ctx, k_system, true);
const uint32_t n_tokens_system = tokens_system.size();
llama_seq_id g_seq_id = 0;
std::vector<llama_token> batch_token;
@ -134,6 +139,44 @@ int main(int argc, char ** argv) {
const auto t_main_start = ggml_time_us();
LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
LOG_TEE("%s: n_parallel = %d, n_sequences = %d, hot_plug = %d, system tokens = %d\n", __func__, n_clients, n_seq, hot_plug, n_tokens_system);
LOG_TEE("\n");
{
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
batch_pos.clear();
batch_seq_id.clear();
for (size_t i = 0; i < n_tokens_system; ++i) {
batch_pos.push_back(i);
batch_seq_id.push_back(0);
}
llama_batch batch = {
n_tokens_system,
tokens_system.data(),
nullptr,
batch_pos.data(),
batch_seq_id.data(),
nullptr,
0, 0, 0, // unused
};
if (llama_decode(ctx, batch, params.n_threads) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
// assign the system KV cachce to all parallel sequences
for (int32_t i = 1; i < n_clients; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
}
LOG_TEE("\n");
}
while (true) {
uint32_t n_tokens = 0;
@ -148,7 +191,7 @@ int main(int argc, char ** argv) {
}
batch_token.push_back(client.sampled);
batch_pos.push_back(client.n_decoded + client.n_prompt);
batch_pos.push_back(n_tokens_system + client.n_prompt + client.n_decoded);
batch_seq_id.push_back(client.seq_id);
batch_logits.push_back(true);
batch_clients.push_back(&client);
@ -158,34 +201,36 @@ int main(int argc, char ** argv) {
if (batch_token.empty()) {
// all sequences have ended - clear the entire KV cache
llama_kv_cache_rm_tokens(ctx, -1, -1);
for (int i = 0; i < n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
}
}
if (hot_plug || batch_token.empty()) {
for (auto & client : clients) {
if (client.seq_id == -1 && g_seq_id < n_seq) {
client.seq_id = g_seq_id;
client.seq_id = client.id;
client.t_start_prompt = ggml_time_us();
client.t_start_gen = 0;
client.input = k_prompts[rand() % k_prompts.size()];
client.prompt = k_system + client.input + "\nAssistant:";
client.prompt = client.input + "\nAssistant:";
client.response = "";
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
std::vector<llama_token> prompt_tokens;
prompt_tokens = ::llama_tokenize(ctx, client.prompt, true);
std::vector<llama_token> tokens_prompt;
tokens_prompt = ::llama_tokenize(ctx, client.prompt, true);
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
batch_token.push_back(prompt_tokens[i]);
batch_pos.push_back(i);
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
batch_token.push_back(tokens_prompt[i]);
batch_pos.push_back(i + n_tokens_system);
batch_seq_id.push_back(client.seq_id);
batch_clients.push_back(&client);
batch_logits.push_back(false);
}
batch_logits.back() = true;
client.n_prompt = prompt_tokens.size();
client.n_prompt = tokens_prompt.size();
client.n_decoded = 0;
client.i_batch = batch_token.size() - 1;
@ -217,9 +262,10 @@ int main(int argc, char ** argv) {
0, 0, 0, // unused
};
if (llama_decode(ctx, batch, params.n_threads)) {
if (n_batch == 1) {
LOG_TEE("%s : failed to decode batch\n", __func__);
const int ret = llama_decode(ctx, batch, params.n_threads);
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
LOG_TEE("%s : failed to decode batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
return 1;
}
@ -242,7 +288,7 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i);
if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients
@ -251,8 +297,8 @@ int main(int argc, char ** argv) {
}
// remember which tokens were sampled - used for repetition penalties during sampling
client.last_tokens.erase(client.last_tokens.begin());
client.last_tokens.push_back(id);
client.tokens_prev.erase(client.tokens_prev.begin());
client.tokens_prev.push_back(id);
const std::string token_str = llama_token_to_piece(ctx, id);
client.response += token_str;
@ -271,7 +317,8 @@ int main(int argc, char ** argv) {
client.response = client.response.substr(0, pos);
}
llama_kv_cache_rm_seq(ctx, client.seq_id, 0, n_ctx);
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, client.seq_id, n_tokens_system, n_ctx);
const auto t_main_end = ggml_time_us();

View file

@ -207,7 +207,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_keep_seq(ctx, -1);
llama_kv_cache_tokens_rm(ctx, -1, -1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@ -335,7 +335,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_keep_seq(ctx, -1);
llama_kv_cache_tokens_rm(ctx, -1, -1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@ -568,7 +568,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
}
// clear the KV cache
llama_kv_cache_keep_seq(ctx, -1);
llama_kv_cache_tokens_rm(ctx, -1, -1);
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
if (logits.empty()) {

View file

@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
LOG("out of drafted tokens\n");
}
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_dft, n_ctx);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx);
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
++n_past_dft;
@ -257,7 +257,7 @@ int main(int argc, char ** argv) {
}
// evaluate the drafted token on the draft model
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_cur, n_ctx);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx);
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
++n_past_cur;
@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
}
// evaluate the target model on the drafted tokens
llama_kv_cache_rm_seq(ctx_tgt, 0, n_past_tgt, n_ctx);
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx);
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
++n_past_tgt;