refactor code + remove unused comments + improved README.md

This commit is contained in:
FSSRepo 2023-10-14 00:31:34 -04:00
parent 9f72b44635
commit 7e64bfe060
2 changed files with 101 additions and 54 deletions

View file

@ -3,7 +3,7 @@
#include "build-info.h"
#include "grammar-parser.h"
#define SERVER_MULTIMODAL_SUPPORT
// #define SERVER_MULTIMODAL_SUPPORT
#ifdef SERVER_MULTIMODAL_SUPPORT
#include "../llava/clip.h"
@ -44,11 +44,6 @@ struct server_params
int32_t write_timeout = 600;
};
// struct beam_search_callback_data {
// llama_server_context* ctx;
// llama_client_slot* slot;
// };
static bool server_verbose = false;
#if SERVER_VERBOSE != 1
@ -660,6 +655,7 @@ struct llama_server_context
}
waitAllAreIdle();
all_slots_are_idle = true;
// wait until system prompt load
update_system_prompt = true;
while(update_system_prompt) {
@ -672,7 +668,11 @@ struct llama_server_context
system_prompt = sys_props.value("prompt", "");
user_name = sys_props.value("anti_prompt", "");
assistant_name = sys_props.value("assistant_name", "");
notifySystemPromptChanged();
if(slots.size() > 0) {
notifySystemPromptChanged();
} else {
update_system_prompt = true;
}
}
void waitAllAreIdle() {
@ -813,6 +813,7 @@ struct llama_server_context
});
return has_next_token; // continue
}
#ifdef SERVER_MULTIMODAL_SUPPORT
bool processImages(llama_client_slot &slot) {
for(slot_image &img : slot.images) {
@ -1204,6 +1205,11 @@ struct llama_server_context
}
};
struct server_beam_search_callback_data {
llama_context * ctx;
llama_client_slot * slot;
};
static void server_print_usage(const char *argv0, const gpt_params &params,
const server_params &sparams)
{
@ -1251,6 +1257,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" -spf FNAME, --system-prompt-file FNAME\n");
printf(" Set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
#ifdef SERVER_MULTIMODAL_SUPPORT
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n");
#endif
@ -1258,7 +1266,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
}
static void server_params_parse(int argc, char **argv, server_params &sparams,
gpt_params &params)
gpt_params &params, llama_server_context& llama)
{
gpt_params default_params;
server_params default_sparams;
@ -1523,6 +1531,26 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
params.n_predict = std::stoi(argv[i]);
} else if (arg == "-spf" || arg == "--system-prompt-file")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
invalid_param = true;
break;
}
std::string systm_content = "";
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(systm_content)
);
llama.processSystemPromptData(json::parse(systm_content));
}
#ifdef SERVER_MULTIMODAL_SUPPORT
else if(arg == "--mmproj") {
@ -1864,8 +1892,8 @@ static void log_server_request(const Request &req, const Response &res)
});
}
static bool is_at_eob(llama_server_context * server_context, const llama_token *tokens, const size_t n_tokens) {
return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context->ctx);
static bool is_at_eob(const server_beam_search_callback_data & server_context, const llama_token *tokens, const size_t n_tokens) {
return n_tokens && tokens[n_tokens - 1] == llama_token_eos(server_context.ctx);
}
// Function matching type llama_beam_search_callback_fn_t.
@ -1875,34 +1903,34 @@ static bool is_at_eob(llama_server_context * server_context, const llama_token *
// This is also called when the stop condition is met.
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
// AVOID HEADACHES unnecessaries
// NO TESTED after PR #3589
// static void beam_search_callback(void *callback_data, llama_beams_state beams_state) {
// auto & llama = *static_cast<beam_search_callback_data*>(callback_data);
// // Mark beams as EOS as needed.
// for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
// llama_beam_view& beam_view = beams_state.beam_views[i];
// if (!beam_view.eob && is_at_eob(llama.ctx, beam_view.tokens, beam_view.n_tokens)) {
// beam_view.eob = true;
// }
// }
// printf(","); // Show progress
// if (const size_t n = beams_state.common_prefix_length) {
// llama.slot->generated_token_probs.resize(llama.slot->generated_token_probs.size() + n);
// assert(0u < beams_state.n_beams);
// const llama_token * tokens = beams_state.beam_views[0].tokens;
// const auto map = [](llama_token tok) { return completion_token_output{{},tok}; };
// std::transform(tokens, tokens + n, llama.slot->generated_token_probs.end() - n, map);
// printf("%zu", n);
// }
// fflush(stdout);
// #if 0 // DEBUG: print current beams for this iteration
// std::cout << "\n\nCurrent beams:\n";
// for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
// std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
// }
// #endif
// }
static void beam_search_callback(void *callback_data, llama_beams_state beams_state) {
auto & llama = *static_cast<server_beam_search_callback_data*>(callback_data);
// Mark beams as EOS as needed.
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
llama_beam_view& beam_view = beams_state.beam_views[i];
if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) {
beam_view.eob = true;
}
}
printf(","); // Show progress
if (const size_t n = beams_state.common_prefix_length) {
llama.slot->generated_token_probs.resize(llama.slot->generated_token_probs.size() + n);
assert(0u < beams_state.n_beams);
const llama_token * tokens = beams_state.beam_views[0].tokens;
const auto map = [](llama_token tok) { return completion_token_output{{},tok}; };
std::transform(tokens, tokens + n, llama.slot->generated_token_probs.end() - n, map);
printf("%zu", n);
}
fflush(stdout);
#if 0 // DEBUG: print current beams for this iteration
std::cout << "\n\nCurrent beams:\n";
for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
}
#endif
}
struct token_translator {
llama_context * ctx;
@ -1933,7 +1961,7 @@ int main(int argc, char **argv)
// struct that contains llama context and inference
llama_server_context llama;
server_params_parse(argc, argv, sparams, params);
server_params_parse(argc, argv, sparams, params, llama);
if (params.model_alias == "unknown")
{
@ -2015,8 +2043,6 @@ int main(int argc, char **argv)
llama.processSystemPromptData(data["system_prompt"]);
}
// llama_reset_timings(llama.ctx);
slot->reset();
parse_options_completion(data, slot, llama);
@ -2030,14 +2056,14 @@ int main(int argc, char **argv)
if (!slot->params.stream) {
std::string completion_text = "";
if (llama.params.n_beams) {
// // Fill llama.generated_token_probs vector with final beam.
// beam_search_callback_data data_;
// data_.slot = slot;
// data_.ctx = &llama;
// llama_beam_search(llama.ctx, beam_search_callback, &data_, llama.params.n_beams,
// slot->n_past, llama.params.n_predict);
// // Translate llama.generated_token_probs to llama.generated_text.
// append_to_generated_text_from_generated_token_probs(llama, slot);
// Fill llama.generated_token_probs vector with final beam.
server_beam_search_callback_data data_beam;
data_beam.slot = slot;
data_beam.ctx = llama.ctx;
llama_beam_search(llama.ctx, beam_search_callback, &data_beam, llama.params.n_beams,
slot->n_past, llama.params.n_predict);
// Translate llama.generated_token_probs to llama.generated_text.
append_to_generated_text_from_generated_token_probs(llama, slot);
} else {
while (slot->isProcessing()) {
if(slot->hasNewToken()) {
@ -2055,8 +2081,6 @@ int main(int argc, char **argv)
}
const json data = format_final_response(llama, slot, completion_text, probs);
//llama_print_timings(llama.ctx);
slot->release();
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
"application/json");
@ -2138,8 +2162,6 @@ int main(int argc, char **argv)
llama.processSystemPromptData(data["system_prompt"]);
}
// llama_reset_timings(llama.ctx);
slot->reset();
slot->infill = true;
@ -2167,7 +2189,6 @@ int main(int argc, char **argv)
}
const json data = format_final_response(llama, slot, completion_text, probs);
//llama_print_timings(llama.ctx);
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
"application/json");
} else {