llama : clearer error messages for invalid logits or embeddings ids
* llama : assert all models that can have inp_out_ids Since the graph topology is now constant, this presence check can be done even when there are no outputs. * llama : assert logits and embd buffers exist before writing to them
This commit is contained in:
parent
8f70dcb0f3
commit
615a3a4a50
1 changed files with 47 additions and 19 deletions
66
llama.cpp
66
llama.cpp
|
@ -8954,7 +8954,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||
}
|
||||
|
||||
if (lctx.n_outputs > 0 && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)) {
|
||||
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
|
@ -9514,6 +9514,7 @@ static int llama_decode_internal(
|
|||
if (res) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(lctx.logits != nullptr);
|
||||
|
||||
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
|
||||
const int32_t n_outputs_new = lctx.n_outputs;
|
||||
|
@ -9534,6 +9535,7 @@ static int llama_decode_internal(
|
|||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(lctx.embd != nullptr);
|
||||
float * embd_out = lctx.embd + n_outputs_prev*n_embd;
|
||||
const int32_t n_outputs_new = lctx.n_outputs;
|
||||
|
||||
|
@ -14623,20 +14625,33 @@ float * llama_get_logits(struct llama_context * ctx) {
|
|||
}
|
||||
|
||||
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
||||
const int32_t j = ctx->output_ids[i];
|
||||
|
||||
llama_synchronize(ctx);
|
||||
|
||||
if (ctx->logits && 0 <= j && (size_t) j < ctx->output_size) {
|
||||
try {
|
||||
if (ctx->logits == nullptr) {
|
||||
throw std::runtime_error("no logits");
|
||||
}
|
||||
if ((size_t) i >= ctx->output_ids.size()) {
|
||||
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
|
||||
}
|
||||
const int32_t j = ctx->output_ids[i];
|
||||
|
||||
if (j < 0) {
|
||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
||||
}
|
||||
if ((size_t) j >= ctx->output_size) {
|
||||
// This should not happen
|
||||
throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
|
||||
}
|
||||
|
||||
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
||||
}
|
||||
LLAMA_LOG_ERROR("%s: invalid logits id %i, reason: %s (j=%i, output_size=%li)\n",
|
||||
__func__, i, !ctx->logits ? "no logits" : j < 0 ? "batch.logits[i] wasn't true" : "too big",
|
||||
j, ctx->output_size);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
||||
#ifndef NDEBUG
|
||||
GGML_ASSERT(false);
|
||||
GGML_ASSERT(false);
|
||||
#endif
|
||||
return nullptr;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||
|
@ -14646,20 +14661,33 @@ float * llama_get_embeddings(struct llama_context * ctx) {
|
|||
}
|
||||
|
||||
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
||||
const int32_t j = ctx->output_ids[i];
|
||||
|
||||
llama_synchronize(ctx);
|
||||
|
||||
if (ctx->embd && 0 <= j && (size_t) j < ctx->output_size) {
|
||||
try {
|
||||
if (ctx->embd == nullptr) {
|
||||
throw std::runtime_error("no embeddings");
|
||||
}
|
||||
if ((size_t) i >= ctx->output_ids.size()) {
|
||||
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
|
||||
}
|
||||
const int32_t j = ctx->output_ids[i];
|
||||
|
||||
if (j < 0) {
|
||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
||||
}
|
||||
if ((size_t) j >= ctx->output_size) {
|
||||
// This should not happen
|
||||
throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
|
||||
}
|
||||
|
||||
return ctx->embd + j*ctx->model.hparams.n_embd;
|
||||
}
|
||||
LLAMA_LOG_ERROR("%s: invalid embeddings id %i, reason: %s (j=%i, output_size=%li)\n",
|
||||
__func__, i, !ctx->embd ? "no embeddings" : j < 0 ? "batch.logits[i] wasn't true" : "too big",
|
||||
j, ctx->output_size);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
||||
#ifndef NDEBUG
|
||||
GGML_ASSERT(false);
|
||||
GGML_ASSERT(false);
|
||||
#endif
|
||||
return nullptr;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue