Shift all values by the max value before applying logsoftmax
This commit is contained in:
parent
8e66e59cdd
commit
325fc88141
1 changed files with 9 additions and 1 deletions
10
llama.cpp
10
llama.cpp
|
@ -2143,10 +2143,18 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
|
||||||
|
|
||||||
template<typename T, typename LogitAccessor>
|
template<typename T, typename LogitAccessor>
|
||||||
void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) {
|
void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) {
|
||||||
|
T* element = std::max_element(
|
||||||
|
array, array + size,
|
||||||
|
[&logit_accessor](T& lhs, T& rhs) {
|
||||||
|
return logit_accessor(lhs) < logit_accessor(rhs);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
float max_l = logit_accessor(*element);
|
||||||
float sum = 0.f;
|
float sum = 0.f;
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
float& logit = logit_accessor(array[i]);
|
float& logit = logit_accessor(array[i]);
|
||||||
float p = expf(logit);
|
float p = expf(logit - max_l);
|
||||||
sum += p;
|
sum += p;
|
||||||
logit = p;
|
logit = p;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue