Check for llama_get_logits_ith() errors

Embeddings models like BERT don't have logits. This caused the llamafile
software to crash for users who tried to inference mxbai-embed-large-v1.
This change potentially helps prevent the server from crashing. Since it
is possible for this function to fail having callers check the result is
a good idea from a defensive coding standpoint. The older exception code
has also been refactored, since it's no longer needed.
This commit is contained in:
Justine Tunney 2024-05-21 14:13:38 -07:00
parent 201cc11afa
commit 99291c04f7
No known key found for this signature in database
GPG key ID: 52965314629936D4
8 changed files with 92 additions and 65 deletions

View file

@ -195,6 +195,9 @@ static llama_token llama_sampling_sample_impl(
llama_token id = 0; llama_token id = 0;
// Get a pointer to the logits // Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx); float * logits = llama_get_logits_ith(ctx_main, idx);
if (!logits) {
throw std::runtime_error("llama_get_logits_ith failed");
}
if (temp < 0.0) { if (temp < 0.0) {
// greedy sampling, with probs // greedy sampling, with probs
@ -284,6 +287,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
// Get a pointer to the logits // Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx); float * logits = llama_get_logits_ith(ctx_main, idx);
if (!logits) {
throw std::runtime_error("llama_get_logits_ith failed");
}
if (ctx_sampling->grammar != NULL && !apply_grammar) { if (ctx_sampling->grammar != NULL && !apply_grammar) {
GGML_ASSERT(original_logits != NULL); GGML_ASSERT(original_logits != NULL);
@ -298,6 +304,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
if (ctx_cfg) { if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
if (!logits_guidance) {
throw std::runtime_error("llama_get_logits_ith failed");
}
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
} }

View file

@ -169,6 +169,9 @@ int main(int argc, char ** argv) {
auto n_vocab = llama_n_vocab(model); auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, i_batch[i]); auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
if (!logits) {
return 1;
}
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);

View file

@ -58,6 +58,9 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
// sum up all token embeddings // sum up all token embeddings
for (int32_t k = n_inst; k < n_toks; k++) { for (int32_t k = n_inst; k < n_toks; k++) {
float * emb = llama_get_embeddings_ith(ctx, k); float * emb = llama_get_embeddings_ith(ctx, k);
if (!emb) {
throw std::runtime_error("llama_get_embeddings_ith failed");
}
for (uint64_t j = 0; j < n_embd; j++) { for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] += emb[j]; emb_unorm[j] += emb[j];
} }
@ -114,6 +117,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
llama_decode(ctx, bat); llama_decode(ctx, bat);
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
if (!logits) {
throw std::runtime_error("llama_get_logits_ith failed");
}
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl)); auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
auto n_candidates = (int32_t)candidates.size(); auto n_candidates = (int32_t)candidates.size();

View file

@ -394,6 +394,9 @@ Java_com_example_llama_Llm_completion_1loop(
auto n_vocab = llama_n_vocab(model); auto n_vocab = llama_n_vocab(model);
auto logits = llama_get_logits_ith(context, batch->n_tokens - 1); auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
if (!logits) {
throw std::runtime_error("llama_get_logits_ith failed");
}
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);

View file

@ -239,6 +239,9 @@ int main(int argc, char ** argv) {
{ {
auto n_vocab = llama_n_vocab(model); auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
if (!logits) {
return 1;
}
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);

View file

@ -638,6 +638,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
for (int seq = 0; seq < n_seq_batch; seq++) { for (int seq = 0; seq < n_seq_batch; seq++) {
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first); const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
if (!all_logits) {
return 1;
}
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first; llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
if (!params.logits_file.empty()) { if (!params.logits_file.empty()) {

View file

@ -120,6 +120,9 @@ int main(int argc, char ** argv) {
{ {
auto n_vocab = llama_n_vocab(model); auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
if (!logits) {
return 1;
}
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);

127
llama.cpp
View file

@ -17301,42 +17301,39 @@ float * llama_get_logits(struct llama_context * ctx) {
return ctx->logits; return ctx->logits;
} }
static float * llama_get_logits_ith_fail(int i, std::string reason) {
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, reason.c_str());
#ifndef NDEBUG
GGML_ASSERT(false);
#endif
return nullptr;
}
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
int32_t j = -1; int32_t j = -1;
llama_synchronize(ctx); llama_synchronize(ctx);
if (ctx->logits == nullptr) {
try { // this can happen for embeddings models like bert
if (ctx->logits == nullptr) { return llama_get_logits_ith_fail(i, "no logits");
throw std::runtime_error("no logits");
}
if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
}
} else if ((size_t) i >= ctx->output_ids.size()) {
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}
if (j < 0) {
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
if (j >= ctx->n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
}
return ctx->logits + j*ctx->model.hparams.n_vocab;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
GGML_ASSERT(false);
#endif
return nullptr;
} }
if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
return llama_get_logits_ith_fail(i, format("negative index out of range [0, %d)", ctx->n_outputs));
}
} else if ((size_t) i >= ctx->output_ids.size()) {
return llama_get_logits_ith_fail(i, format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}
if (j < 0) {
return llama_get_logits_ith_fail(i, format("batch.logits[%d] != true", i));
}
if (j >= ctx->n_outputs) {
// This should not happen
return llama_get_logits_ith_fail(i, format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
}
return ctx->logits + j*ctx->model.hparams.n_vocab;
} }
float * llama_get_embeddings(struct llama_context * ctx) { float * llama_get_embeddings(struct llama_context * ctx) {
@ -17345,43 +17342,43 @@ float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embd; return ctx->embd;
} }
static float * llama_get_embeddings_ith_fail(int i, std::string reason) {
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, reason.c_str());
#ifndef NDEBUG
GGML_ASSERT(false);
#endif
return nullptr;
}
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
int32_t j = -1; int32_t j = -1;
llama_synchronize(ctx); llama_synchronize(ctx);
if (ctx->embd == nullptr) {
try { return llama_get_embeddings_ith_fail(i, "no embeddings");
if (ctx->embd == nullptr) {
throw std::runtime_error("no embeddings");
}
if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
}
} else if ((size_t) i >= ctx->output_ids.size()) {
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}
if (j < 0) {
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
if (j >= ctx->n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
}
return ctx->embd + j*ctx->model.hparams.n_embd;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
GGML_ASSERT(false);
#endif
return nullptr;
} }
if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
return llama_get_embeddings_ith_fail(
i, format("negative index out of range [0, %d)", ctx->n_outputs));
}
} else if ((size_t) i >= ctx->output_ids.size()) {
return llama_get_embeddings_ith_fail(
i, format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}
if (j < 0) {
return llama_get_embeddings_ith_fail(
i, format("batch.logits[%d] != true", i));
}
if (j >= ctx->n_outputs) {
// This should not happen
return llama_get_embeddings_ith_fail(
i, format("corrupt output buffer (j=%d, n_outputs=%d)",
j, ctx->n_outputs));
}
return ctx->embd + j*ctx->model.hparams.n_embd;
} }
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) { float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {