From 3f8e444d0d7fd06835d790732ce2b7c722be6d97 Mon Sep 17 00:00:00 2001 From: Bartosz Podkanowicz Date: Thu, 9 Nov 2023 21:21:35 +0100 Subject: [PATCH] fix error in the formula - formula now is similar to formula in the paper. --- examples/contrastive/contrastive.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/contrastive/contrastive.cpp b/examples/contrastive/contrastive.cpp index 83610dcfe..24616b347 100644 --- a/examples/contrastive/contrastive.cpp +++ b/examples/contrastive/contrastive.cpp @@ -1,6 +1,7 @@ #include "common.h" #include "llama.h" +#include #include #include #include @@ -144,6 +145,7 @@ int main(int argc, char ** argv) { int n_cur = batch.n_tokens; int n_decode = 0; + float log_alpha = std::log(alpha); const auto t_main_start = ggml_time_us(); while (n_cur <= n_len) { @@ -156,9 +158,10 @@ int main(int argc, char ** argv) { std::vector candidates; candidates.reserve(n_vocab); + auto largest_expert_logit = *std::max_element(logits_expert, logits_expert + n_vocab); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { float cd_logit = std::numeric_limits::lowest(); - if (logits_expert[token_id] > alpha) { + if (logits_expert[token_id] > log_alpha + largest_expert_logit) { cd_logit = (1+beta)*logits_expert[token_id] - beta*logits_amateur[token_id]; } candidates.emplace_back(llama_token_data{ token_id, cd_logit, 0.0f });