llama : fix llama_get_embeddings_ith when the resulting id is 0

This commit is contained in:
Francis Couture-Harpin 2024-03-17 15:34:56 -04:00
parent 487f89ec2e
commit 408fcb0f91

View file

@ -14650,11 +14650,13 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
const int32_t j = ctx->output_ids[i];
llama_synchronize(ctx);
if (ctx->logits && 0 <= j && (size_t) j < ctx->output_size) {
return ctx->logits + j*ctx->model.hparams.n_vocab;
}
LLAMA_LOG_ERROR("%s: invalid logits id %i, reason: %s (j=%i, output_size=%li)\n",
__func__, i, !ctx->logits ? "no logits" : j < 0 ? "batch.logits[i] wasn't true" : "too big", j, ctx->output_size);
__func__, i, !ctx->logits ? "no logits" : j < 0 ? "batch.logits[i] wasn't true" : "too big",
j, ctx->output_size);
#ifndef NDEBUG
GGML_ASSERT(false);
#endif
@ -14672,10 +14674,12 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
llama_synchronize(ctx);
if (ctx->embd && 0 < j && (size_t) j < ctx->output_size) {
if (ctx->embd && 0 <= j && (size_t) j < ctx->output_size) {
return ctx->embd + j*ctx->model.hparams.n_embd;
}
LLAMA_LOG_ERROR("%s: invalid embeddings id %i\n", __func__, i);
LLAMA_LOG_ERROR("%s: invalid embeddings id %i, reason: %s (j=%i, output_size=%li)\n",
__func__, i, !ctx->embd ? "no embeddings" : j < 0 ? "batch.logits[i] wasn't true" : "too big",
j, ctx->output_size);
#ifndef NDEBUG
GGML_ASSERT(false);
#endif