fix error in the formula - formula now is similar to formula in the paper.

This commit is contained in:
Bartosz Podkanowicz 2023-11-09 21:21:35 +01:00
parent 1cf0b09273
commit 3f8e444d0d

View file

@ -1,6 +1,7 @@
#include "common.h"
#include "llama.h"
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <string>
@ -144,6 +145,7 @@ int main(int argc, char ** argv) {
int n_cur = batch.n_tokens;
int n_decode = 0;
float log_alpha = std::log(alpha);
const auto t_main_start = ggml_time_us();
while (n_cur <= n_len) {
@ -156,9 +158,10 @@ int main(int argc, char ** argv) {
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
auto largest_expert_logit = *std::max_element(logits_expert, logits_expert + n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
float cd_logit = std::numeric_limits<float>::lowest();
if (logits_expert[token_id] > alpha) {
if (logits_expert[token_id] > log_alpha + largest_expert_logit) {
cd_logit = (1+beta)*logits_expert[token_id] - beta*logits_amateur[token_id];
}
candidates.emplace_back(llama_token_data{ token_id, cd_logit, 0.0f });