parallel : process system prompt once + configurable paramters + llama API
This commit is contained in:
parent
82e20e9ba0
commit
4b5f3cd6bf
9 changed files with 187 additions and 93 deletions
|
@ -977,7 +977,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
test t(inst, lmodel, ctx);
|
||||
|
||||
llama_kv_cache_rm_tokens(ctx, -1, -1);
|
||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
||||
|
||||
// warmup run
|
||||
if (t.n_prompt > 0) {
|
||||
|
@ -988,7 +988,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
for (int i = 0; i < params.reps; i++) {
|
||||
llama_kv_cache_rm_tokens(ctx, -1, -1);
|
||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
||||
|
||||
uint64_t t_start = get_time_ns();
|
||||
if (t.n_prompt > 0) {
|
||||
|
|
|
@ -505,8 +505,8 @@ int main(int argc, char ** argv) {
|
|||
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||
|
||||
llama_kv_cache_rm_seq (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||
llama_kv_cache_shift_seq(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
||||
|
||||
n_past -= n_discard;
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ User: Hello, what is the temperature outside?
|
|||
Assistant: It is 72 degrees Fahrenheit.
|
||||
User: What is the definition of a prime number?
|
||||
Assistant: A prime number is a number that is divisible only by itself and 1.
|
||||
User: )";
|
||||
User:)";
|
||||
|
||||
static std::vector<std::string> k_prompts = {
|
||||
"What is the meaning of life?",
|
||||
|
@ -70,7 +70,7 @@ struct client {
|
|||
std::string prompt;
|
||||
std::string response;
|
||||
|
||||
std::vector<llama_token> last_tokens;
|
||||
std::vector<llama_token> tokens_prev;
|
||||
};
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
|
@ -80,13 +80,14 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
const int n_clients = 8;
|
||||
|
||||
// insert new requests as soon as the previous one is done
|
||||
const bool hot_plug = true;
|
||||
// number of simultaneous "clients" to simulate
|
||||
const int32_t n_clients = params.n_parallel;
|
||||
|
||||
// requests to simulate
|
||||
const int32_t n_seq = 128;
|
||||
const int32_t n_seq = params.n_sequences;
|
||||
|
||||
// insert new requests as soon as the previous one is done
|
||||
const bool hot_plug = params.hot_plug;
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_set_target(log_filename_generator("parallel", "log"));
|
||||
|
@ -114,13 +115,17 @@ int main(int argc, char ** argv) {
|
|||
for (size_t i = 0; i < clients.size(); ++i) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.last_tokens.resize(n_ctx);
|
||||
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
|
||||
client.tokens_prev.resize(n_ctx);
|
||||
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
|
||||
}
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
tokens_system = ::llama_tokenize(ctx, k_system, true);
|
||||
const uint32_t n_tokens_system = tokens_system.size();
|
||||
|
||||
llama_seq_id g_seq_id = 0;
|
||||
|
||||
std::vector<llama_token> batch_token;
|
||||
|
@ -134,6 +139,44 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
|
||||
LOG_TEE("%s: n_parallel = %d, n_sequences = %d, hot_plug = %d, system tokens = %d\n", __func__, n_clients, n_seq, hot_plug, n_tokens_system);
|
||||
LOG_TEE("\n");
|
||||
|
||||
{
|
||||
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
|
||||
|
||||
batch_pos.clear();
|
||||
batch_seq_id.clear();
|
||||
|
||||
for (size_t i = 0; i < n_tokens_system; ++i) {
|
||||
batch_pos.push_back(i);
|
||||
batch_seq_id.push_back(0);
|
||||
}
|
||||
|
||||
llama_batch batch = {
|
||||
n_tokens_system,
|
||||
tokens_system.data(),
|
||||
nullptr,
|
||||
batch_pos.data(),
|
||||
batch_seq_id.data(),
|
||||
nullptr,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
if (llama_decode(ctx, batch, params.n_threads) != 0) {
|
||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// assign the system KV cachce to all parallel sequences
|
||||
for (int32_t i = 1; i < n_clients; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
|
||||
}
|
||||
|
||||
LOG_TEE("\n");
|
||||
}
|
||||
|
||||
while (true) {
|
||||
uint32_t n_tokens = 0;
|
||||
|
||||
|
@ -148,7 +191,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
batch_token.push_back(client.sampled);
|
||||
batch_pos.push_back(client.n_decoded + client.n_prompt);
|
||||
batch_pos.push_back(n_tokens_system + client.n_prompt + client.n_decoded);
|
||||
batch_seq_id.push_back(client.seq_id);
|
||||
batch_logits.push_back(true);
|
||||
batch_clients.push_back(&client);
|
||||
|
@ -158,34 +201,36 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if (batch_token.empty()) {
|
||||
// all sequences have ended - clear the entire KV cache
|
||||
llama_kv_cache_rm_tokens(ctx, -1, -1);
|
||||
for (int i = 0; i < n_clients; ++i) {
|
||||
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
|
||||
}
|
||||
}
|
||||
|
||||
if (hot_plug || batch_token.empty()) {
|
||||
for (auto & client : clients) {
|
||||
if (client.seq_id == -1 && g_seq_id < n_seq) {
|
||||
client.seq_id = g_seq_id;
|
||||
client.seq_id = client.id;
|
||||
client.t_start_prompt = ggml_time_us();
|
||||
client.t_start_gen = 0;
|
||||
|
||||
client.input = k_prompts[rand() % k_prompts.size()];
|
||||
client.prompt = k_system + client.input + "\nAssistant:";
|
||||
client.prompt = client.input + "\nAssistant:";
|
||||
client.response = "";
|
||||
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
|
||||
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
|
||||
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
prompt_tokens = ::llama_tokenize(ctx, client.prompt, true);
|
||||
std::vector<llama_token> tokens_prompt;
|
||||
tokens_prompt = ::llama_tokenize(ctx, client.prompt, true);
|
||||
|
||||
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
|
||||
batch_token.push_back(prompt_tokens[i]);
|
||||
batch_pos.push_back(i);
|
||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||
batch_token.push_back(tokens_prompt[i]);
|
||||
batch_pos.push_back(i + n_tokens_system);
|
||||
batch_seq_id.push_back(client.seq_id);
|
||||
batch_clients.push_back(&client);
|
||||
batch_logits.push_back(false);
|
||||
}
|
||||
batch_logits.back() = true;
|
||||
|
||||
client.n_prompt = prompt_tokens.size();
|
||||
client.n_prompt = tokens_prompt.size();
|
||||
client.n_decoded = 0;
|
||||
client.i_batch = batch_token.size() - 1;
|
||||
|
||||
|
@ -217,9 +262,10 @@ int main(int argc, char ** argv) {
|
|||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
if (n_batch == 1) {
|
||||
LOG_TEE("%s : failed to decode batch\n", __func__);
|
||||
const int ret = llama_decode(ctx, batch, params.n_threads);
|
||||
if (ret != 0) {
|
||||
if (n_batch == 1 || ret < 0) {
|
||||
LOG_TEE("%s : failed to decode batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
@ -242,7 +288,7 @@ int main(int argc, char ** argv) {
|
|||
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
||||
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
||||
|
||||
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
|
||||
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i);
|
||||
|
||||
if (client.n_decoded == 1) {
|
||||
// start measuring generation time after the first token to make sure all concurrent clients
|
||||
|
@ -251,8 +297,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
client.last_tokens.erase(client.last_tokens.begin());
|
||||
client.last_tokens.push_back(id);
|
||||
client.tokens_prev.erase(client.tokens_prev.begin());
|
||||
client.tokens_prev.push_back(id);
|
||||
|
||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||
client.response += token_str;
|
||||
|
@ -271,7 +317,8 @@ int main(int argc, char ** argv) {
|
|||
client.response = client.response.substr(0, pos);
|
||||
}
|
||||
|
||||
llama_kv_cache_rm_seq(ctx, client.seq_id, 0, n_ctx);
|
||||
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
||||
llama_kv_cache_seq_rm(ctx, client.seq_id, n_tokens_system, n_ctx);
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
|
|
|
@ -207,7 +207,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_keep_seq(ctx, -1);
|
||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
|
@ -335,7 +335,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_keep_seq(ctx, -1);
|
||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
|
@ -568,7 +568,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
}
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_keep_seq(ctx, -1);
|
||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
||||
|
||||
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
|
||||
if (logits.empty()) {
|
||||
|
|
|
@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
|
|||
LOG("out of drafted tokens\n");
|
||||
}
|
||||
|
||||
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_dft, n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx);
|
||||
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
|
||||
++n_past_dft;
|
||||
|
||||
|
@ -257,7 +257,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// evaluate the drafted token on the draft model
|
||||
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_cur, n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx);
|
||||
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
|
||||
++n_past_cur;
|
||||
|
||||
|
@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// evaluate the target model on the drafted tokens
|
||||
llama_kv_cache_rm_seq(ctx_tgt, 0, n_past_tgt, n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx);
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
|
||||
++n_past_tgt;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue