From 325fc8814155a325bea879c92104e34e0145d7ea Mon Sep 17 00:00:00 2001 From: Bach Le Date: Sat, 8 Jul 2023 00:10:26 +0800 Subject: [PATCH] Shift all values by the max value before applying logsoftmax --- llama.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index cdfb1bbb6..6e0f96bf2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2143,10 +2143,18 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l template void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) { + T* element = std::max_element( + array, array + size, + [&logit_accessor](T& lhs, T& rhs) { + return logit_accessor(lhs) < logit_accessor(rhs); + } + ); + + float max_l = logit_accessor(*element); float sum = 0.f; for (int i = 0; i < size; ++i) { float& logit = logit_accessor(array[i]); - float p = expf(logit); + float p = expf(logit - max_l); sum += p; logit = p; }