llama : extend batch API to select which logits to output
This commit is contained in:
parent
897caccdf4
commit
fa0e677820
4 changed files with 46 additions and 6 deletions
|
@ -79,7 +79,7 @@ bool eval_float(void * model, float * input, int N){
|
|||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, };
|
||||
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, };
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
|
|
|
@ -82,6 +82,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const int n_clients = 4;
|
||||
|
||||
// insert new requests as soon as the previous one is done
|
||||
const bool hot_swap = true;
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_set_target(log_filename_generator("parallel", "log"));
|
||||
LOG_TEE("Log start\n");
|
||||
|
@ -121,14 +124,23 @@ int main(int argc, char ** argv) {
|
|||
std::vector<llama_token> batch_token;
|
||||
std::vector<llama_pos> batch_pos;
|
||||
std::vector<llama_seq_id> batch_seq_id;
|
||||
std::vector<int8_t> batch_logits;
|
||||
std::vector<client *> batch_clients;
|
||||
|
||||
while (true) {
|
||||
int32_t n_total_prompt = 0;
|
||||
int32_t n_total_gen = 0;
|
||||
|
||||
float t_avg = 0.0f;
|
||||
|
||||
const int32_t n_seq = 128;
|
||||
|
||||
while (g_seq_id < n_seq + n_clients) {
|
||||
uint32_t n_tokens = 0;
|
||||
|
||||
batch_token.clear();
|
||||
batch_pos.clear();
|
||||
batch_seq_id.clear();
|
||||
batch_logits.clear();
|
||||
|
||||
for (auto & client : clients) {
|
||||
if (client.seq_id == -1) {
|
||||
|
@ -138,6 +150,7 @@ int main(int argc, char ** argv) {
|
|||
batch_token.push_back(client.sampled);
|
||||
batch_pos.push_back(client.n_decoded);
|
||||
batch_seq_id.push_back(client.seq_id);
|
||||
batch_logits.push_back(true);
|
||||
batch_clients.push_back(&client);
|
||||
client.n_decoded += 1;
|
||||
client.i_batch = batch_token.size() - 1;
|
||||
|
@ -146,7 +159,9 @@ int main(int argc, char ** argv) {
|
|||
if (batch_token.empty()) {
|
||||
// all sequences have ended - clear the entire KV cache
|
||||
llama_kv_cache_rm_tokens(ctx, -1, -1);
|
||||
}
|
||||
|
||||
if (hot_swap || batch_token.empty()) {
|
||||
for (auto & client : clients) {
|
||||
if (client.seq_id == -1) {
|
||||
client.seq_id = g_seq_id;
|
||||
|
@ -166,7 +181,10 @@ int main(int argc, char ** argv) {
|
|||
batch_pos.push_back(i);
|
||||
batch_seq_id.push_back(client.seq_id);
|
||||
batch_clients.push_back(&client);
|
||||
batch_logits.push_back(false);
|
||||
}
|
||||
batch_logits.back() = true;
|
||||
|
||||
client.n_prompt = prompt_tokens.size();
|
||||
client.n_decoded = prompt_tokens.size();
|
||||
client.i_batch = batch_token.size() - 1;
|
||||
|
@ -186,6 +204,7 @@ int main(int argc, char ** argv) {
|
|||
nullptr,
|
||||
batch_pos.data() + i,
|
||||
batch_seq_id.data() + i,
|
||||
batch_logits.data() + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
|
@ -232,14 +251,20 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n",
|
||||
printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n",
|
||||
client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt,
|
||||
(t_main_end - client.t_start_prompt) / 1e6,
|
||||
(double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6,
|
||||
(double) (client.n_decoded - client.n_prompt) / (t_main_end - client.t_start_gen) * 1e6,
|
||||
(double) (client.n_decoded ) / (t_main_end - client.t_start_prompt) * 1e6,
|
||||
::trim(client.input).c_str(),
|
||||
::trim(client.response).c_str());
|
||||
|
||||
n_total_prompt += client.n_prompt;
|
||||
n_total_gen += client.n_decoded - client.n_prompt;
|
||||
|
||||
t_avg += (t_main_end - client.t_start_prompt) / 1e6;
|
||||
|
||||
client.seq_id = -1;
|
||||
}
|
||||
|
||||
|
@ -248,6 +273,11 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
LOG_TEE("\n\n");
|
||||
LOG_TEE("Total prompt tokens: %d\n", n_total_prompt);
|
||||
LOG_TEE("Total gen tokens: %d\n", n_total_gen);
|
||||
LOG_TEE("Avg time per seq: %.2f s\n", t_avg / n_seq);
|
||||
|
||||
LOG_TEE("\n\n");
|
||||
|
||||
llama_print_timings(ctx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue