minor : comments + rename
ggml-ci
This commit is contained in:
parent
1c626e2fe1
commit
373d782d42
4 changed files with 13 additions and 7 deletions
|
@ -69,7 +69,7 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds
|
||||||
llama_token llama_sampling_sample(
|
llama_token llama_sampling_sample(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
struct llama_context * ctx_guidance,
|
struct llama_context * ctx_cfg,
|
||||||
const int idx) {
|
const int idx) {
|
||||||
const int n_ctx = llama_n_ctx(ctx_main);
|
const int n_ctx = llama_n_ctx(ctx_main);
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
@ -110,8 +110,8 @@ llama_token llama_sampling_sample(
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||||
|
|
||||||
if (ctx_guidance) {
|
if (ctx_cfg) {
|
||||||
llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_guidance, params.cfg_scale);
|
llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_cfg, params.cfg_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply penalties
|
// apply penalties
|
||||||
|
|
|
@ -80,7 +80,7 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds
|
||||||
// - ctx_sampling: sampling-specific context
|
// - ctx_sampling: sampling-specific context
|
||||||
//
|
//
|
||||||
// optional:
|
// optional:
|
||||||
// - ctx_guidance: context to use for guidance
|
// - ctx_cfg: context to use for classifier-free guidance
|
||||||
// - idx: sample from llama_get_logits_ith(ctx, idx)
|
// - idx: sample from llama_get_logits_ith(ctx, idx)
|
||||||
//
|
//
|
||||||
// returns:
|
// returns:
|
||||||
|
@ -90,7 +90,7 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds
|
||||||
llama_token llama_sampling_sample(
|
llama_token llama_sampling_sample(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
struct llama_context * ctx_guidance,
|
struct llama_context * ctx_cfg,
|
||||||
int idx = 0);
|
int idx = 0);
|
||||||
|
|
||||||
void llama_sampling_accept(
|
void llama_sampling_accept(
|
||||||
|
|
|
@ -119,8 +119,8 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
||||||
|
|
||||||
// GG: are we sure that the should be a trailing whitespace at the end of this string?
|
// GG: are we sure that there should be a trailing whitespace at the end of this string?
|
||||||
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params.n_batch, &n_past);
|
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params.n_batch, &n_past);
|
||||||
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
|
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
|
||||||
eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
|
eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
|
||||||
eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past);
|
eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past);
|
||||||
|
|
|
@ -612,8 +612,14 @@ int main(int argc, char ** argv) {
|
||||||
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
|
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
|
||||||
while ((int) embd_inp.size() > n_consumed) {
|
while ((int) embd_inp.size() > n_consumed) {
|
||||||
embd.push_back(embd_inp[n_consumed]);
|
embd.push_back(embd_inp[n_consumed]);
|
||||||
|
|
||||||
|
// GG: I'm not sure it's a good idea to push the prompt tokens into the sampling context
|
||||||
|
// Most likely will remove this in the future to avoid exposing "prev"
|
||||||
|
// Same thing is done in "server". If we stop pushing the prompt tokens, then the repetition
|
||||||
|
// penalty will be applied only based on the tokens generated by the model.
|
||||||
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
|
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
|
||||||
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
|
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
|
||||||
|
|
||||||
++n_consumed;
|
++n_consumed;
|
||||||
if ((int) embd.size() >= params.n_batch) {
|
if ((int) embd.size() >= params.n_batch) {
|
||||||
break;
|
break;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue