Check for llama_get_logits_ith() errors
Embeddings models like BERT don't have logits. This caused the llamafile software to crash for users who tried to inference mxbai-embed-large-v1. This change potentially helps prevent the server from crashing. Since it is possible for this function to fail having callers check the result is a good idea from a defensive coding standpoint. The older exception code has also been refactored, since it's no longer needed.
This commit is contained in:
parent
201cc11afa
commit
99291c04f7
8 changed files with 92 additions and 65 deletions
|
@ -195,6 +195,9 @@ static llama_token llama_sampling_sample_impl(
|
||||||
llama_token id = 0;
|
llama_token id = 0;
|
||||||
// Get a pointer to the logits
|
// Get a pointer to the logits
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
|
if (!logits) {
|
||||||
|
throw std::runtime_error("llama_get_logits_ith failed");
|
||||||
|
}
|
||||||
|
|
||||||
if (temp < 0.0) {
|
if (temp < 0.0) {
|
||||||
// greedy sampling, with probs
|
// greedy sampling, with probs
|
||||||
|
@ -284,6 +287,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
|
|
||||||
// Get a pointer to the logits
|
// Get a pointer to the logits
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
|
if (!logits) {
|
||||||
|
throw std::runtime_error("llama_get_logits_ith failed");
|
||||||
|
}
|
||||||
|
|
||||||
if (ctx_sampling->grammar != NULL && !apply_grammar) {
|
if (ctx_sampling->grammar != NULL && !apply_grammar) {
|
||||||
GGML_ASSERT(original_logits != NULL);
|
GGML_ASSERT(original_logits != NULL);
|
||||||
|
@ -298,6 +304,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
|
|
||||||
if (ctx_cfg) {
|
if (ctx_cfg) {
|
||||||
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
|
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
|
||||||
|
if (!logits_guidance) {
|
||||||
|
throw std::runtime_error("llama_get_logits_ith failed");
|
||||||
|
}
|
||||||
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -169,6 +169,9 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
auto n_vocab = llama_n_vocab(model);
|
auto n_vocab = llama_n_vocab(model);
|
||||||
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
|
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
|
||||||
|
if (!logits) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
|
|
@ -58,6 +58,9 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
||||||
// sum up all token embeddings
|
// sum up all token embeddings
|
||||||
for (int32_t k = n_inst; k < n_toks; k++) {
|
for (int32_t k = n_inst; k < n_toks; k++) {
|
||||||
float * emb = llama_get_embeddings_ith(ctx, k);
|
float * emb = llama_get_embeddings_ith(ctx, k);
|
||||||
|
if (!emb) {
|
||||||
|
throw std::runtime_error("llama_get_embeddings_ith failed");
|
||||||
|
}
|
||||||
for (uint64_t j = 0; j < n_embd; j++) {
|
for (uint64_t j = 0; j < n_embd; j++) {
|
||||||
emb_unorm[j] += emb[j];
|
emb_unorm[j] += emb[j];
|
||||||
}
|
}
|
||||||
|
@ -114,6 +117,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
|
||||||
|
|
||||||
llama_decode(ctx, bat);
|
llama_decode(ctx, bat);
|
||||||
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
|
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
|
||||||
|
if (!logits) {
|
||||||
|
throw std::runtime_error("llama_get_logits_ith failed");
|
||||||
|
}
|
||||||
|
|
||||||
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
|
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
|
||||||
auto n_candidates = (int32_t)candidates.size();
|
auto n_candidates = (int32_t)candidates.size();
|
||||||
|
|
|
@ -394,6 +394,9 @@ Java_com_example_llama_Llm_completion_1loop(
|
||||||
|
|
||||||
auto n_vocab = llama_n_vocab(model);
|
auto n_vocab = llama_n_vocab(model);
|
||||||
auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
|
auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
|
||||||
|
if (!logits) {
|
||||||
|
throw std::runtime_error("llama_get_logits_ith failed");
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
|
|
@ -239,6 +239,9 @@ int main(int argc, char ** argv) {
|
||||||
{
|
{
|
||||||
auto n_vocab = llama_n_vocab(model);
|
auto n_vocab = llama_n_vocab(model);
|
||||||
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||||
|
if (!logits) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
|
|
@ -638,6 +638,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
|
|
||||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
||||||
|
if (!all_logits) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
||||||
if (!params.logits_file.empty()) {
|
if (!params.logits_file.empty()) {
|
||||||
|
|
|
@ -120,6 +120,9 @@ int main(int argc, char ** argv) {
|
||||||
{
|
{
|
||||||
auto n_vocab = llama_n_vocab(model);
|
auto n_vocab = llama_n_vocab(model);
|
||||||
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||||
|
if (!logits) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
|
127
llama.cpp
127
llama.cpp
|
@ -17301,42 +17301,39 @@ float * llama_get_logits(struct llama_context * ctx) {
|
||||||
return ctx->logits;
|
return ctx->logits;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static float * llama_get_logits_ith_fail(int i, std::string reason) {
|
||||||
|
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, reason.c_str());
|
||||||
|
#ifndef NDEBUG
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
#endif
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
||||||
int32_t j = -1;
|
int32_t j = -1;
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
if (ctx->logits == nullptr) {
|
||||||
try {
|
// this can happen for embeddings models like bert
|
||||||
if (ctx->logits == nullptr) {
|
return llama_get_logits_ith_fail(i, "no logits");
|
||||||
throw std::runtime_error("no logits");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (i < 0) {
|
|
||||||
j = ctx->n_outputs + i;
|
|
||||||
if (j < 0) {
|
|
||||||
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
|
|
||||||
}
|
|
||||||
} else if ((size_t) i >= ctx->output_ids.size()) {
|
|
||||||
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
|
|
||||||
} else {
|
|
||||||
j = ctx->output_ids[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (j < 0) {
|
|
||||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
|
||||||
}
|
|
||||||
if (j >= ctx->n_outputs) {
|
|
||||||
// This should not happen
|
|
||||||
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
|
||||||
}
|
|
||||||
|
|
||||||
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
|
||||||
} catch (const std::exception & err) {
|
|
||||||
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
|
||||||
#ifndef NDEBUG
|
|
||||||
GGML_ASSERT(false);
|
|
||||||
#endif
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
if (i < 0) {
|
||||||
|
j = ctx->n_outputs + i;
|
||||||
|
if (j < 0) {
|
||||||
|
return llama_get_logits_ith_fail(i, format("negative index out of range [0, %d)", ctx->n_outputs));
|
||||||
|
}
|
||||||
|
} else if ((size_t) i >= ctx->output_ids.size()) {
|
||||||
|
return llama_get_logits_ith_fail(i, format("out of range [0, %lu)", ctx->output_ids.size()));
|
||||||
|
} else {
|
||||||
|
j = ctx->output_ids[i];
|
||||||
|
}
|
||||||
|
if (j < 0) {
|
||||||
|
return llama_get_logits_ith_fail(i, format("batch.logits[%d] != true", i));
|
||||||
|
}
|
||||||
|
if (j >= ctx->n_outputs) {
|
||||||
|
// This should not happen
|
||||||
|
return llama_get_logits_ith_fail(i, format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
||||||
|
}
|
||||||
|
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_embeddings(struct llama_context * ctx) {
|
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||||
|
@ -17345,43 +17342,43 @@ float * llama_get_embeddings(struct llama_context * ctx) {
|
||||||
return ctx->embd;
|
return ctx->embd;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static float * llama_get_embeddings_ith_fail(int i, std::string reason) {
|
||||||
|
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, reason.c_str());
|
||||||
|
#ifndef NDEBUG
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
#endif
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
||||||
int32_t j = -1;
|
int32_t j = -1;
|
||||||
|
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
if (ctx->embd == nullptr) {
|
||||||
try {
|
return llama_get_embeddings_ith_fail(i, "no embeddings");
|
||||||
if (ctx->embd == nullptr) {
|
|
||||||
throw std::runtime_error("no embeddings");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (i < 0) {
|
|
||||||
j = ctx->n_outputs + i;
|
|
||||||
if (j < 0) {
|
|
||||||
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
|
|
||||||
}
|
|
||||||
} else if ((size_t) i >= ctx->output_ids.size()) {
|
|
||||||
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
|
|
||||||
} else {
|
|
||||||
j = ctx->output_ids[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (j < 0) {
|
|
||||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
|
||||||
}
|
|
||||||
if (j >= ctx->n_outputs) {
|
|
||||||
// This should not happen
|
|
||||||
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
|
||||||
}
|
|
||||||
|
|
||||||
return ctx->embd + j*ctx->model.hparams.n_embd;
|
|
||||||
} catch (const std::exception & err) {
|
|
||||||
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
|
||||||
#ifndef NDEBUG
|
|
||||||
GGML_ASSERT(false);
|
|
||||||
#endif
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
if (i < 0) {
|
||||||
|
j = ctx->n_outputs + i;
|
||||||
|
if (j < 0) {
|
||||||
|
return llama_get_embeddings_ith_fail(
|
||||||
|
i, format("negative index out of range [0, %d)", ctx->n_outputs));
|
||||||
|
}
|
||||||
|
} else if ((size_t) i >= ctx->output_ids.size()) {
|
||||||
|
return llama_get_embeddings_ith_fail(
|
||||||
|
i, format("out of range [0, %lu)", ctx->output_ids.size()));
|
||||||
|
} else {
|
||||||
|
j = ctx->output_ids[i];
|
||||||
|
}
|
||||||
|
if (j < 0) {
|
||||||
|
return llama_get_embeddings_ith_fail(
|
||||||
|
i, format("batch.logits[%d] != true", i));
|
||||||
|
}
|
||||||
|
if (j >= ctx->n_outputs) {
|
||||||
|
// This should not happen
|
||||||
|
return llama_get_embeddings_ith_fail(
|
||||||
|
i, format("corrupt output buffer (j=%d, n_outputs=%d)",
|
||||||
|
j, ctx->n_outputs));
|
||||||
|
}
|
||||||
|
return ctx->embd + j*ctx->model.hparams.n_embd;
|
||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue