parallel : process system prompt once + configurable paramters + llama API

This commit is contained in:
Georgi Gerganov 2023-09-19 17:00:42 +03:00
parent 82e20e9ba0
commit 4b5f3cd6bf
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
9 changed files with 187 additions and 93 deletions

View file

@ -35,7 +35,7 @@ User: Hello, what is the temperature outside?
Assistant: It is 72 degrees Fahrenheit.
User: What is the definition of a prime number?
Assistant: A prime number is a number that is divisible only by itself and 1.
User: )";
User:)";
static std::vector<std::string> k_prompts = {
"What is the meaning of life?",
@ -70,7 +70,7 @@ struct client {
std::string prompt;
std::string response;
std::vector<llama_token> last_tokens;
std::vector<llama_token> tokens_prev;
};
int main(int argc, char ** argv) {
@ -80,13 +80,14 @@ int main(int argc, char ** argv) {
return 1;
}
const int n_clients = 8;
// insert new requests as soon as the previous one is done
const bool hot_plug = true;
// number of simultaneous "clients" to simulate
const int32_t n_clients = params.n_parallel;
// requests to simulate
const int32_t n_seq = 128;
const int32_t n_seq = params.n_sequences;
// insert new requests as soon as the previous one is done
const bool hot_plug = params.hot_plug;
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("parallel", "log"));
@ -114,13 +115,17 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i];
client.id = i;
client.last_tokens.resize(n_ctx);
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
client.tokens_prev.resize(n_ctx);
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
std::vector<llama_token> tokens_system;
tokens_system = ::llama_tokenize(ctx, k_system, true);
const uint32_t n_tokens_system = tokens_system.size();
llama_seq_id g_seq_id = 0;
std::vector<llama_token> batch_token;
@ -134,6 +139,44 @@ int main(int argc, char ** argv) {
const auto t_main_start = ggml_time_us();
LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
LOG_TEE("%s: n_parallel = %d, n_sequences = %d, hot_plug = %d, system tokens = %d\n", __func__, n_clients, n_seq, hot_plug, n_tokens_system);
LOG_TEE("\n");
{
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
batch_pos.clear();
batch_seq_id.clear();
for (size_t i = 0; i < n_tokens_system; ++i) {
batch_pos.push_back(i);
batch_seq_id.push_back(0);
}
llama_batch batch = {
n_tokens_system,
tokens_system.data(),
nullptr,
batch_pos.data(),
batch_seq_id.data(),
nullptr,
0, 0, 0, // unused
};
if (llama_decode(ctx, batch, params.n_threads) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
// assign the system KV cachce to all parallel sequences
for (int32_t i = 1; i < n_clients; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
}
LOG_TEE("\n");
}
while (true) {
uint32_t n_tokens = 0;
@ -148,7 +191,7 @@ int main(int argc, char ** argv) {
}
batch_token.push_back(client.sampled);
batch_pos.push_back(client.n_decoded + client.n_prompt);
batch_pos.push_back(n_tokens_system + client.n_prompt + client.n_decoded);
batch_seq_id.push_back(client.seq_id);
batch_logits.push_back(true);
batch_clients.push_back(&client);
@ -158,34 +201,36 @@ int main(int argc, char ** argv) {
if (batch_token.empty()) {
// all sequences have ended - clear the entire KV cache
llama_kv_cache_rm_tokens(ctx, -1, -1);
for (int i = 0; i < n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
}
}
if (hot_plug || batch_token.empty()) {
for (auto & client : clients) {
if (client.seq_id == -1 && g_seq_id < n_seq) {
client.seq_id = g_seq_id;
client.seq_id = client.id;
client.t_start_prompt = ggml_time_us();
client.t_start_gen = 0;
client.input = k_prompts[rand() % k_prompts.size()];
client.prompt = k_system + client.input + "\nAssistant:";
client.prompt = client.input + "\nAssistant:";
client.response = "";
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
std::vector<llama_token> prompt_tokens;
prompt_tokens = ::llama_tokenize(ctx, client.prompt, true);
std::vector<llama_token> tokens_prompt;
tokens_prompt = ::llama_tokenize(ctx, client.prompt, true);
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
batch_token.push_back(prompt_tokens[i]);
batch_pos.push_back(i);
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
batch_token.push_back(tokens_prompt[i]);
batch_pos.push_back(i + n_tokens_system);
batch_seq_id.push_back(client.seq_id);
batch_clients.push_back(&client);
batch_logits.push_back(false);
}
batch_logits.back() = true;
client.n_prompt = prompt_tokens.size();
client.n_prompt = tokens_prompt.size();
client.n_decoded = 0;
client.i_batch = batch_token.size() - 1;
@ -217,9 +262,10 @@ int main(int argc, char ** argv) {
0, 0, 0, // unused
};
if (llama_decode(ctx, batch, params.n_threads)) {
if (n_batch == 1) {
LOG_TEE("%s : failed to decode batch\n", __func__);
const int ret = llama_decode(ctx, batch, params.n_threads);
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
LOG_TEE("%s : failed to decode batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
return 1;
}
@ -242,7 +288,7 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i);
if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients
@ -251,8 +297,8 @@ int main(int argc, char ** argv) {
}
// remember which tokens were sampled - used for repetition penalties during sampling
client.last_tokens.erase(client.last_tokens.begin());
client.last_tokens.push_back(id);
client.tokens_prev.erase(client.tokens_prev.begin());
client.tokens_prev.push_back(id);
const std::string token_str = llama_token_to_piece(ctx, id);
client.response += token_str;
@ -271,7 +317,8 @@ int main(int argc, char ** argv) {
client.response = client.response.substr(0, pos);
}
llama_kv_cache_rm_seq(ctx, client.seq_id, 0, n_ctx);
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, client.seq_id, n_tokens_system, n_ctx);
const auto t_main_end = ggml_time_us();