diff --git a/examples/server-parallel/server.cpp b/examples/server-parallel/server.cpp index 72ec7c2bc..dbc361fd3 100644 --- a/examples/server-parallel/server.cpp +++ b/examples/server-parallel/server.cpp @@ -80,36 +80,33 @@ struct llama_client_slot int32_t n_decoded = 0; int32_t i_batch = -1; string prompt = ""; - string sampled_token_str; string generated_text = ""; + int n_tokens_predicted = 0; llama_token sampled; + std::vector sampled_tokens; std::vector tokens_prev; slot_state state = IDLE; slot_command command = NONE; - bool newToken = false; float temperature = 0.1f; void start(string prompt_, float temp_) { prompt = prompt_; command = LOAD_PROMPT; temperature = temp_; - newToken = false; + LOG_TEE("slot %i is processing\n", id); } bool hasNewToken() { - if(newToken) { - newToken = false; - return true; - } - return false; + return sampled_tokens.size() > 0; } bool available() { return state == IDLE && command == NONE; } - void nofity() { - newToken = !newToken; + void addTokenString(string token) { + sampled_tokens.insert(sampled_tokens.begin(), token); + n_tokens_predicted++; } void release() { @@ -163,7 +160,7 @@ struct server_parallel_context { slot.id = i; slot.prompt = "default"; slot.state = IDLE; - slot.tokens_prev.resize(std::max(256, params.n_predict)); + slot.tokens_prev.resize(params.n_predict); std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0); LOG_TEE(" - slot %i\n", slot.id); slots.push_back(slot); @@ -247,7 +244,6 @@ struct server_parallel_context { if ((slot_id == -1 && slot.available()) || slot.id == slot_id) { slot.start(prompt, temperature); - LOG_TEE("slot %i is processing\n", slot.id); return &slot; // return a pointer to slot (thread safe?) } } @@ -302,6 +298,7 @@ struct server_parallel_context { } batch.n_tokens = 0; + int kv_cache_free = (n_ctx - n_tokens_system); // decode any currently ongoing sequences for (auto & slot : slots) { @@ -311,13 +308,17 @@ struct server_parallel_context { llama_kv_cache_seq_rm(ctx, slot.id, n_tokens_system, n_ctx); slot.state = IDLE; slot.command = NONE; + slot.n_prompt = 0; + slot.n_tokens_predicted = 0; continue; } + kv_cache_free -= slot.n_prompt; + // no decode wait until the token had been send to client // improves performance and avoid decoherence? - if (slot.state == IDLE || slot.newToken) { + if (slot.state == IDLE) { continue; } @@ -339,12 +340,13 @@ struct server_parallel_context { if (slot.state == IDLE && slot.command == LOAD_PROMPT) { slot.state = PROCESSING; slot.command = NONE; - //LOG_TEE("slot %i process prompt:\n%s%s'------------------------------\n", slot.id, system_prompt.c_str(), slot.prompt.c_str()); + std::fill(slot.tokens_prev.begin(), slot.tokens_prev.end(), 0); // do not prepend BOS because we have a system prompt! std::vector tokens_prompt; tokens_prompt = ::llama_tokenize(ctx, slot.prompt, false); + slot.n_tokens_predicted = 0; for (size_t i = 0; i < tokens_prompt.size(); ++i) { batch.token [batch.n_tokens] = tokens_prompt[i]; @@ -362,11 +364,6 @@ struct server_parallel_context { slot.n_prompt = tokens_prompt.size(); slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; - - // insert new requests one-by-one - //if (cont_batching) { - // break; - //} } } } @@ -379,13 +376,6 @@ struct server_parallel_context { int32_t n_batch = params.n_batch; for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { - // experiment: process in powers of 2 - //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { - // n_batch /= 2; - // i -= n_batch; - // continue; - //} - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); llama_batch batch_view = { @@ -431,18 +421,17 @@ struct server_parallel_context { slot.sampled = id; size_t stop_pos = - findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL); + findStoppingStrings(slot.generated_text, token_str.size(), STOP_FULL); - slot.sampled_token_str = token_str; - // notify new token - slot.nofity(); + slot.addTokenString(token_str); + + kv_cache_free -= slot.n_tokens_predicted; if (slot.n_decoded > 2 && - (id == llama_token_eos(ctx) || - (params.n_predict > 0 && - slot.n_decoded + slot.n_prompt >= - params.n_predict) || - stop_pos != std::string::npos)) { + (id == llama_token_eos(ctx) || + (slot.n_decoded + slot.n_prompt >= + params.n_predict) || + stop_pos != std::string::npos)) { //LOG_TEE("slot %i generated text:\n%s'------------------------------\n", slot.id, slot.generated_text.c_str()); slot.generated_text.clear(); slot.release(); @@ -450,6 +439,11 @@ struct server_parallel_context { slot.i_batch = -1; } } + + if(kv_cache_free < 0) { + LOG_TEE("\nError: kv cache is full, increase context size."); + return false; + } return true; } }; @@ -759,6 +753,9 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } params.n_predict = std::stoi(argv[i]); + if(params.n_predict <= 128) { // this example don't support long prompts + params.n_predict = 128; + } } else if (arg == "-r" || arg == "--reverse-prompt") { if (++i >= argc) @@ -858,7 +855,8 @@ int main(int argc, char **argv) } if(slot->hasNewToken()) { // new token notification stringstream ss; - json res_d = {{ "content", slot->sampled_token_str }}; + json res_d = {{ "content", slot->sampled_tokens.back() }}; + slot->sampled_tokens.pop_back(); ss << "data: " << res_d.dump() << "\n\n"; string result = ss.str(); if(!sink.write(result.c_str(), result.size())) {