Make Classifier-Free Guidance a sampling function
This commit is contained in:
parent
114d4c5389
commit
422a7ffdaf
3 changed files with 71 additions and 29 deletions
|
@ -54,20 +54,6 @@ void sigint_handler(int signo) {
|
|||
}
|
||||
#endif
|
||||
|
||||
void inplace_log_softmax(float* logits, int n_vocab) {
|
||||
float sum = 0.f;
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
float p = expf(logits[i]);
|
||||
logits[i] = p;
|
||||
sum += p;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
float p = logits[i];
|
||||
logits[i] = logf(p/ sum);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
|
@ -554,21 +540,6 @@ int main(int argc, char ** argv) {
|
|||
logits[it->first] += it->second;
|
||||
}
|
||||
|
||||
if (guidance_ctx) {
|
||||
inplace_log_softmax(logits, n_vocab);
|
||||
auto* guidance_logits = llama_get_logits(guidance_ctx);
|
||||
inplace_log_softmax(guidance_logits, n_vocab);
|
||||
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
guidance_logits[i] = params.cfg_scale * (logits[i] - guidance_logits[i]) + guidance_logits[i];
|
||||
}
|
||||
inplace_log_softmax(guidance_logits, n_vocab);
|
||||
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
logits[i] = guidance_logits[i] * params.cfg_smooth_factor + logits[i] * (1 - params.cfg_smooth_factor);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
|
@ -577,6 +548,10 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
if (guidance_ctx) {
|
||||
llama_sample_context_free_guidance(ctx, &candidates_p, guidance_ctx, params.cfg_scale, params.cfg_smooth_factor);
|
||||
}
|
||||
|
||||
// Apply penalties
|
||||
float nl_logit = logits[llama_token_nl()];
|
||||
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
|
||||
|
|
55
llama.cpp
55
llama.cpp
|
@ -2141,6 +2141,61 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T, typename LogitAccessor>
|
||||
void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) {
|
||||
float sum = 0.f;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
float& logit = logit_accessor(array[i]);
|
||||
float p = expf(logit);
|
||||
sum += p;
|
||||
logit = p;
|
||||
}
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
float& logit = logit_accessor(array[i]);
|
||||
logit = logf(logit / sum);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_context_free_guidance(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
struct llama_context * guidance_ctx,
|
||||
float scale,
|
||||
float smooth_factor) {
|
||||
assert(ctx);
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
assert(n_vocab == (int)candidates->size);
|
||||
assert(!candidates->sorted);
|
||||
|
||||
auto logit_from_token_data = [](llama_token_data& data) -> float& {
|
||||
return data.logit;
|
||||
};
|
||||
|
||||
auto logit_from_float = [](float& item) -> float& {
|
||||
return item;
|
||||
};
|
||||
|
||||
llama_log_softmax(candidates->data, candidates->size, logit_from_token_data);
|
||||
|
||||
auto* guidance_logits = llama_get_logits(guidance_ctx);
|
||||
llama_log_softmax(guidance_logits, n_vocab, logit_from_float);
|
||||
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
float guidance_logit = guidance_logits[i];
|
||||
float base_logit = candidates->data[i].logit;
|
||||
guidance_logits[i] = scale * (base_logit - guidance_logit) + guidance_logit;
|
||||
}
|
||||
|
||||
llama_log_softmax(guidance_logits, n_vocab, logit_from_float);
|
||||
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
float base_logit = candidates->data[i].logit;
|
||||
float guidance_logit = guidance_logits[i];
|
||||
|
||||
candidates->data[i].logit = smooth_factor * guidance_logit + (1.f - smooth_factor) * base_logit;
|
||||
}
|
||||
}
|
||||
|
||||
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
|
||||
assert(ctx);
|
||||
|
|
12
llama.h
12
llama.h
|
@ -307,6 +307,18 @@ extern "C" {
|
|||
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
|
||||
|
||||
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
|
||||
/// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
|
||||
/// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
|
||||
/// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits.
|
||||
LLAMA_API void llama_sample_context_free_guidance(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
struct llama_context * guidance_ctx,
|
||||
float scale,
|
||||
float smooth_factor);
|
||||
|
||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue