Fix styling based on review

This commit is contained in:
Bach Le 2023-07-10 23:50:17 +08:00
parent 325fc88141
commit abf164d71e
4 changed files with 49 additions and 62 deletions

View file

@ -556,7 +556,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res;
}
struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_params & params) {
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
@ -576,7 +576,7 @@ struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_p
}
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) {
auto lparams = llama_get_context_params_from_gpt_params(params);
auto lparams = llama_context_params_from_gpt_params(params);
llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
if (model == NULL) {

View file

@ -105,7 +105,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
//
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params);
struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
//
// Console utils

View file

@ -109,14 +109,14 @@ int main(int argc, char ** argv) {
llama_model * model;
llama_context * ctx;
llama_context * guidance_ctx = NULL;
llama_context * ctx_guidance = NULL;
g_ctx = &ctx;
// load the model and apply lora adapter, if any
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (params.cfg_scale > 1.f) {
struct llama_context_params lparams = llama_get_context_params_from_gpt_params(params);
guidance_ctx = llama_new_context_with_model(model, lparams);
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
ctx_guidance = llama_new_context_with_model(model, lparams);
}
if (model == NULL) {
@ -202,9 +202,9 @@ int main(int argc, char ** argv) {
std::vector<llama_token> guidance_inp;
int guidance_offset = 0;
int original_prompt_len = 0;
if (guidance_ctx) {
if (ctx_guidance) {
params.cfg_negative_prompt.insert(0, 1, ' ');
guidance_inp = ::llama_tokenize(guidance_ctx, params.cfg_negative_prompt, true);
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true);
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
original_prompt_len = original_inp.size();
@ -278,7 +278,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
}
if (guidance_ctx) {
if (ctx_guidance) {
fprintf(stderr, "\n");
fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
@ -363,13 +363,13 @@ int main(int argc, char ** argv) {
int n_remain = params.n_predict;
int n_consumed = 0;
int n_session_consumed = 0;
int guidance_n_past = 0;
int n_past_guidance = 0;
// the first thing we will do is to output the prompt, so set color accordingly
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
std::vector<llama_token> embd;
std::vector<llama_token> guidance_embd;
std::vector<llama_token> embd_guidance;
// do one empty run to warm up the model
{
@ -403,7 +403,7 @@ int main(int argc, char ** argv) {
// always keep the first token - BOS
n_past = std::max(1, params.n_keep);
guidance_n_past = std::max(1, params.n_keep + guidance_offset);
n_past_guidance = std::max(1, params.n_keep + guidance_offset);
// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
@ -445,29 +445,29 @@ int main(int argc, char ** argv) {
// evaluate tokens in batches
// embd is typically prepared beforehand to fit within a batch, but not always
if (guidance_ctx) {
if (ctx_guidance) {
int input_size = 0;
llama_token* input_buf = NULL;
if (guidance_n_past < (int) guidance_inp.size()) {
if (n_past_guidance < (int) guidance_inp.size()) {
// Guidance context should have the same data with these modifications:
//
// * Replace the initial prompt
// * Shift everything by guidance_offset
guidance_embd = guidance_inp;
embd_guidance = guidance_inp;
if (embd.begin() + original_prompt_len < embd.end()) {
guidance_embd.insert(
guidance_embd.end(),
embd_guidance.insert(
embd_guidance.end(),
embd.begin() + original_prompt_len,
embd.end()
);
}
input_buf = guidance_embd.data();
input_size = guidance_embd.size();
input_buf = embd_guidance.data();
input_size = embd_guidance.size();
//fprintf(stderr, "\n---------------------\n");
//for (int i = 0; i < (int) guidance_embd.size(); i++) {
//fprintf(stderr, "%s", llama_token_to_str(ctx, guidance_embd[i]));
//for (int i = 0; i < (int) embd_guidance.size(); i++) {
//fprintf(stderr, "%s", llama_token_to_str(ctx, embd_guidance[i]));
//}
//fprintf(stderr, "\n---------------------\n");
} else {
@ -477,12 +477,12 @@ int main(int argc, char ** argv) {
for (int i = 0; i < input_size; i += params.n_batch) {
int n_eval = std::min(input_size - i, params.n_batch);
if (llama_eval(guidance_ctx, input_buf + i, n_eval, guidance_n_past, params.n_threads)) {
if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
guidance_n_past += n_eval;
n_past_guidance += n_eval;
}
}
@ -505,7 +505,7 @@ int main(int argc, char ** argv) {
}
embd.clear();
guidance_embd.clear();
embd_guidance.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// out of user input, sample next token
@ -548,8 +548,8 @@ int main(int argc, char ** argv) {
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
if (guidance_ctx) {
llama_sample_classifier_free_guidance(ctx, &candidates_p, guidance_ctx, params.cfg_scale, params.cfg_smooth_factor);
if (ctx_guidance) {
llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor);
}
// Apply penalties
@ -747,7 +747,7 @@ int main(int argc, char ** argv) {
}
llama_print_timings(ctx);
if (guidance_ctx) { llama_free(guidance_ctx); }
if (ctx_guidance) { llama_free(ctx_guidance); }
llama_free(ctx);
llama_free_model(model);

View file

@ -2141,27 +2141,17 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
}
}
template<typename T, typename LogitAccessor>
void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) {
T* element = std::max_element(
array, array + size,
[&logit_accessor](T& lhs, T& rhs) {
return logit_accessor(lhs) < logit_accessor(rhs);
}
);
float max_l = logit_accessor(*element);
static void llama_log_softmax(float * array, size_t size) {
float max_l = *std::max_element(array, array + size);
float sum = 0.f;
for (int i = 0; i < size; ++i) {
float& logit = logit_accessor(array[i]);
float p = expf(logit - max_l);
for (size_t i = 0; i < size; ++i) {
float p = expf(array[i] - max_l);
sum += p;
logit = p;
array[i] = p;
}
for (int i = 0; i < size; ++i) {
float& logit = logit_accessor(array[i]);
logit = logf(logit / sum);
for (size_t i = 0; i < size; ++i) {
array[i] = logf(array[i] / sum);
}
}
@ -2178,32 +2168,29 @@ void llama_sample_classifier_free_guidance(
assert(n_vocab == (int)candidates->size);
assert(!candidates->sorted);
auto logit_from_token_data = [](llama_token_data& data) -> float& {
return data.logit;
};
std::vector<float> logits_base;
logits_base.reserve(candidates->size);
for (size_t i = 0; i < candidates->size; ++i) {
logits_base.push_back(candidates->data[i].logit);
}
llama_log_softmax(logits_base.data(), candidates->size);
auto logit_from_float = [](float& item) -> float& {
return item;
};
llama_log_softmax(candidates->data, candidates->size, logit_from_token_data);
auto* guidance_logits = llama_get_logits(guidance_ctx);
llama_log_softmax(guidance_logits, n_vocab, logit_from_float);
float* logits_guidance = llama_get_logits(guidance_ctx);
llama_log_softmax(logits_guidance, n_vocab);
for (int i = 0; i < n_vocab; ++i) {
float guidance_logit = guidance_logits[i];
float base_logit = candidates->data[i].logit;
guidance_logits[i] = scale * (base_logit - guidance_logit) + guidance_logit;
float logit_guidance = logits_guidance[i];
float logit_base = logits_base[i];
logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance;
}
llama_log_softmax(guidance_logits, n_vocab, logit_from_float);
llama_log_softmax(logits_guidance, n_vocab);
for (int i = 0; i < n_vocab; ++i) {
float base_logit = candidates->data[i].logit;
float guidance_logit = guidance_logits[i];
float logit_base = logits_base[i];
float logit_guidance = logits_guidance[i];
candidates->data[i].logit = smooth_factor * guidance_logit + (1.f - smooth_factor) * base_logit;
candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base;
}
if (ctx) {