llama : refactor sampling v2 (#9294)

- Add `struct llama_sampler` and `struct llama_sampler_i`
- Add `llama_sampler_` API
- Add `llama_sampler_chain_` API for chaining multiple samplers
- Remove `LLAMA_API_INTERNAL`
- Add `llama_perf_` API and remove old `llama_print_timings` and `llama_reset_timings`
This commit is contained in:
Georgi Gerganov 2024-09-07 15:16:19 +03:00 committed by GitHub
parent 947538acb8
commit df270ef745
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 3497 additions and 2914 deletions

View file

@ -33,6 +33,7 @@
static llama_context ** g_ctx;
static llama_model ** g_model;
static gpt_sampler ** g_smpl;
static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
@ -92,7 +93,7 @@ static void write_logfile(
yaml_dump_string_multiline(logfile, "output", output.c_str());
yaml_dump_vector_int(logfile, "output_tokens", output_tokens);
llama_dump_timing_info_yaml(logfile, ctx);
llama_perf_dump_yaml(logfile, ctx);
fclose(logfile);
}
@ -105,7 +106,7 @@ static void sigint_handler(int signo) {
} else {
console::cleanup();
printf("\n");
llama_print_timings(*g_ctx);
gpt_perf_print(*g_ctx, *g_smpl);
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
_exit(130);
}
@ -121,8 +122,7 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
llama_chat_msg new_msg{role, content};
auto formatted = llama_chat_format_single(
model, g_params->chat_template, chat_msgs, new_msg, role == "user");
auto formatted = llama_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
chat_msgs.push_back({role, content});
LOG("formatted: %s\n", formatted.c_str());
return formatted;
@ -137,7 +137,7 @@ int main(int argc, char ** argv) {
return 1;
}
llama_sampling_params & sparams = params.sparams;
auto & sparams = params.sparams;
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("main", "log"));
@ -183,27 +183,23 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
}
LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
print_build_info();
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}
LOG_TEE("%s: seed = %u\n", __func__, params.seed);
std::mt19937 rng(params.seed);
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
LOG("%s: llama backend init\n", __func__);
llama_backend_init();
llama_numa_init(params.numa);
llama_model * model;
llama_context * ctx;
llama_context * ctx_guidance = NULL;
llama_model * model = nullptr;
llama_context * ctx = nullptr;
gpt_sampler * smpl = nullptr;
std::vector<llama_chat_msg> chat_msgs;
g_model = &model;
g_ctx = &ctx;
g_smpl = &smpl;
// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
@ -211,10 +207,6 @@ int main(int argc, char ** argv) {
model = llama_init.model;
ctx = llama_init.context;
if (sparams.cfg_scale > 1.f) {
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
ctx_guidance = llama_new_context_with_model(model, lparams);
}
if (model == NULL) {
LOG_TEE("%s: error: unable to load model\n", __func__);
@ -251,9 +243,6 @@ int main(int argc, char ** argv) {
}
llama_attach_threadpool(ctx, threadpool, threadpool_batch);
if (ctx_guidance) {
llama_attach_threadpool(ctx_guidance, threadpool, threadpool_batch);
}
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
@ -337,24 +326,6 @@ int main(int argc, char ** argv) {
}
// Tokenize negative prompt
std::vector<llama_token> guidance_inp;
int guidance_offset = 0;
int original_prompt_len = 0;
if (ctx_guidance) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true, true);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
original_prompt_len = original_inp.size();
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
LOG("guidance_offset: %s", log_tostr(guidance_offset));
}
if ((int) embd_inp.size() > n_ctx - 4) {
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
return 1;
@ -421,15 +392,6 @@ int main(int argc, char ** argv) {
LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
}
if (ctx_guidance) {
LOG_TEE("\n");
LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str());
LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
for (int i = 0; i < (int) guidance_inp.size(); i++) {
LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
}
}
if (params.n_keep > add_bos) {
LOG_TEE("%s: static prompt based on n_keep: '", __func__);
for (int i = 0; i < params.n_keep; i++) {
@ -495,8 +457,15 @@ int main(int argc, char ** argv) {
}
}
}
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
smpl = gpt_sampler_init(model, sparams);
if (!smpl) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
}
LOG_TEE("sampling params: \n%s\n", sparams.print().c_str());
LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
// group-attention state
@ -543,7 +512,6 @@ int main(int argc, char ** argv) {
int n_remain = params.n_predict;
int n_consumed = 0;
int n_session_consumed = 0;
int n_past_guidance = 0;
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
@ -555,7 +523,6 @@ int main(int argc, char ** argv) {
display = params.display_prompt;
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
// tokenized antiprompts
std::vector<std::vector<llama_token>> antiprompt_ids;
@ -565,12 +532,6 @@ int main(int argc, char ** argv) {
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
}
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
if (!ctx_sampling) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
}
if (llama_model_has_encoder(model)) {
int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data();
@ -612,7 +573,7 @@ int main(int argc, char ** argv) {
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) >= n_ctx) {
if (n_past + (int) embd.size() >= n_ctx) {
if (params.n_predict == -2) {
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
@ -629,11 +590,7 @@ int main(int argc, char ** argv) {
n_past -= n_discard;
if (ctx_guidance) {
n_past_guidance -= n_discard;
}
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
LOG("after swap: n_past = %d\n", n_past);
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
@ -686,46 +643,6 @@ int main(int argc, char ** argv) {
}
}
// evaluate tokens in batches
// embd is typically prepared beforehand to fit within a batch, but not always
if (ctx_guidance) {
int input_size = 0;
llama_token * input_buf = NULL;
if (n_past_guidance < (int) guidance_inp.size()) {
// Guidance context should have the same data with these modifications:
//
// * Replace the initial prompt
// * Shift everything by guidance_offset
embd_guidance = guidance_inp;
if (embd.begin() + original_prompt_len < embd.end()) {
embd_guidance.insert(
embd_guidance.end(),
embd.begin() + original_prompt_len,
embd.end()
);
}
input_buf = embd_guidance.data();
input_size = embd_guidance.size();
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
} else {
input_buf = embd.data();
input_size = embd.size();
}
for (int i = 0; i < input_size; i += params.n_batch) {
int n_eval = std::min(input_size - i, params.n_batch);
if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
n_past_guidance += n_eval;
}
}
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
int n_eval = (int) embd.size() - i;
if (n_eval > params.n_batch) {
@ -755,7 +672,6 @@ int main(int argc, char ** argv) {
}
embd.clear();
embd_guidance.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// optionally save the session on first sample (for faster prompt loading next time)
@ -766,11 +682,11 @@ int main(int argc, char ** argv) {
LOG("saved session to %s\n", path_session.c_str());
}
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
gpt_sampler_accept(smpl, id, /* apply_grammar= */ true);
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());
embd.push_back(id);
@ -789,7 +705,7 @@ int main(int argc, char ** argv) {
// push the prompt in the sampling context in order to apply repetition penalties later
// for the prompt, we don't apply grammar rules
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false);
gpt_sampler_accept(smpl, embd_inp[n_consumed], /* apply_grammar= */ false);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
@ -832,7 +748,7 @@ int main(int argc, char ** argv) {
// check for reverse prompt in the last n_prev tokens
if (!params.antiprompt.empty()) {
const int n_prev = 32;
const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
const std::string last_output = gpt_sampler_prev_str(smpl, ctx, n_prev);
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
@ -854,7 +770,7 @@ int main(int argc, char ** argv) {
}
// check for reverse prompt using special tokens
llama_token last_token = llama_sampling_last(ctx_sampling);
llama_token last_token = gpt_sampler_last(smpl);
for (std::vector<llama_token> ids : antiprompt_ids) {
if (ids.size() == 1 && last_token == ids[0]) {
if (params.interactive) {
@ -871,7 +787,7 @@ int main(int argc, char ** argv) {
}
// deal with end of generation tokens in interactive mode
if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
if (llama_token_is_eog(model, gpt_sampler_last(smpl))) {
LOG("found an EOG token\n");
if (params.interactive) {
@ -892,7 +808,7 @@ int main(int argc, char ** argv) {
// if current token is not EOG, we add it to current assistant message
if (params.conversation) {
auto id = llama_sampling_last(ctx_sampling);
const auto id = gpt_sampler_last(smpl);
assistant_ss << llama_token_to_piece(ctx, id, false);
}
@ -988,7 +904,7 @@ int main(int argc, char ** argv) {
if (n_past > 0) {
if (is_interacting) {
llama_sampling_reset(ctx_sampling);
gpt_sampler_reset(smpl);
}
is_interacting = false;
}
@ -1013,14 +929,15 @@ int main(int argc, char ** argv) {
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
}
llama_print_timings(ctx);
LOG_TEE("\n");
gpt_perf_print(ctx, smpl);
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
if (ctx_guidance) { llama_free(ctx_guidance); }
gpt_sampler_free(smpl);
llama_free(ctx);
llama_free_model(model);
llama_sampling_free(ctx_sampling);
llama_backend_free();
ggml_threadpool_free(threadpool);