llama : refactor sampling v2 (#9294)

- Add `struct llama_sampler` and `struct llama_sampler_i`
- Add `llama_sampler_` API
- Add `llama_sampler_chain_` API for chaining multiple samplers
- Remove `LLAMA_API_INTERNAL`
- Add `llama_perf_` API and remove old `llama_print_timings` and `llama_reset_timings`
This commit is contained in:
Georgi Gerganov 2024-09-07 15:16:19 +03:00 committed by GitHub
parent 947538acb8
commit df270ef745
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 3497 additions and 2914 deletions

View file

@ -470,8 +470,6 @@ node index.js
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
`penalty_prompt`: This will replace the `prompt` for the purpose of the penalty evaluation. Can be either `null`, a string or an array of numbers representing tokens. Default: `null`, which is to use the original `prompt`.
`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.
`mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0`
@ -724,7 +722,6 @@ Example:
"stopping_word": ""
},
"penalize_nl": true,
"penalty_prompt_tokens": [],
"presence_penalty": 0.0,
"prompt": "Say hello to llama.cpp",
"repeat_last_n": 64,
@ -748,8 +745,7 @@ Example:
"tfs_z": 1.0,
"top_k": 40,
"top_p": 0.949999988079071,
"typical_p": 1.0,
"use_penalty_prompt_tokens": false
"typical_p": 1.0
}
]
```

View file

@ -3,7 +3,6 @@
#include "common.h"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "grammar-parser.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
@ -169,11 +168,13 @@ struct server_slot {
std::string stopping_word;
// sampling
llama_token sampled;
struct llama_sampling_params sparams;
llama_sampling_context * ctx_sampling = nullptr;
json json_schema;
struct gpt_sampler_params sparams;
struct gpt_sampler * smpl = nullptr;
llama_token sampled;
int32_t ga_i = 0; // group-attention state
int32_t ga_n = 1; // group-attention factor
int32_t ga_w = 512; // group-attention width
@ -651,8 +652,8 @@ struct server_context {
// Clear any sampling context
for (server_slot & slot : slots) {
if (slot.ctx_sampling != nullptr) {
llama_sampling_free(slot.ctx_sampling);
if (slot.smpl != nullptr) {
gpt_sampler_free(slot.smpl);
}
}
@ -883,8 +884,8 @@ struct server_context {
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
slot_params default_params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
llama_sampling_params default_sparams = params.sparams;
auto & data = task.data;
auto default_sparams = params.sparams;
const auto & data = task.data;
if (data.count("__oaicompat") != 0) {
slot.oaicompat = true;
@ -901,7 +902,7 @@ struct server_context {
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
@ -923,7 +924,8 @@ struct server_context {
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
return false;
} else if (data.contains("json_schema") && !data.contains("grammar")) {
}
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
slot.sparams.grammar = json_schema_to_grammar(schema);
@ -973,56 +975,11 @@ struct server_context {
}
}
// penalize user-provided tokens
{
slot.sparams.penalty_prompt_tokens.clear();
slot.sparams.use_penalty_prompt_tokens = false;
const auto & penalty_prompt = data.find("penalty_prompt");
if (penalty_prompt != data.end()) {
if (penalty_prompt->is_string()) {
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
if (slot.params.n_predict > 0) {
slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
}
slot.sparams.use_penalty_prompt_tokens = true;
LOG_VERBOSE("penalty_prompt_tokens", {
{"id_slot", slot.id},
{"tokens", slot.sparams.penalty_prompt_tokens},
});
}
else if (penalty_prompt->is_array()) {
const auto n_tokens = penalty_prompt->size();
slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
const int n_vocab = llama_n_vocab(model);
for (const auto & penalty_token : *penalty_prompt) {
if (penalty_token.is_number_integer()) {
const auto tok = penalty_token.get<llama_token>();
if (tok >= 0 && tok < n_vocab) {
slot.sparams.penalty_prompt_tokens.push_back(tok);
}
}
}
slot.sparams.use_penalty_prompt_tokens = true;
LOG_VERBOSE("penalty_prompt_tokens", {
{"id_slot", slot.id},
{"tokens", slot.sparams.penalty_prompt_tokens},
});
}
}
}
{
slot.sparams.logit_bias.clear();
if (json_value(data, "ignore_eos", false) && has_eos_token) {
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
}
const auto & logit_bias = data.find("logit_bias");
@ -1043,12 +1000,12 @@ struct server_context {
if (el[0].is_number_integer()) {
llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab) {
slot.sparams.logit_bias[tok] = bias;
slot.sparams.logit_bias.push_back({tok, bias});
}
} else if (el[0].is_string()) {
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
for (auto tok : toks) {
slot.sparams.logit_bias[tok] = bias;
slot.sparams.logit_bias.push_back({tok, bias});
}
}
}
@ -1070,26 +1027,27 @@ struct server_context {
}
{
const auto & samplers_sequence = data.find("samplers");
if (samplers_sequence != data.end() && samplers_sequence->is_array()) {
const auto & samplers = data.find("samplers");
if (samplers != data.end() && samplers->is_array()) {
std::vector<std::string> sampler_names;
for (const auto & sampler_name : *samplers_sequence) {
if (sampler_name.is_string()) {
sampler_names.emplace_back(sampler_name);
for (const auto & name : *samplers) {
if (name.is_string()) {
sampler_names.emplace_back(name);
}
}
slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false);
slot.sparams.samplers = gpt_sampler_types_from_names(sampler_names, false);
} else {
slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
slot.sparams.samplers = default_sparams.samplers;
}
}
{
if (slot.ctx_sampling != nullptr) {
llama_sampling_free(slot.ctx_sampling);
if (slot.smpl != nullptr) {
gpt_sampler_free(slot.smpl);
}
slot.ctx_sampling = llama_sampling_init(slot.sparams);
if (slot.ctx_sampling == nullptr) {
slot.smpl = gpt_sampler_init(model, slot.sparams);
if (slot.smpl == nullptr) {
// for now, the only error that may happen here is invalid grammar
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
return false;
@ -1178,11 +1136,6 @@ struct server_context {
slot.generated_text += token_str;
slot.has_next_token = true;
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
// we can change penalty_prompt_tokens because it is always created from scratch each request
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
}
// check if there is incomplete UTF-8 character at the end
bool incomplete = false;
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
@ -1300,13 +1253,10 @@ struct server_context {
}
json get_formated_generation(const server_slot & slot) const {
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
std::vector<std::string> samplers_sequence;
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
for (const auto & sampler_type : slot.sparams.samplers_sequence) {
samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
std::vector<std::string> samplers;
samplers.reserve(slot.sparams.samplers.size());
for (const auto & sampler : slot.sparams.samplers) {
samplers.emplace_back(gpt_sampler_type_to_str(sampler));
}
return json {
@ -1321,13 +1271,11 @@ struct server_context {
{"top_p", slot.sparams.top_p},
{"min_p", slot.sparams.min_p},
{"tfs_z", slot.sparams.tfs_z},
{"typical_p", slot.sparams.typical_p},
{"typical_p", slot.sparams.typ_p},
{"repeat_last_n", slot.sparams.penalty_last_n},
{"repeat_penalty", slot.sparams.penalty_repeat},
{"presence_penalty", slot.sparams.penalty_present},
{"frequency_penalty", slot.sparams.penalty_freq},
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
{"mirostat", slot.sparams.mirostat},
{"mirostat_tau", slot.sparams.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta},
@ -1336,13 +1284,13 @@ struct server_context {
{"max_tokens", slot.params.n_predict}, // User configured n_predict
{"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard},
{"ignore_eos", ignore_eos},
{"ignore_eos", slot.sparams.ignore_eos},
{"stream", slot.params.stream},
{"logit_bias", slot.sparams.logit_bias},
//{"logit_bias", slot.sparams.logit_bias},
{"n_probs", slot.sparams.n_probs},
{"min_keep", slot.sparams.min_keep},
{"grammar", slot.sparams.grammar},
{"samplers", samplers_sequence}
{"samplers", samplers},
};
}
@ -2136,7 +2084,7 @@ struct server_context {
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
llama_sampling_reset(slot.ctx_sampling);
gpt_sampler_reset(slot.smpl);
if (!slot.params.cache_prompt) {
slot.n_past_se = 0;
@ -2149,7 +2097,7 @@ struct server_context {
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
}
}
}
@ -2202,7 +2150,7 @@ struct server_context {
slot.n_past_se = 0;
slot.ga_i = 0;
// TODO: is the system prompt ever in the sampling context?
llama_sampling_reset(slot.ctx_sampling);
gpt_sampler_reset(slot.smpl);
}
// remove the non-common part from the cache
@ -2375,18 +2323,18 @@ struct server_context {
slot.release();
slot.i_batch = -1;
continue; // continue loop of slots
} else {
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;
}
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
}
completion_token_output result;
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
gpt_sampler_accept(slot.smpl, id, true);
slot.n_decoded += 1;
if (slot.n_decoded == 1) {
@ -2395,34 +2343,15 @@ struct server_context {
metrics.on_prompt_eval(slot);
}
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
result.tok = id;
const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
if (n_probs > 0) {
const size_t n_valid = slot.ctx_sampling->n_valid;
const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
// Make sure at least n_probs top tokens are at the front of the vector:
if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
}
if (slot.sparams.temp == 0.0f) {
// With greedy sampling the probabilities have possibly not been calculated.
for (size_t i = 0; i < n_probs; ++i) {
result.probs.push_back({
cur_p.data[i].id,
i == 0 ? 1.0f : 0.0f
});
}
} else {
for (size_t i = 0; i < n_probs; ++i) {
result.probs.push_back({
cur_p.data[i].id,
i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
});
}
}
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
result.probs.push_back({
cur_p->data[i].id,
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
});
}
if (!process_token(result, slot)) {