From 373d782d42ca17030827a8dcab1adba43849cfbf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Oct 2023 18:17:31 +0300 Subject: [PATCH] minor : comments + rename ggml-ci --- common/sampling.cpp | 6 +++--- common/sampling.h | 4 ++-- examples/llava/llava.cpp | 4 ++-- examples/main/main.cpp | 6 ++++++ 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 388085fdc..0b2466581 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -69,7 +69,7 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, - struct llama_context * ctx_guidance, + struct llama_context * ctx_cfg, const int idx) { const int n_ctx = llama_n_ctx(ctx_main); const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); @@ -110,8 +110,8 @@ llama_token llama_sampling_sample( llama_token_data_array cur_p = { cur.data(), cur.size(), false }; - if (ctx_guidance) { - llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_guidance, params.cfg_scale); + if (ctx_cfg) { + llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_cfg, params.cfg_scale); } // apply penalties diff --git a/common/sampling.h b/common/sampling.h index bb3c6a63c..50afcbc12 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -80,7 +80,7 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds // - ctx_sampling: sampling-specific context // // optional: -// - ctx_guidance: context to use for guidance +// - ctx_cfg: context to use for classifier-free guidance // - idx: sample from llama_get_logits_ith(ctx, idx) // // returns: @@ -90,7 +90,7 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, - struct llama_context * ctx_guidance, + struct llama_context * ctx_cfg, int idx = 0); void llama_sampling_accept( diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 1ac730cc7..de64ee713 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -119,8 +119,8 @@ int main(int argc, char ** argv) { const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; - // GG: are we sure that the should be a trailing whitespace at the end of this string? - eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params.n_batch, &n_past); + // GG: are we sure that there should be a trailing whitespace at the end of this string? + eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params.n_batch, &n_past); eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past); eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past); eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 316c7bf05..97273aef5 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -612,8 +612,14 @@ int main(int argc, char ** argv) { LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); while ((int) embd_inp.size() > n_consumed) { embd.push_back(embd_inp[n_consumed]); + + // GG: I'm not sure it's a good idea to push the prompt tokens into the sampling context + // Most likely will remove this in the future to avoid exposing "prev" + // Same thing is done in "server". If we stop pushing the prompt tokens, then the repetition + // penalty will be applied only based on the tokens generated by the model. ctx_sampling->prev.erase(ctx_sampling->prev.begin()); ctx_sampling->prev.push_back(embd_inp[n_consumed]); + ++n_consumed; if ((int) embd.size() >= params.n_batch) { break;