llama : clearer error messages for invalid logits or embeddings ids

* llama : assert all models that can have inp_out_ids

Since the graph topology is now constant, this presence check
can be done even when there are no outputs.

* llama : assert logits and embd buffers exist before writing to them
This commit is contained in:
Francis Couture-Harpin 2024-03-19 15:01:21 -04:00
parent 8f70dcb0f3
commit 615a3a4a50

View file

@ -8954,7 +8954,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}
if (lctx.n_outputs > 0 && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)) {
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
const int64_t n_tokens = batch.n_tokens;
@ -9514,6 +9514,7 @@ static int llama_decode_internal(
if (res) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(lctx.logits != nullptr);
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
const int32_t n_outputs_new = lctx.n_outputs;
@ -9534,6 +9535,7 @@ static int llama_decode_internal(
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
GGML_ASSERT(lctx.embd != nullptr);
float * embd_out = lctx.embd + n_outputs_prev*n_embd;
const int32_t n_outputs_new = lctx.n_outputs;
@ -14623,20 +14625,33 @@ float * llama_get_logits(struct llama_context * ctx) {
}
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
const int32_t j = ctx->output_ids[i];
llama_synchronize(ctx);
if (ctx->logits && 0 <= j && (size_t) j < ctx->output_size) {
try {
if (ctx->logits == nullptr) {
throw std::runtime_error("no logits");
}
if ((size_t) i >= ctx->output_ids.size()) {
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
}
const int32_t j = ctx->output_ids[i];
if (j < 0) {
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
if ((size_t) j >= ctx->output_size) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
}
return ctx->logits + j*ctx->model.hparams.n_vocab;
}
LLAMA_LOG_ERROR("%s: invalid logits id %i, reason: %s (j=%i, output_size=%li)\n",
__func__, i, !ctx->logits ? "no logits" : j < 0 ? "batch.logits[i] wasn't true" : "too big",
j, ctx->output_size);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
GGML_ASSERT(false);
GGML_ASSERT(false);
#endif
return nullptr;
return nullptr;
}
}
float * llama_get_embeddings(struct llama_context * ctx) {
@ -14646,20 +14661,33 @@ float * llama_get_embeddings(struct llama_context * ctx) {
}
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
const int32_t j = ctx->output_ids[i];
llama_synchronize(ctx);
if (ctx->embd && 0 <= j && (size_t) j < ctx->output_size) {
try {
if (ctx->embd == nullptr) {
throw std::runtime_error("no embeddings");
}
if ((size_t) i >= ctx->output_ids.size()) {
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
}
const int32_t j = ctx->output_ids[i];
if (j < 0) {
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
if ((size_t) j >= ctx->output_size) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
}
return ctx->embd + j*ctx->model.hparams.n_embd;
}
LLAMA_LOG_ERROR("%s: invalid embeddings id %i, reason: %s (j=%i, output_size=%li)\n",
__func__, i, !ctx->embd ? "no embeddings" : j < 0 ? "batch.logits[i] wasn't true" : "too big",
j, ctx->output_size);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
GGML_ASSERT(false);
GGML_ASSERT(false);
#endif
return nullptr;
return nullptr;
}
}
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {