Check for llama_get_logits_ith() errors
Embeddings models like BERT don't have logits. This caused the llamafile software to crash for users who tried to inference mxbai-embed-large-v1. This change potentially helps prevent the server from crashing. Since it is possible for this function to fail having callers check the result is a good idea from a defensive coding standpoint. The older exception code has also been refactored, since it's no longer needed.
This commit is contained in:
parent
201cc11afa
commit
99291c04f7
8 changed files with 92 additions and 65 deletions
|
@ -195,6 +195,9 @@ static llama_token llama_sampling_sample_impl(
|
|||
llama_token id = 0;
|
||||
// Get a pointer to the logits
|
||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||
if (!logits) {
|
||||
throw std::runtime_error("llama_get_logits_ith failed");
|
||||
}
|
||||
|
||||
if (temp < 0.0) {
|
||||
// greedy sampling, with probs
|
||||
|
@ -284,6 +287,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
|
||||
// Get a pointer to the logits
|
||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||
if (!logits) {
|
||||
throw std::runtime_error("llama_get_logits_ith failed");
|
||||
}
|
||||
|
||||
if (ctx_sampling->grammar != NULL && !apply_grammar) {
|
||||
GGML_ASSERT(original_logits != NULL);
|
||||
|
@ -298,6 +304,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
|
||||
if (ctx_cfg) {
|
||||
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
|
||||
if (!logits_guidance) {
|
||||
throw std::runtime_error("llama_get_logits_ith failed");
|
||||
}
|
||||
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
||||
}
|
||||
|
||||
|
|
|
@ -169,6 +169,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
|
||||
if (!logits) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
|
|
@ -58,6 +58,9 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||
// sum up all token embeddings
|
||||
for (int32_t k = n_inst; k < n_toks; k++) {
|
||||
float * emb = llama_get_embeddings_ith(ctx, k);
|
||||
if (!emb) {
|
||||
throw std::runtime_error("llama_get_embeddings_ith failed");
|
||||
}
|
||||
for (uint64_t j = 0; j < n_embd; j++) {
|
||||
emb_unorm[j] += emb[j];
|
||||
}
|
||||
|
@ -114,6 +117,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
|
|||
|
||||
llama_decode(ctx, bat);
|
||||
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
|
||||
if (!logits) {
|
||||
throw std::runtime_error("llama_get_logits_ith failed");
|
||||
}
|
||||
|
||||
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
|
||||
auto n_candidates = (int32_t)candidates.size();
|
||||
|
|
|
@ -394,6 +394,9 @@ Java_com_example_llama_Llm_completion_1loop(
|
|||
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
|
||||
if (!logits) {
|
||||
throw std::runtime_error("llama_get_logits_ith failed");
|
||||
}
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
|
|
@ -239,6 +239,9 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
if (!logits) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
|
|
@ -638,6 +638,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
|
||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
||||
if (!all_logits) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
||||
if (!params.logits_file.empty()) {
|
||||
|
|
|
@ -120,6 +120,9 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
if (!logits) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
|
127
llama.cpp
127
llama.cpp
|
@ -17301,42 +17301,39 @@ float * llama_get_logits(struct llama_context * ctx) {
|
|||
return ctx->logits;
|
||||
}
|
||||
|
||||
static float * llama_get_logits_ith_fail(int i, std::string reason) {
|
||||
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, reason.c_str());
|
||||
#ifndef NDEBUG
|
||||
GGML_ASSERT(false);
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
||||
int32_t j = -1;
|
||||
llama_synchronize(ctx);
|
||||
|
||||
try {
|
||||
if (ctx->logits == nullptr) {
|
||||
throw std::runtime_error("no logits");
|
||||
}
|
||||
|
||||
if (i < 0) {
|
||||
j = ctx->n_outputs + i;
|
||||
if (j < 0) {
|
||||
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
|
||||
}
|
||||
} else if ((size_t) i >= ctx->output_ids.size()) {
|
||||
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
|
||||
} else {
|
||||
j = ctx->output_ids[i];
|
||||
}
|
||||
|
||||
if (j < 0) {
|
||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
||||
}
|
||||
if (j >= ctx->n_outputs) {
|
||||
// This should not happen
|
||||
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
||||
}
|
||||
|
||||
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
||||
#ifndef NDEBUG
|
||||
GGML_ASSERT(false);
|
||||
#endif
|
||||
return nullptr;
|
||||
if (ctx->logits == nullptr) {
|
||||
// this can happen for embeddings models like bert
|
||||
return llama_get_logits_ith_fail(i, "no logits");
|
||||
}
|
||||
if (i < 0) {
|
||||
j = ctx->n_outputs + i;
|
||||
if (j < 0) {
|
||||
return llama_get_logits_ith_fail(i, format("negative index out of range [0, %d)", ctx->n_outputs));
|
||||
}
|
||||
} else if ((size_t) i >= ctx->output_ids.size()) {
|
||||
return llama_get_logits_ith_fail(i, format("out of range [0, %lu)", ctx->output_ids.size()));
|
||||
} else {
|
||||
j = ctx->output_ids[i];
|
||||
}
|
||||
if (j < 0) {
|
||||
return llama_get_logits_ith_fail(i, format("batch.logits[%d] != true", i));
|
||||
}
|
||||
if (j >= ctx->n_outputs) {
|
||||
// This should not happen
|
||||
return llama_get_logits_ith_fail(i, format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
||||
}
|
||||
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
||||
}
|
||||
|
||||
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||
|
@ -17345,43 +17342,43 @@ float * llama_get_embeddings(struct llama_context * ctx) {
|
|||
return ctx->embd;
|
||||
}
|
||||
|
||||
static float * llama_get_embeddings_ith_fail(int i, std::string reason) {
|
||||
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, reason.c_str());
|
||||
#ifndef NDEBUG
|
||||
GGML_ASSERT(false);
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
||||
int32_t j = -1;
|
||||
|
||||
llama_synchronize(ctx);
|
||||
|
||||
try {
|
||||
if (ctx->embd == nullptr) {
|
||||
throw std::runtime_error("no embeddings");
|
||||
}
|
||||
|
||||
if (i < 0) {
|
||||
j = ctx->n_outputs + i;
|
||||
if (j < 0) {
|
||||
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
|
||||
}
|
||||
} else if ((size_t) i >= ctx->output_ids.size()) {
|
||||
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
|
||||
} else {
|
||||
j = ctx->output_ids[i];
|
||||
}
|
||||
|
||||
if (j < 0) {
|
||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
||||
}
|
||||
if (j >= ctx->n_outputs) {
|
||||
// This should not happen
|
||||
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
||||
}
|
||||
|
||||
return ctx->embd + j*ctx->model.hparams.n_embd;
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
||||
#ifndef NDEBUG
|
||||
GGML_ASSERT(false);
|
||||
#endif
|
||||
return nullptr;
|
||||
if (ctx->embd == nullptr) {
|
||||
return llama_get_embeddings_ith_fail(i, "no embeddings");
|
||||
}
|
||||
if (i < 0) {
|
||||
j = ctx->n_outputs + i;
|
||||
if (j < 0) {
|
||||
return llama_get_embeddings_ith_fail(
|
||||
i, format("negative index out of range [0, %d)", ctx->n_outputs));
|
||||
}
|
||||
} else if ((size_t) i >= ctx->output_ids.size()) {
|
||||
return llama_get_embeddings_ith_fail(
|
||||
i, format("out of range [0, %lu)", ctx->output_ids.size()));
|
||||
} else {
|
||||
j = ctx->output_ids[i];
|
||||
}
|
||||
if (j < 0) {
|
||||
return llama_get_embeddings_ith_fail(
|
||||
i, format("batch.logits[%d] != true", i));
|
||||
}
|
||||
if (j >= ctx->n_outputs) {
|
||||
// This should not happen
|
||||
return llama_get_embeddings_ith_fail(
|
||||
i, format("corrupt output buffer (j=%d, n_outputs=%d)",
|
||||
j, ctx->n_outputs));
|
||||
}
|
||||
return ctx->embd + j*ctx->model.hparams.n_embd;
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue