Shift all values by the max value before applying logsoftmax

This commit is contained in:
Bach Le 2023-07-08 00:10:26 +08:00
parent 8e66e59cdd
commit 325fc88141

View file

@ -2143,10 +2143,18 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
template<typename T, typename LogitAccessor> template<typename T, typename LogitAccessor>
void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) { void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) {
T* element = std::max_element(
array, array + size,
[&logit_accessor](T& lhs, T& rhs) {
return logit_accessor(lhs) < logit_accessor(rhs);
}
);
float max_l = logit_accessor(*element);
float sum = 0.f; float sum = 0.f;
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
float& logit = logit_accessor(array[i]); float& logit = logit_accessor(array[i]);
float p = expf(logit); float p = expf(logit - max_l);
sum += p; sum += p;
logit = p; logit = p;
} }