Make Classifier-Free Guidance a sampling function

This commit is contained in:
Bach Le 2023-07-07 23:45:37 +08:00
parent 114d4c5389
commit 422a7ffdaf
3 changed files with 71 additions and 29 deletions

View file

@ -54,20 +54,6 @@ void sigint_handler(int signo) {
} }
#endif #endif
void inplace_log_softmax(float* logits, int n_vocab) {
float sum = 0.f;
for (int i = 0; i < n_vocab; ++i) {
float p = expf(logits[i]);
logits[i] = p;
sum += p;
}
for (int i = 0; i < n_vocab; ++i) {
float p = logits[i];
logits[i] = logf(p/ sum);
}
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
@ -554,21 +540,6 @@ int main(int argc, char ** argv) {
logits[it->first] += it->second; logits[it->first] += it->second;
} }
if (guidance_ctx) {
inplace_log_softmax(logits, n_vocab);
auto* guidance_logits = llama_get_logits(guidance_ctx);
inplace_log_softmax(guidance_logits, n_vocab);
for (int i = 0; i < n_vocab; ++i) {
guidance_logits[i] = params.cfg_scale * (logits[i] - guidance_logits[i]) + guidance_logits[i];
}
inplace_log_softmax(guidance_logits, n_vocab);
for (int i = 0; i < n_vocab; ++i) {
logits[i] = guidance_logits[i] * params.cfg_smooth_factor + logits[i] * (1 - params.cfg_smooth_factor);
}
}
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) { for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
@ -577,6 +548,10 @@ int main(int argc, char ** argv) {
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
if (guidance_ctx) {
llama_sample_context_free_guidance(ctx, &candidates_p, guidance_ctx, params.cfg_scale, params.cfg_smooth_factor);
}
// Apply penalties // Apply penalties
float nl_logit = logits[llama_token_nl()]; float nl_logit = logits[llama_token_nl()];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);

View file

@ -2141,6 +2141,61 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
} }
} }
template<typename T, typename LogitAccessor>
void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) {
float sum = 0.f;
for (int i = 0; i < size; ++i) {
float& logit = logit_accessor(array[i]);
float p = expf(logit);
sum += p;
logit = p;
}
for (int i = 0; i < size; ++i) {
float& logit = logit_accessor(array[i]);
logit = logf(logit / sum);
}
}
void llama_sample_context_free_guidance(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_context * guidance_ctx,
float scale,
float smooth_factor) {
assert(ctx);
auto n_vocab = llama_n_vocab(ctx);
assert(n_vocab == (int)candidates->size);
assert(!candidates->sorted);
auto logit_from_token_data = [](llama_token_data& data) -> float& {
return data.logit;
};
auto logit_from_float = [](float& item) -> float& {
return item;
};
llama_log_softmax(candidates->data, candidates->size, logit_from_token_data);
auto* guidance_logits = llama_get_logits(guidance_ctx);
llama_log_softmax(guidance_logits, n_vocab, logit_from_float);
for (int i = 0; i < n_vocab; ++i) {
float guidance_logit = guidance_logits[i];
float base_logit = candidates->data[i].logit;
guidance_logits[i] = scale * (base_logit - guidance_logit) + guidance_logit;
}
llama_log_softmax(guidance_logits, n_vocab, logit_from_float);
for (int i = 0; i < n_vocab; ++i) {
float base_logit = candidates->data[i].logit;
float guidance_logit = guidance_logits[i];
candidates->data[i].logit = smooth_factor * guidance_logit + (1.f - smooth_factor) * base_logit;
}
}
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
assert(ctx); assert(ctx);

12
llama.h
View file

@ -307,6 +307,18 @@ extern "C" {
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
/// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
/// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
/// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits.
LLAMA_API void llama_sample_context_free_guidance(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_context * guidance_ctx,
float scale,
float smooth_factor);
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);