fix error in the formula - formula now is similar to formula in the paper.
This commit is contained in:
parent
1cf0b09273
commit
3f8e444d0d
1 changed files with 4 additions and 1 deletions
|
@ -1,6 +1,7 @@
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -144,6 +145,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
int n_cur = batch.n_tokens;
|
int n_cur = batch.n_tokens;
|
||||||
int n_decode = 0;
|
int n_decode = 0;
|
||||||
|
float log_alpha = std::log(alpha);
|
||||||
|
|
||||||
const auto t_main_start = ggml_time_us();
|
const auto t_main_start = ggml_time_us();
|
||||||
while (n_cur <= n_len) {
|
while (n_cur <= n_len) {
|
||||||
|
@ -156,9 +158,10 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
|
||||||
|
auto largest_expert_logit = *std::max_element(logits_expert, logits_expert + n_vocab);
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
float cd_logit = std::numeric_limits<float>::lowest();
|
float cd_logit = std::numeric_limits<float>::lowest();
|
||||||
if (logits_expert[token_id] > alpha) {
|
if (logits_expert[token_id] > log_alpha + largest_expert_logit) {
|
||||||
cd_logit = (1+beta)*logits_expert[token_id] - beta*logits_amateur[token_id];
|
cd_logit = (1+beta)*logits_expert[token_id] - beta*logits_amateur[token_id];
|
||||||
}
|
}
|
||||||
candidates.emplace_back(llama_token_data{ token_id, cd_logit, 0.0f });
|
candidates.emplace_back(llama_token_data{ token_id, cd_logit, 0.0f });
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue