refactor code + remove unused comments + improved README.md
This commit is contained in:
parent
9f72b44635
commit
7e64bfe060
2 changed files with 101 additions and 54 deletions
|
@ -3,7 +3,7 @@
|
|||
#include "build-info.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#define SERVER_MULTIMODAL_SUPPORT
|
||||
// #define SERVER_MULTIMODAL_SUPPORT
|
||||
|
||||
#ifdef SERVER_MULTIMODAL_SUPPORT
|
||||
#include "../llava/clip.h"
|
||||
|
@ -44,11 +44,6 @@ struct server_params
|
|||
int32_t write_timeout = 600;
|
||||
};
|
||||
|
||||
// struct beam_search_callback_data {
|
||||
// llama_server_context* ctx;
|
||||
// llama_client_slot* slot;
|
||||
// };
|
||||
|
||||
static bool server_verbose = false;
|
||||
|
||||
#if SERVER_VERBOSE != 1
|
||||
|
@ -660,6 +655,7 @@ struct llama_server_context
|
|||
}
|
||||
waitAllAreIdle();
|
||||
all_slots_are_idle = true;
|
||||
|
||||
// wait until system prompt load
|
||||
update_system_prompt = true;
|
||||
while(update_system_prompt) {
|
||||
|
@ -672,7 +668,11 @@ struct llama_server_context
|
|||
system_prompt = sys_props.value("prompt", "");
|
||||
user_name = sys_props.value("anti_prompt", "");
|
||||
assistant_name = sys_props.value("assistant_name", "");
|
||||
notifySystemPromptChanged();
|
||||
if(slots.size() > 0) {
|
||||
notifySystemPromptChanged();
|
||||
} else {
|
||||
update_system_prompt = true;
|
||||
}
|
||||
}
|
||||
|
||||
void waitAllAreIdle() {
|
||||
|
@ -813,6 +813,7 @@ struct llama_server_context
|
|||
});
|
||||
return has_next_token; // continue
|
||||
}
|
||||
|
||||
#ifdef SERVER_MULTIMODAL_SUPPORT
|
||||
bool processImages(llama_client_slot &slot) {
|
||||
for(slot_image &img : slot.images) {
|
||||
|
@ -1204,6 +1205,11 @@ struct llama_server_context
|
|||
}
|
||||
};
|
||||
|
||||
struct server_beam_search_callback_data {
|
||||
llama_context * ctx;
|
||||
llama_client_slot * slot;
|
||||
};
|
||||
|
||||
static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
||||
const server_params &sparams)
|
||||
{
|
||||
|
@ -1251,6 +1257,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
|
||||
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
||||
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
||||
printf(" -spf FNAME, --system-prompt-file FNAME\n");
|
||||
printf(" Set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
|
||||
#ifdef SERVER_MULTIMODAL_SUPPORT
|
||||
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n");
|
||||
#endif
|
||||
|
@ -1258,7 +1266,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||
}
|
||||
|
||||
static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||
gpt_params ¶ms)
|
||||
gpt_params ¶ms, llama_server_context& llama)
|
||||
{
|
||||
gpt_params default_params;
|
||||
server_params default_sparams;
|
||||
|
@ -1523,6 +1531,26 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
break;
|
||||
}
|
||||
params.n_predict = std::stoi(argv[i]);
|
||||
} else if (arg == "-spf" || arg == "--system-prompt-file")
|
||||
{
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
std::ifstream file(argv[i]);
|
||||
if (!file) {
|
||||
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
std::string systm_content = "";
|
||||
std::copy(
|
||||
std::istreambuf_iterator<char>(file),
|
||||
std::istreambuf_iterator<char>(),
|
||||
std::back_inserter(systm_content)
|
||||
);
|
||||
llama.processSystemPromptData(json::parse(systm_content));
|
||||
}
|
||||
#ifdef SERVER_MULTIMODAL_SUPPORT
|
||||
else if(arg == "--mmproj") {
|
||||
|
@ -1864,8 +1892,8 @@ static void log_server_request(const Request &req, const Response &res)
|
|||
});
|
||||
}
|
||||
|
||||
static bool is_at_eob(llama_server_context * server_context, const llama_token *tokens, const size_t n_tokens) {
|
||||
return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context->ctx);
|
||||
static bool is_at_eob(const server_beam_search_callback_data & server_context, const llama_token *tokens, const size_t n_tokens) {
|
||||
return n_tokens && tokens[n_tokens - 1] == llama_token_eos(server_context.ctx);
|
||||
}
|
||||
|
||||
// Function matching type llama_beam_search_callback_fn_t.
|
||||
|
@ -1875,34 +1903,34 @@ static bool is_at_eob(llama_server_context * server_context, const llama_token *
|
|||
// This is also called when the stop condition is met.
|
||||
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
|
||||
|
||||
// AVOID HEADACHES unnecessaries
|
||||
// NO TESTED after PR #3589
|
||||
|
||||
// static void beam_search_callback(void *callback_data, llama_beams_state beams_state) {
|
||||
// auto & llama = *static_cast<beam_search_callback_data*>(callback_data);
|
||||
// // Mark beams as EOS as needed.
|
||||
// for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
|
||||
// llama_beam_view& beam_view = beams_state.beam_views[i];
|
||||
// if (!beam_view.eob && is_at_eob(llama.ctx, beam_view.tokens, beam_view.n_tokens)) {
|
||||
// beam_view.eob = true;
|
||||
// }
|
||||
// }
|
||||
// printf(","); // Show progress
|
||||
// if (const size_t n = beams_state.common_prefix_length) {
|
||||
// llama.slot->generated_token_probs.resize(llama.slot->generated_token_probs.size() + n);
|
||||
// assert(0u < beams_state.n_beams);
|
||||
// const llama_token * tokens = beams_state.beam_views[0].tokens;
|
||||
// const auto map = [](llama_token tok) { return completion_token_output{{},tok}; };
|
||||
// std::transform(tokens, tokens + n, llama.slot->generated_token_probs.end() - n, map);
|
||||
// printf("%zu", n);
|
||||
// }
|
||||
// fflush(stdout);
|
||||
// #if 0 // DEBUG: print current beams for this iteration
|
||||
// std::cout << "\n\nCurrent beams:\n";
|
||||
// for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
|
||||
// std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
|
||||
// }
|
||||
// #endif
|
||||
// }
|
||||
static void beam_search_callback(void *callback_data, llama_beams_state beams_state) {
|
||||
auto & llama = *static_cast<server_beam_search_callback_data*>(callback_data);
|
||||
// Mark beams as EOS as needed.
|
||||
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
|
||||
llama_beam_view& beam_view = beams_state.beam_views[i];
|
||||
if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) {
|
||||
beam_view.eob = true;
|
||||
}
|
||||
}
|
||||
printf(","); // Show progress
|
||||
if (const size_t n = beams_state.common_prefix_length) {
|
||||
llama.slot->generated_token_probs.resize(llama.slot->generated_token_probs.size() + n);
|
||||
assert(0u < beams_state.n_beams);
|
||||
const llama_token * tokens = beams_state.beam_views[0].tokens;
|
||||
const auto map = [](llama_token tok) { return completion_token_output{{},tok}; };
|
||||
std::transform(tokens, tokens + n, llama.slot->generated_token_probs.end() - n, map);
|
||||
printf("%zu", n);
|
||||
}
|
||||
fflush(stdout);
|
||||
#if 0 // DEBUG: print current beams for this iteration
|
||||
std::cout << "\n\nCurrent beams:\n";
|
||||
for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
|
||||
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
struct token_translator {
|
||||
llama_context * ctx;
|
||||
|
@ -1933,7 +1961,7 @@ int main(int argc, char **argv)
|
|||
// struct that contains llama context and inference
|
||||
llama_server_context llama;
|
||||
|
||||
server_params_parse(argc, argv, sparams, params);
|
||||
server_params_parse(argc, argv, sparams, params, llama);
|
||||
|
||||
if (params.model_alias == "unknown")
|
||||
{
|
||||
|
@ -2015,8 +2043,6 @@ int main(int argc, char **argv)
|
|||
llama.processSystemPromptData(data["system_prompt"]);
|
||||
}
|
||||
|
||||
// llama_reset_timings(llama.ctx);
|
||||
|
||||
slot->reset();
|
||||
|
||||
parse_options_completion(data, slot, llama);
|
||||
|
@ -2030,14 +2056,14 @@ int main(int argc, char **argv)
|
|||
if (!slot->params.stream) {
|
||||
std::string completion_text = "";
|
||||
if (llama.params.n_beams) {
|
||||
// // Fill llama.generated_token_probs vector with final beam.
|
||||
// beam_search_callback_data data_;
|
||||
// data_.slot = slot;
|
||||
// data_.ctx = &llama;
|
||||
// llama_beam_search(llama.ctx, beam_search_callback, &data_, llama.params.n_beams,
|
||||
// slot->n_past, llama.params.n_predict);
|
||||
// // Translate llama.generated_token_probs to llama.generated_text.
|
||||
// append_to_generated_text_from_generated_token_probs(llama, slot);
|
||||
// Fill llama.generated_token_probs vector with final beam.
|
||||
server_beam_search_callback_data data_beam;
|
||||
data_beam.slot = slot;
|
||||
data_beam.ctx = llama.ctx;
|
||||
llama_beam_search(llama.ctx, beam_search_callback, &data_beam, llama.params.n_beams,
|
||||
slot->n_past, llama.params.n_predict);
|
||||
// Translate llama.generated_token_probs to llama.generated_text.
|
||||
append_to_generated_text_from_generated_token_probs(llama, slot);
|
||||
} else {
|
||||
while (slot->isProcessing()) {
|
||||
if(slot->hasNewToken()) {
|
||||
|
@ -2055,8 +2081,6 @@ int main(int argc, char **argv)
|
|||
}
|
||||
|
||||
const json data = format_final_response(llama, slot, completion_text, probs);
|
||||
|
||||
//llama_print_timings(llama.ctx);
|
||||
slot->release();
|
||||
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
|
||||
"application/json");
|
||||
|
@ -2138,8 +2162,6 @@ int main(int argc, char **argv)
|
|||
llama.processSystemPromptData(data["system_prompt"]);
|
||||
}
|
||||
|
||||
// llama_reset_timings(llama.ctx);
|
||||
|
||||
slot->reset();
|
||||
slot->infill = true;
|
||||
|
||||
|
@ -2167,7 +2189,6 @@ int main(int argc, char **argv)
|
|||
}
|
||||
|
||||
const json data = format_final_response(llama, slot, completion_text, probs);
|
||||
//llama_print_timings(llama.ctx);
|
||||
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
|
||||
"application/json");
|
||||
} else {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue