llama : sanity checks for access to logits
This commit is contained in:
parent
1f5cd83275
commit
f91707bbe1
1 changed files with 20 additions and 0 deletions
20
llama.cpp
20
llama.cpp
|
@ -1468,6 +1468,10 @@ struct llama_context {
|
||||||
|
|
||||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
|
#ifndef NDEBUG
|
||||||
|
// guard against access to unset logits
|
||||||
|
std::vector<bool> logits_valid;
|
||||||
|
#endif
|
||||||
bool logits_all = false;
|
bool logits_all = false;
|
||||||
|
|
||||||
// input embedding (1-dimensional array: [n_embd])
|
// input embedding (1-dimensional array: [n_embd])
|
||||||
|
@ -5609,6 +5613,12 @@ static int llama_decode_internal(
|
||||||
{
|
{
|
||||||
auto & logits_out = lctx.logits;
|
auto & logits_out = lctx.logits;
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
auto & logits_valid = lctx.logits_valid;
|
||||||
|
logits_valid.clear();
|
||||||
|
logits_valid.resize(n_vocab);
|
||||||
|
#endif
|
||||||
|
|
||||||
if (batch.logits) {
|
if (batch.logits) {
|
||||||
logits_out.resize(n_vocab * n_tokens);
|
logits_out.resize(n_vocab * n_tokens);
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
|
@ -5616,13 +5626,22 @@ static int llama_decode_internal(
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
|
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
|
||||||
|
#ifndef NDEBUG
|
||||||
|
logits_valid[i] = true;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
} else if (lctx.logits_all) {
|
} else if (lctx.logits_all) {
|
||||||
logits_out.resize(n_vocab * n_tokens);
|
logits_out.resize(n_vocab * n_tokens);
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
|
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
|
||||||
|
#ifndef NDEBUG
|
||||||
|
std::fill(logits_valid.begin(), logits_valid.end(), true);
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
logits_out.resize(n_vocab);
|
logits_out.resize(n_vocab);
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab);
|
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab);
|
||||||
|
#ifndef NDEBUG
|
||||||
|
logits_valid[n_tokens - 1] = true;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9465,6 +9484,7 @@ float * llama_get_logits(struct llama_context * ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
||||||
|
assert(ctx->logits_valid.at(i));
|
||||||
return ctx->logits.data() + i*ctx->model.hparams.n_vocab;
|
return ctx->logits.data() + i*ctx->model.hparams.n_vocab;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue