llama : improve llama_batch API + simplify parallel example

This commit is contained in:
Georgi Gerganov 2023-09-20 10:46:18 +03:00
parent a1327c71c6
commit addae65fd4
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 111 additions and 70 deletions

View file

@ -127,11 +127,7 @@ int main(int argc, char ** argv) {
llama_seq_id g_seq_id = 0;
std::vector<llama_token> batch_token;
std::vector<llama_pos> batch_pos;
std::vector<llama_seq_id> batch_seq_id;
std::vector<int8_t> batch_logits;
std::vector<client *> batch_clients;
llama_batch batch = llama_batch_init(params.n_batch, 0);
int32_t n_total_prompt = 0;
int32_t n_total_gen = 0;
@ -146,24 +142,15 @@ int main(int argc, char ** argv) {
{
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
batch_pos.clear();
batch_seq_id.clear();
batch.n_tokens = n_tokens_system;
for (size_t i = 0; i < n_tokens_system; ++i) {
batch_pos.push_back(i);
batch_seq_id.push_back(0);
for (uint32_t i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = tokens_system[i];
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
}
llama_batch batch = {
n_tokens_system,
tokens_system.data(),
nullptr,
batch_pos.data(),
batch_seq_id.data(),
nullptr,
0, 0, 0, // unused
};
if (llama_decode(ctx, batch, params.n_threads) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
@ -180,63 +167,72 @@ int main(int argc, char ** argv) {
LOG_TEE("Processing requests ...\n\n");
while (true) {
uint32_t n_tokens = 0;
batch_token.clear();
batch_pos.clear();
batch_seq_id.clear();
batch_logits.clear();
batch.n_tokens = 0;
// decode any currently ongoing sequences
for (auto & client : clients) {
if (client.seq_id == -1) {
continue;
}
batch_token.push_back(client.sampled);
batch_pos.push_back(n_tokens_system + client.n_prompt + client.n_decoded);
batch_seq_id.push_back(client.id);
batch_logits.push_back(true);
batch_clients.push_back(&client);
batch.token [batch.n_tokens] = client.sampled;
batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
batch.seq_id[batch.n_tokens] = client.id;
batch.logits[batch.n_tokens] = true;
client.n_decoded += 1;
client.i_batch = batch_token.size() - 1;
client.i_batch = batch.n_tokens;
batch.n_tokens += 1;
}
if (batch_token.empty()) {
if (batch.n_tokens == 0) {
// all sequences have ended - clear the entire KV cache
for (int i = 0; i < n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
}
}
if (cont_batching || batch_token.empty()) {
// insert new sequences for decoding
if (cont_batching || batch.n_tokens == 0) {
for (auto & client : clients) {
if (client.seq_id == -1 && g_seq_id < n_seq) {
client.seq_id = g_seq_id;
client.t_start_prompt = ggml_time_us();
client.t_start_gen = 0;
client.input = k_prompts[rand() % k_prompts.size()];
client.prompt = client.input + "\nAssistant:";
client.input = k_prompts[rand() % k_prompts.size()];
client.prompt = client.input + "\nAssistant:";
client.response = "";
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
std::vector<llama_token> tokens_prompt;
tokens_prompt = ::llama_tokenize(ctx, client.prompt, true);
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
batch_token.push_back(tokens_prompt[i]);
batch_pos.push_back(i + n_tokens_system);
batch_seq_id.push_back(client.id);
batch_clients.push_back(&client);
batch_logits.push_back(false);
batch.token [batch.n_tokens] = tokens_prompt[i];
batch.pos [batch.n_tokens] = i + n_tokens_system;
batch.seq_id[batch.n_tokens] = client.id;
batch.logits[batch.n_tokens] = false;
batch.n_tokens += 1;
}
// extract the logits only for the last token
if (batch.n_tokens > 0) {
batch.logits[batch.n_tokens - 1] = true;
}
batch_logits.back() = true;
client.n_prompt = tokens_prompt.size();
client.n_decoded = 0;
client.i_batch = batch_token.size() - 1;
client.i_batch = batch.n_tokens - 1;
LOG_TEE("\033[1mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
g_seq_id += 1;
// insert new requests one-by-one
//if (cont_batching) {
// break;
//}
@ -244,34 +240,35 @@ int main(int argc, char ** argv) {
}
}
if (batch_token.empty()) {
if (batch.n_tokens == 0) {
break;
}
// process in chunks of params.n_batch
int32_t n_batch = params.n_batch;
for (int32_t i = 0; i < (int32_t) batch_token.size(); i += n_batch) {
n_tokens = std::min(n_batch, (int32_t) (batch_token.size() - i));
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const uint32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch = {
llama_batch batch_view = {
n_tokens,
batch_token.data() + i,
batch.token + i,
nullptr,
batch_pos.data() + i,
batch_seq_id.data() + i,
batch_logits.data() + i,
batch.pos + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
const int ret = llama_decode(ctx, batch, params.n_threads);
const int ret = llama_decode(ctx, batch_view, params.n_threads);
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
LOG_TEE("%s : failed to decode batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
// if you get here, it means the KV cache is full - try increasing it via the context size
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
return 1;
}
LOG("%s : failed to decode batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
LOG("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
n_cache_miss += 1;
@ -357,6 +354,8 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx);
llama_batch_free(batch);
llama_free(ctx);
llama_free_model(model);