Fix styling based on review

This commit is contained in:
Bach Le 2023-07-10 23:50:17 +08:00
parent 325fc88141
commit abf164d71e
4 changed files with 49 additions and 62 deletions

View file

@ -556,7 +556,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res; return res;
} }
struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_params & params) { struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params(); auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx; lparams.n_ctx = params.n_ctx;
@ -576,7 +576,7 @@ struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_p
} }
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) { std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) {
auto lparams = llama_get_context_params_from_gpt_params(params); auto lparams = llama_context_params_from_gpt_params(params);
llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
if (model == NULL) { if (model == NULL) {

View file

@ -105,7 +105,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
// //
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params); std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params);
struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
// //
// Console utils // Console utils

View file

@ -109,14 +109,14 @@ int main(int argc, char ** argv) {
llama_model * model; llama_model * model;
llama_context * ctx; llama_context * ctx;
llama_context * guidance_ctx = NULL; llama_context * ctx_guidance = NULL;
g_ctx = &ctx; g_ctx = &ctx;
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
std::tie(model, ctx) = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (params.cfg_scale > 1.f) { if (params.cfg_scale > 1.f) {
struct llama_context_params lparams = llama_get_context_params_from_gpt_params(params); struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
guidance_ctx = llama_new_context_with_model(model, lparams); ctx_guidance = llama_new_context_with_model(model, lparams);
} }
if (model == NULL) { if (model == NULL) {
@ -202,9 +202,9 @@ int main(int argc, char ** argv) {
std::vector<llama_token> guidance_inp; std::vector<llama_token> guidance_inp;
int guidance_offset = 0; int guidance_offset = 0;
int original_prompt_len = 0; int original_prompt_len = 0;
if (guidance_ctx) { if (ctx_guidance) {
params.cfg_negative_prompt.insert(0, 1, ' '); params.cfg_negative_prompt.insert(0, 1, ' ');
guidance_inp = ::llama_tokenize(guidance_ctx, params.cfg_negative_prompt, true); guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true);
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true); std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
original_prompt_len = original_inp.size(); original_prompt_len = original_inp.size();
@ -278,7 +278,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
} }
if (guidance_ctx) { if (ctx_guidance) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
@ -363,13 +363,13 @@ int main(int argc, char ** argv) {
int n_remain = params.n_predict; int n_remain = params.n_predict;
int n_consumed = 0; int n_consumed = 0;
int n_session_consumed = 0; int n_session_consumed = 0;
int guidance_n_past = 0; int n_past_guidance = 0;
// the first thing we will do is to output the prompt, so set color accordingly // the first thing we will do is to output the prompt, so set color accordingly
console_set_color(con_st, CONSOLE_COLOR_PROMPT); console_set_color(con_st, CONSOLE_COLOR_PROMPT);
std::vector<llama_token> embd; std::vector<llama_token> embd;
std::vector<llama_token> guidance_embd; std::vector<llama_token> embd_guidance;
// do one empty run to warm up the model // do one empty run to warm up the model
{ {
@ -403,7 +403,7 @@ int main(int argc, char ** argv) {
// always keep the first token - BOS // always keep the first token - BOS
n_past = std::max(1, params.n_keep); n_past = std::max(1, params.n_keep);
guidance_n_past = std::max(1, params.n_keep + guidance_offset); n_past_guidance = std::max(1, params.n_keep + guidance_offset);
// insert n_left/2 tokens at the start of embd from last_n_tokens // insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
@ -445,29 +445,29 @@ int main(int argc, char ** argv) {
// evaluate tokens in batches // evaluate tokens in batches
// embd is typically prepared beforehand to fit within a batch, but not always // embd is typically prepared beforehand to fit within a batch, but not always
if (guidance_ctx) { if (ctx_guidance) {
int input_size = 0; int input_size = 0;
llama_token* input_buf = NULL; llama_token* input_buf = NULL;
if (guidance_n_past < (int) guidance_inp.size()) { if (n_past_guidance < (int) guidance_inp.size()) {
// Guidance context should have the same data with these modifications: // Guidance context should have the same data with these modifications:
// //
// * Replace the initial prompt // * Replace the initial prompt
// * Shift everything by guidance_offset // * Shift everything by guidance_offset
guidance_embd = guidance_inp; embd_guidance = guidance_inp;
if (embd.begin() + original_prompt_len < embd.end()) { if (embd.begin() + original_prompt_len < embd.end()) {
guidance_embd.insert( embd_guidance.insert(
guidance_embd.end(), embd_guidance.end(),
embd.begin() + original_prompt_len, embd.begin() + original_prompt_len,
embd.end() embd.end()
); );
} }
input_buf = guidance_embd.data(); input_buf = embd_guidance.data();
input_size = guidance_embd.size(); input_size = embd_guidance.size();
//fprintf(stderr, "\n---------------------\n"); //fprintf(stderr, "\n---------------------\n");
//for (int i = 0; i < (int) guidance_embd.size(); i++) { //for (int i = 0; i < (int) embd_guidance.size(); i++) {
//fprintf(stderr, "%s", llama_token_to_str(ctx, guidance_embd[i])); //fprintf(stderr, "%s", llama_token_to_str(ctx, embd_guidance[i]));
//} //}
//fprintf(stderr, "\n---------------------\n"); //fprintf(stderr, "\n---------------------\n");
} else { } else {
@ -477,12 +477,12 @@ int main(int argc, char ** argv) {
for (int i = 0; i < input_size; i += params.n_batch) { for (int i = 0; i < input_size; i += params.n_batch) {
int n_eval = std::min(input_size - i, params.n_batch); int n_eval = std::min(input_size - i, params.n_batch);
if (llama_eval(guidance_ctx, input_buf + i, n_eval, guidance_n_past, params.n_threads)) { if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return 1; return 1;
} }
guidance_n_past += n_eval; n_past_guidance += n_eval;
} }
} }
@ -505,7 +505,7 @@ int main(int argc, char ** argv) {
} }
embd.clear(); embd.clear();
guidance_embd.clear(); embd_guidance.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// out of user input, sample next token // out of user input, sample next token
@ -548,8 +548,8 @@ int main(int argc, char ** argv) {
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
if (guidance_ctx) { if (ctx_guidance) {
llama_sample_classifier_free_guidance(ctx, &candidates_p, guidance_ctx, params.cfg_scale, params.cfg_smooth_factor); llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor);
} }
// Apply penalties // Apply penalties
@ -747,7 +747,7 @@ int main(int argc, char ** argv) {
} }
llama_print_timings(ctx); llama_print_timings(ctx);
if (guidance_ctx) { llama_free(guidance_ctx); } if (ctx_guidance) { llama_free(ctx_guidance); }
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);

View file

@ -2141,27 +2141,17 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
} }
} }
template<typename T, typename LogitAccessor> static void llama_log_softmax(float * array, size_t size) {
void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) { float max_l = *std::max_element(array, array + size);
T* element = std::max_element(
array, array + size,
[&logit_accessor](T& lhs, T& rhs) {
return logit_accessor(lhs) < logit_accessor(rhs);
}
);
float max_l = logit_accessor(*element);
float sum = 0.f; float sum = 0.f;
for (int i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
float& logit = logit_accessor(array[i]); float p = expf(array[i] - max_l);
float p = expf(logit - max_l);
sum += p; sum += p;
logit = p; array[i] = p;
} }
for (int i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
float& logit = logit_accessor(array[i]); array[i] = logf(array[i] / sum);
logit = logf(logit / sum);
} }
} }
@ -2178,32 +2168,29 @@ void llama_sample_classifier_free_guidance(
assert(n_vocab == (int)candidates->size); assert(n_vocab == (int)candidates->size);
assert(!candidates->sorted); assert(!candidates->sorted);
auto logit_from_token_data = [](llama_token_data& data) -> float& { std::vector<float> logits_base;
return data.logit; logits_base.reserve(candidates->size);
}; for (size_t i = 0; i < candidates->size; ++i) {
logits_base.push_back(candidates->data[i].logit);
}
llama_log_softmax(logits_base.data(), candidates->size);
auto logit_from_float = [](float& item) -> float& { float* logits_guidance = llama_get_logits(guidance_ctx);
return item; llama_log_softmax(logits_guidance, n_vocab);
};
llama_log_softmax(candidates->data, candidates->size, logit_from_token_data);
auto* guidance_logits = llama_get_logits(guidance_ctx);
llama_log_softmax(guidance_logits, n_vocab, logit_from_float);
for (int i = 0; i < n_vocab; ++i) { for (int i = 0; i < n_vocab; ++i) {
float guidance_logit = guidance_logits[i]; float logit_guidance = logits_guidance[i];
float base_logit = candidates->data[i].logit; float logit_base = logits_base[i];
guidance_logits[i] = scale * (base_logit - guidance_logit) + guidance_logit; logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance;
} }
llama_log_softmax(guidance_logits, n_vocab, logit_from_float); llama_log_softmax(logits_guidance, n_vocab);
for (int i = 0; i < n_vocab; ++i) { for (int i = 0; i < n_vocab; ++i) {
float base_logit = candidates->data[i].logit; float logit_base = logits_base[i];
float guidance_logit = guidance_logits[i]; float logit_guidance = logits_guidance[i];
candidates->data[i].logit = smooth_factor * guidance_logit + (1.f - smooth_factor) * base_logit; candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base;
} }
if (ctx) { if (ctx) {