llama : clearer error messages for invalid logits or embeddings ids
* llama : assert all models that can have inp_out_ids Since the graph topology is now constant, this presence check can be done even when there are no outputs. * llama : assert logits and embd buffers exist before writing to them
This commit is contained in:
parent
8f70dcb0f3
commit
615a3a4a50
1 changed files with 47 additions and 19 deletions
58
llama.cpp
58
llama.cpp
|
@ -8954,7 +8954,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lctx.n_outputs > 0 && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)) {
|
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||||
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
|
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
|
@ -9514,6 +9514,7 @@ static int llama_decode_internal(
|
||||||
if (res) {
|
if (res) {
|
||||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
|
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
|
||||||
GGML_ASSERT(backend_res != nullptr);
|
GGML_ASSERT(backend_res != nullptr);
|
||||||
|
GGML_ASSERT(lctx.logits != nullptr);
|
||||||
|
|
||||||
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
|
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
|
||||||
const int32_t n_outputs_new = lctx.n_outputs;
|
const int32_t n_outputs_new = lctx.n_outputs;
|
||||||
|
@ -9534,6 +9535,7 @@ static int llama_decode_internal(
|
||||||
case LLAMA_POOLING_TYPE_NONE:
|
case LLAMA_POOLING_TYPE_NONE:
|
||||||
{
|
{
|
||||||
// extract token embeddings
|
// extract token embeddings
|
||||||
|
GGML_ASSERT(lctx.embd != nullptr);
|
||||||
float * embd_out = lctx.embd + n_outputs_prev*n_embd;
|
float * embd_out = lctx.embd + n_outputs_prev*n_embd;
|
||||||
const int32_t n_outputs_new = lctx.n_outputs;
|
const int32_t n_outputs_new = lctx.n_outputs;
|
||||||
|
|
||||||
|
@ -14623,20 +14625,33 @@ float * llama_get_logits(struct llama_context * ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
||||||
const int32_t j = ctx->output_ids[i];
|
|
||||||
|
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
if (ctx->logits && 0 <= j && (size_t) j < ctx->output_size) {
|
try {
|
||||||
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
if (ctx->logits == nullptr) {
|
||||||
|
throw std::runtime_error("no logits");
|
||||||
}
|
}
|
||||||
LLAMA_LOG_ERROR("%s: invalid logits id %i, reason: %s (j=%i, output_size=%li)\n",
|
if ((size_t) i >= ctx->output_ids.size()) {
|
||||||
__func__, i, !ctx->logits ? "no logits" : j < 0 ? "batch.logits[i] wasn't true" : "too big",
|
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
|
||||||
j, ctx->output_size);
|
}
|
||||||
|
const int32_t j = ctx->output_ids[i];
|
||||||
|
|
||||||
|
if (j < 0) {
|
||||||
|
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
||||||
|
}
|
||||||
|
if ((size_t) j >= ctx->output_size) {
|
||||||
|
// This should not happen
|
||||||
|
throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
||||||
|
} catch (const std::exception & err) {
|
||||||
|
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
#endif
|
#endif
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_embeddings(struct llama_context * ctx) {
|
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||||
|
@ -14646,20 +14661,33 @@ float * llama_get_embeddings(struct llama_context * ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
||||||
const int32_t j = ctx->output_ids[i];
|
|
||||||
|
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
if (ctx->embd && 0 <= j && (size_t) j < ctx->output_size) {
|
try {
|
||||||
return ctx->embd + j*ctx->model.hparams.n_embd;
|
if (ctx->embd == nullptr) {
|
||||||
|
throw std::runtime_error("no embeddings");
|
||||||
}
|
}
|
||||||
LLAMA_LOG_ERROR("%s: invalid embeddings id %i, reason: %s (j=%i, output_size=%li)\n",
|
if ((size_t) i >= ctx->output_ids.size()) {
|
||||||
__func__, i, !ctx->embd ? "no embeddings" : j < 0 ? "batch.logits[i] wasn't true" : "too big",
|
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
|
||||||
j, ctx->output_size);
|
}
|
||||||
|
const int32_t j = ctx->output_ids[i];
|
||||||
|
|
||||||
|
if (j < 0) {
|
||||||
|
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
||||||
|
}
|
||||||
|
if ((size_t) j >= ctx->output_size) {
|
||||||
|
// This should not happen
|
||||||
|
throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx->embd + j*ctx->model.hparams.n_embd;
|
||||||
|
} catch (const std::exception & err) {
|
||||||
|
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
#endif
|
#endif
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue