Merge branch 'master' into server-rev
This commit is contained in:
commit
176993c871
46 changed files with 583 additions and 4691 deletions
|
@ -181,8 +181,6 @@ struct slot_params
|
|||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
|
||||
json input_prefix;
|
||||
|
@ -339,36 +337,6 @@ static T json_value(const json &body, const std::string &key, const T &default_v
|
|||
: default_value;
|
||||
}
|
||||
|
||||
// TODO: this is not needed, should reuse llama_sampling_init from common/sampling.h
|
||||
static struct llama_sampling_context * llama_sampling_init_srv(const struct llama_sampling_params &sparams, const std::string &grammar, int n_ctx)
|
||||
{
|
||||
struct llama_sampling_context * result = new llama_sampling_context();
|
||||
|
||||
result->params = sparams;
|
||||
result->grammar = nullptr;
|
||||
|
||||
// if there is a grammar, parse it
|
||||
if (!grammar.empty()) {
|
||||
result->parsed_grammar = grammar_parser::parse(grammar.c_str());
|
||||
|
||||
// will be empty (default) if there are parse errors
|
||||
if (result->parsed_grammar.rules.empty()) {
|
||||
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
|
||||
|
||||
result->grammar = llama_grammar_init(
|
||||
grammar_rules.data(),
|
||||
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
||||
}
|
||||
|
||||
result->prev.resize(n_ctx);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct llama_client_slot
|
||||
{
|
||||
int id;
|
||||
|
@ -410,7 +378,6 @@ struct llama_client_slot
|
|||
struct llama_sampling_params sparams;
|
||||
llama_sampling_context* ctx_sampling = nullptr;
|
||||
bool has_next_token = true;
|
||||
int max_context_size = 0;
|
||||
|
||||
// multimodal
|
||||
std::vector<slot_image> images;
|
||||
|
@ -435,7 +402,7 @@ struct llama_client_slot
|
|||
llama_sampling_free(ctx_sampling);
|
||||
}
|
||||
|
||||
ctx_sampling = llama_sampling_init_srv(sparams, params.grammar, max_context_size);
|
||||
ctx_sampling = llama_sampling_init(sparams);
|
||||
|
||||
for (slot_image &img : images)
|
||||
{
|
||||
|
@ -455,7 +422,7 @@ struct llama_client_slot
|
|||
llama_sampling_free(ctx_sampling);
|
||||
}
|
||||
|
||||
ctx_sampling = llama_sampling_init_srv(sparams, params.grammar, max_context_size);
|
||||
ctx_sampling = llama_sampling_init(sparams);
|
||||
return ctx_sampling != nullptr;
|
||||
}
|
||||
|
||||
|
@ -629,7 +596,7 @@ struct llama_server_context
|
|||
{
|
||||
llama_client_slot slot;
|
||||
slot.id = i;
|
||||
slot.max_context_size = max_ctx_per_slot;
|
||||
slot.sparams.n_prev = max_ctx_per_slot;
|
||||
slot.reset();
|
||||
|
||||
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, max_ctx_per_slot);
|
||||
|
@ -702,26 +669,26 @@ struct llama_server_context
|
|||
bool launch_slot_with_data(llama_client_slot* &slot, json data) {
|
||||
slot_params default_params;
|
||||
llama_sampling_params default_sparams;
|
||||
slot->params.stream = json_value(data, "stream", false);
|
||||
slot->params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||
slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict);
|
||||
slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||
slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||
slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
||||
slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
|
||||
slot->sparams.repeat_last_n = json_value(data, "repeat_last_n", default_sparams.repeat_last_n);
|
||||
slot->sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||
slot->sparams.repeat_penalty = json_value(data, "repeat_penalty", default_sparams.repeat_penalty);
|
||||
slot->sparams.presence_penalty = json_value(data, "presence_penalty", default_sparams.presence_penalty);
|
||||
slot->sparams.frequency_penalty = json_value(data, "frequency_penalty", default_sparams.frequency_penalty);
|
||||
slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
||||
slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
||||
slot->params.seed = json_value(data, "seed", default_params.seed);
|
||||
slot->params.grammar = json_value(data, "grammar", default_params.grammar);
|
||||
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot->params.stream = json_value(data, "stream", false);
|
||||
slot->params.cache_prompt = json_value(data, "cache_prompt", false);
|
||||
slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict);
|
||||
slot->sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
||||
slot->sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||
slot->sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
||||
slot->sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
|
||||
slot->sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||
slot->sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
||||
slot->sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
||||
slot->sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
||||
slot->sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
||||
slot->sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
||||
slot->sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
||||
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
||||
slot->params.seed = json_value(data, "seed", default_params.seed);
|
||||
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
|
||||
// infill
|
||||
if (data.count("input_prefix") != 0)
|
||||
|
@ -1157,10 +1124,10 @@ struct llama_server_context
|
|||
{"top_p", slot.sparams.top_p},
|
||||
{"tfs_z", slot.sparams.tfs_z},
|
||||
{"typical_p", slot.sparams.typical_p},
|
||||
{"repeat_last_n", slot.sparams.repeat_last_n},
|
||||
{"repeat_penalty", slot.sparams.repeat_penalty},
|
||||
{"presence_penalty", slot.sparams.presence_penalty},
|
||||
{"frequency_penalty", slot.sparams.frequency_penalty},
|
||||
{"repeat_last_n", slot.sparams.penalty_last_n},
|
||||
{"repeat_penalty", slot.sparams.penalty_repeat},
|
||||
{"presence_penalty", slot.sparams.penalty_present},
|
||||
{"frequency_penalty", slot.sparams.penalty_freq},
|
||||
{"mirostat", slot.sparams.mirostat},
|
||||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||
|
@ -1172,7 +1139,7 @@ struct llama_server_context
|
|||
{"stream", slot.params.stream},
|
||||
{"logit_bias", slot.sparams.logit_bias},
|
||||
{"n_probs", slot.sparams.n_probs},
|
||||
{"grammar", slot.params.grammar},
|
||||
{"grammar", slot.sparams.grammar},
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -1707,7 +1674,7 @@ struct llama_server_context
|
|||
completion_token_output result;
|
||||
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
|
||||
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, id);
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
|
||||
|
||||
if (slot.n_decoded == 1)
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue