diff --git a/llama.cpp b/llama.cpp index cdfb1bbb6..6e0f96bf2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2143,10 +2143,18 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l template void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) { + T* element = std::max_element( + array, array + size, + [&logit_accessor](T& lhs, T& rhs) { + return logit_accessor(lhs) < logit_accessor(rhs); + } + ); + + float max_l = logit_accessor(*element); float sum = 0.f; for (int i = 0; i < size; ++i) { float& logit = logit_accessor(array[i]); - float p = expf(logit); + float p = expf(logit - max_l); sum += p; logit = p; }