llama : improve llama_batch API + simplify parallel example
This commit is contained in:
parent
a1327c71c6
commit
addae65fd4
6 changed files with 111 additions and 70 deletions
|
@ -127,11 +127,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_seq_id g_seq_id = 0;
|
||||
|
||||
std::vector<llama_token> batch_token;
|
||||
std::vector<llama_pos> batch_pos;
|
||||
std::vector<llama_seq_id> batch_seq_id;
|
||||
std::vector<int8_t> batch_logits;
|
||||
std::vector<client *> batch_clients;
|
||||
llama_batch batch = llama_batch_init(params.n_batch, 0);
|
||||
|
||||
int32_t n_total_prompt = 0;
|
||||
int32_t n_total_gen = 0;
|
||||
|
@ -146,24 +142,15 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
|
||||
|
||||
batch_pos.clear();
|
||||
batch_seq_id.clear();
|
||||
batch.n_tokens = n_tokens_system;
|
||||
|
||||
for (size_t i = 0; i < n_tokens_system; ++i) {
|
||||
batch_pos.push_back(i);
|
||||
batch_seq_id.push_back(0);
|
||||
for (uint32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
batch.token[i] = tokens_system[i];
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
batch.logits[i] = false;
|
||||
}
|
||||
|
||||
llama_batch batch = {
|
||||
n_tokens_system,
|
||||
tokens_system.data(),
|
||||
nullptr,
|
||||
batch_pos.data(),
|
||||
batch_seq_id.data(),
|
||||
nullptr,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
if (llama_decode(ctx, batch, params.n_threads) != 0) {
|
||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
|
@ -180,63 +167,72 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("Processing requests ...\n\n");
|
||||
|
||||
while (true) {
|
||||
uint32_t n_tokens = 0;
|
||||
|
||||
batch_token.clear();
|
||||
batch_pos.clear();
|
||||
batch_seq_id.clear();
|
||||
batch_logits.clear();
|
||||
batch.n_tokens = 0;
|
||||
|
||||
// decode any currently ongoing sequences
|
||||
for (auto & client : clients) {
|
||||
if (client.seq_id == -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
batch_token.push_back(client.sampled);
|
||||
batch_pos.push_back(n_tokens_system + client.n_prompt + client.n_decoded);
|
||||
batch_seq_id.push_back(client.id);
|
||||
batch_logits.push_back(true);
|
||||
batch_clients.push_back(&client);
|
||||
batch.token [batch.n_tokens] = client.sampled;
|
||||
batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
|
||||
batch.seq_id[batch.n_tokens] = client.id;
|
||||
batch.logits[batch.n_tokens] = true;
|
||||
|
||||
client.n_decoded += 1;
|
||||
client.i_batch = batch_token.size() - 1;
|
||||
client.i_batch = batch.n_tokens;
|
||||
|
||||
batch.n_tokens += 1;
|
||||
}
|
||||
|
||||
if (batch_token.empty()) {
|
||||
if (batch.n_tokens == 0) {
|
||||
// all sequences have ended - clear the entire KV cache
|
||||
for (int i = 0; i < n_clients; ++i) {
|
||||
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
|
||||
}
|
||||
}
|
||||
|
||||
if (cont_batching || batch_token.empty()) {
|
||||
// insert new sequences for decoding
|
||||
if (cont_batching || batch.n_tokens == 0) {
|
||||
for (auto & client : clients) {
|
||||
if (client.seq_id == -1 && g_seq_id < n_seq) {
|
||||
client.seq_id = g_seq_id;
|
||||
|
||||
client.t_start_prompt = ggml_time_us();
|
||||
client.t_start_gen = 0;
|
||||
|
||||
client.input = k_prompts[rand() % k_prompts.size()];
|
||||
client.prompt = client.input + "\nAssistant:";
|
||||
client.input = k_prompts[rand() % k_prompts.size()];
|
||||
client.prompt = client.input + "\nAssistant:";
|
||||
client.response = "";
|
||||
|
||||
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
|
||||
|
||||
std::vector<llama_token> tokens_prompt;
|
||||
tokens_prompt = ::llama_tokenize(ctx, client.prompt, true);
|
||||
|
||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||
batch_token.push_back(tokens_prompt[i]);
|
||||
batch_pos.push_back(i + n_tokens_system);
|
||||
batch_seq_id.push_back(client.id);
|
||||
batch_clients.push_back(&client);
|
||||
batch_logits.push_back(false);
|
||||
batch.token [batch.n_tokens] = tokens_prompt[i];
|
||||
batch.pos [batch.n_tokens] = i + n_tokens_system;
|
||||
batch.seq_id[batch.n_tokens] = client.id;
|
||||
batch.logits[batch.n_tokens] = false;
|
||||
batch.n_tokens += 1;
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
if (batch.n_tokens > 0) {
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
}
|
||||
batch_logits.back() = true;
|
||||
|
||||
client.n_prompt = tokens_prompt.size();
|
||||
client.n_decoded = 0;
|
||||
client.i_batch = batch_token.size() - 1;
|
||||
client.i_batch = batch.n_tokens - 1;
|
||||
|
||||
LOG_TEE("\033[1mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
|
||||
|
||||
g_seq_id += 1;
|
||||
|
||||
// insert new requests one-by-one
|
||||
//if (cont_batching) {
|
||||
// break;
|
||||
//}
|
||||
|
@ -244,34 +240,35 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
if (batch_token.empty()) {
|
||||
if (batch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
// process in chunks of params.n_batch
|
||||
int32_t n_batch = params.n_batch;
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) batch_token.size(); i += n_batch) {
|
||||
n_tokens = std::min(n_batch, (int32_t) (batch_token.size() - i));
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||
const uint32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
|
||||
llama_batch batch = {
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch_token.data() + i,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch_pos.data() + i,
|
||||
batch_seq_id.data() + i,
|
||||
batch_logits.data() + i,
|
||||
batch.pos + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch, params.n_threads);
|
||||
const int ret = llama_decode(ctx, batch_view, params.n_threads);
|
||||
if (ret != 0) {
|
||||
if (n_batch == 1 || ret < 0) {
|
||||
LOG_TEE("%s : failed to decode batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
|
||||
// if you get here, it means the KV cache is full - try increasing it via the context size
|
||||
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
|
||||
return 1;
|
||||
}
|
||||
|
||||
LOG("%s : failed to decode batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
|
||||
LOG("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
|
||||
|
||||
n_cache_miss += 1;
|
||||
|
||||
|
@ -357,6 +354,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_print_timings(ctx);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
|
|
|
@ -419,7 +419,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
}
|
||||
|
||||
static std::vector<float> hellaswag_evaluate_tokens(
|
||||
llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch, int n_vocab, int n_thread
|
||||
llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab, int n_thread
|
||||
) {
|
||||
std::vector<float> result;
|
||||
result.reserve(tokens.size() * n_vocab);
|
||||
|
|
|
@ -10,10 +10,12 @@ int main(int argc, char ** argv) {
|
|||
gpt_params params;
|
||||
|
||||
if (argc == 1 || argv[1][0] == '-') {
|
||||
printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]);
|
||||
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL]\n" , argv[0]);
|
||||
return 1 ;
|
||||
}
|
||||
|
||||
int n_parallel = 1;
|
||||
|
||||
if (argc >= 2) {
|
||||
params.model = argv[1];
|
||||
}
|
||||
|
@ -22,6 +24,10 @@ int main(int argc, char ** argv) {
|
|||
params.prompt = argv[2];
|
||||
}
|
||||
|
||||
if (argc >= 4) {
|
||||
n_parallel = std::atoi(argv[3]);
|
||||
}
|
||||
|
||||
if (params.prompt.empty()) {
|
||||
params.prompt = "Hello my name is";
|
||||
}
|
||||
|
|
|
@ -134,7 +134,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
while (true) {
|
||||
// sample from the target model
|
||||
const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
|
||||
llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
|
||||
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
last_tokens.erase(last_tokens.begin());
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue