perplexity : add <cmath>

This commit is contained in:
Georgi Gerganov 2023-03-28 19:40:01 +03:00
parent 61733d3b49
commit 21e9ce7574
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -1,6 +1,8 @@
#include "common.h" #include "common.h"
#include "llama.h" #include "llama.h"
#include <cmath>
std::vector<float> softmax(const std::vector<float>& logits) { std::vector<float> softmax(const std::vector<float>& logits) {
std::vector<float> probs(logits.size()); std::vector<float> probs(logits.size());
float max_logit = logits[0]; float max_logit = logits[0];
@ -9,7 +11,7 @@ std::vector<float> softmax(const std::vector<float>& logits) {
for (size_t i = 0; i < logits.size(); i++) { for (size_t i = 0; i < logits.size(); i++) {
// Subtract the maximum logit value from the current logit value for numerical stability // Subtract the maximum logit value from the current logit value for numerical stability
const float logit = logits[i] - max_logit; const float logit = logits[i] - max_logit;
const float exp_logit = std::expf(logit); const float exp_logit = expf(logit);
sum_exp += exp_logit; sum_exp += exp_logit;
probs[i] = exp_logit; probs[i] = exp_logit;
} }