This commit is contained in:
Justine Tunney 2024-05-24 18:49:07 -04:00 committed by GitHub
commit a64f086d72
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 128 additions and 66 deletions

View file

@ -270,12 +270,16 @@ static llama_token llama_sampling_sample_impl(
std::vector<float> original_logits; std::vector<float> original_logits;
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
if (cur_p.data == NULL) {
return -1;
}
if (ctx_sampling->grammar != NULL && !is_resampling) { if (ctx_sampling->grammar != NULL && !is_resampling) {
GGML_ASSERT(!original_logits.empty()); GGML_ASSERT(!original_logits.empty());
} }
llama_token id = 0; llama_token id = 0;
// Get a pointer to the logits // Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx); float * logits = llama_get_logits_ith(ctx_main, idx);
GGML_ASSERT(logits); // already checked in llama_sampling_prepare
if (temp < 0.0) { if (temp < 0.0) {
// greedy sampling, with probs // greedy sampling, with probs
@ -365,6 +369,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
// Get a pointer to the logits // Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx); float * logits = llama_get_logits_ith(ctx_main, idx);
if (!logits) {
return {NULL, 0, false};
}
if (ctx_sampling->grammar != NULL && !apply_grammar) { if (ctx_sampling->grammar != NULL && !apply_grammar) {
GGML_ASSERT(original_logits != NULL); GGML_ASSERT(original_logits != NULL);
@ -379,6 +386,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
if (ctx_cfg) { if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
if (!logits_guidance) {
return {NULL, 0, false};
}
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
} }

View file

@ -169,6 +169,9 @@ int main(int argc, char ** argv) {
auto n_vocab = llama_n_vocab(model); auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, i_batch[i]); auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
if (!logits) {
return 1;
}
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);

View file

@ -58,6 +58,9 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
// sum up all token embeddings // sum up all token embeddings
for (int32_t k = n_inst; k < n_toks; k++) { for (int32_t k = n_inst; k < n_toks; k++) {
float * emb = llama_get_embeddings_ith(ctx, k); float * emb = llama_get_embeddings_ith(ctx, k);
if (!emb) {
throw std::runtime_error("llama_get_embeddings_ith failed");
}
for (uint64_t j = 0; j < n_embd; j++) { for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] += emb[j]; emb_unorm[j] += emb[j];
} }
@ -114,6 +117,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
llama_decode(ctx, bat); llama_decode(ctx, bat);
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
if (!logits) {
throw std::runtime_error("llama_get_logits_ith failed");
}
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl)); auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
auto n_candidates = (int32_t)candidates.size(); auto n_candidates = (int32_t)candidates.size();

View file

@ -530,6 +530,9 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
if (id == -1) {
return 1;
}
llama_sampling_accept(ctx_sampling, ctx, id, true); llama_sampling_accept(ctx_sampling, ctx, id, true);

View file

@ -394,6 +394,9 @@ Java_com_example_llama_Llm_completion_1loop(
auto n_vocab = llama_n_vocab(model); auto n_vocab = llama_n_vocab(model);
auto logits = llama_get_logits_ith(context, batch->n_tokens - 1); auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
if (!logits) {
throw std::runtime_error("llama_get_logits_ith failed");
}
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);

View file

@ -44,6 +44,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_llama, struct llama_context * ctx_llama,
int * n_past) { int * n_past) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL); const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
GGML_ASSERT(id != -1);
llama_sampling_accept(ctx_sampling, ctx_llama, id, true); llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
static std::string ret; static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {

View file

@ -159,6 +159,9 @@ int main(int argc, char ** argv) {
// sample first token // sample first token
{ {
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0); id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
if (id == -1) {
return 1;
}
llama_sampling_accept(ctx_sampling, ctx, id, true); llama_sampling_accept(ctx_sampling, ctx, id, true);
@ -284,6 +287,9 @@ int main(int argc, char ** argv) {
// sample the next token // sample the next token
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch); id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
if (id == -1) {
return 1;
}
llama_sampling_accept(ctx_sampling, ctx, id, true); llama_sampling_accept(ctx_sampling, ctx, id, true);
@ -361,6 +367,9 @@ int main(int argc, char ** argv) {
// sample from the last level // sample from the last level
for (int i = 0; i < W; i++) { for (int i = 0; i < W; i++) {
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i); tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
if (tokens_j[N - 2][i] == -1) {
return 1;
}
} }
} else { } else {
for (int i = 0; i < W; i++) { for (int i = 0; i < W; i++) {

View file

@ -131,6 +131,7 @@ int main(int argc, char ** argv){
while (true) { while (true) {
// sample from the target model // sample from the target model
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft); llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
GGML_ASSERT(id != -1);
llama_sampling_accept(ctx_sampling, ctx, id, true); llama_sampling_accept(ctx_sampling, ctx, id, true);

View file

@ -706,6 +706,9 @@ int main(int argc, char ** argv) {
} }
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
if (id == -1) {
return 1;
}
llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true); llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);

View file

@ -341,6 +341,7 @@ int main(int argc, char ** argv) {
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i); const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
GGML_ASSERT(id != -1);
llama_sampling_accept(client.ctx_sampling, ctx, id, true); llama_sampling_accept(client.ctx_sampling, ctx, id, true);

View file

@ -239,6 +239,9 @@ int main(int argc, char ** argv) {
{ {
auto n_vocab = llama_n_vocab(model); auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
if (!logits) {
return 1;
}
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);

View file

@ -638,6 +638,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
for (int seq = 0; seq < n_seq_batch; seq++) { for (int seq = 0; seq < n_seq_batch; seq++) {
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first); const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
if (!all_logits) {
return {std::move(tokens), -1, {}, {}};
}
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first; llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
if (!params.logits_file.empty()) { if (!params.logits_file.empty()) {

View file

@ -2257,6 +2257,14 @@ struct server_context {
completion_token_output result; completion_token_output result;
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i); const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
if (id == -1) {
send_error(slot, "can't get completions out of an embeddings model");
slot.cache_tokens.clear();
slot.reset();
slot.release();
slot.i_batch = -1;
continue; // continue loop of slots
}
llama_sampling_accept(slot.ctx_sampling, ctx, id, true); llama_sampling_accept(slot.ctx_sampling, ctx, id, true);

View file

@ -120,6 +120,9 @@ int main(int argc, char ** argv) {
{ {
auto n_vocab = llama_n_vocab(model); auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
if (!logits) {
return 1;
}
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);

View file

@ -229,6 +229,9 @@ int main(int argc, char ** argv) {
// stochastic verification // stochastic verification
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL); llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
if (dist_tgt.data == NULL) {
return 1;
}
llama_sample_softmax(ctx_tgt, &dist_tgt); llama_sample_softmax(ctx_tgt, &dist_tgt);
float p_tgt = 0, p_dft = 0; float p_tgt = 0, p_dft = 0;
@ -337,6 +340,9 @@ int main(int argc, char ** argv) {
// sample from the target model // sample from the target model
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
if (token_id == -1) {
return 1;
}
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
@ -457,7 +463,9 @@ int main(int argc, char ** argv) {
continue; continue;
} }
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft); if (llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft) == -1) {
return -1;
}
const auto & cur_p = drafts[s].ctx_sampling->cur; const auto & cur_p = drafts[s].ctx_sampling->cur;

123
llama.cpp
View file

@ -17745,42 +17745,39 @@ float * llama_get_logits(struct llama_context * ctx) {
return ctx->logits; return ctx->logits;
} }
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { static float * llama_get_logits_ith_fail(int i, const std::string & reason) {
int32_t j = -1; LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, reason.c_str());
llama_synchronize(ctx);
try {
if (ctx->logits == nullptr) {
throw std::runtime_error("no logits");
}
if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
}
} else if ((size_t) i >= ctx->output_ids.size()) {
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}
if (j < 0) {
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
if (j >= ctx->n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
}
return ctx->logits + j*ctx->model.hparams.n_vocab;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG #ifndef NDEBUG
GGML_ASSERT(false); GGML_ASSERT(false);
#endif #endif
return nullptr; return nullptr;
} }
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
int32_t j = -1;
llama_synchronize(ctx);
if (ctx->logits == nullptr) {
// this can happen for embeddings models like bert
return llama_get_logits_ith_fail(i, "no logits");
}
if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
return llama_get_logits_ith_fail(i, format("negative index out of range [%d, 0)", -ctx->n_outputs));
}
} else if ((size_t) i >= ctx->output_ids.size()) {
return llama_get_logits_ith_fail(i, format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}
if (j < 0) {
return llama_get_logits_ith_fail(i, format("batch.logits[%d] != true", i));
}
if (j >= ctx->n_outputs) {
// This should not happen
return llama_get_logits_ith_fail(i, format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
}
return ctx->logits + j*ctx->model.hparams.n_vocab;
} }
float * llama_get_embeddings(struct llama_context * ctx) { float * llama_get_embeddings(struct llama_context * ctx) {
@ -17789,43 +17786,43 @@ float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embd; return ctx->embd;
} }
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { static float * llama_get_embeddings_ith_fail(int i, const std::string & reason) {
int32_t j = -1; LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, reason.c_str());
llama_synchronize(ctx);
try {
if (ctx->embd == nullptr) {
throw std::runtime_error("no embeddings");
}
if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
}
} else if ((size_t) i >= ctx->output_ids.size()) {
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}
if (j < 0) {
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
if (j >= ctx->n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
}
return ctx->embd + j*ctx->model.hparams.n_embd;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG #ifndef NDEBUG
GGML_ASSERT(false); GGML_ASSERT(false);
#endif #endif
return nullptr; return nullptr;
} }
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
int32_t j = -1;
llama_synchronize(ctx);
if (ctx->embd == nullptr) {
return llama_get_embeddings_ith_fail(i, "no embeddings");
}
if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
return llama_get_embeddings_ith_fail(
i, format("negative index out of range [%d, 0)", -ctx->n_outputs));
}
} else if ((size_t) i >= ctx->output_ids.size()) {
return llama_get_embeddings_ith_fail(
i, format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}
if (j < 0) {
return llama_get_embeddings_ith_fail(
i, format("batch.logits[%d] != true", i));
}
if (j >= ctx->n_outputs) {
// This should not happen
return llama_get_embeddings_ith_fail(
i, format("corrupt output buffer (j=%d, n_outputs=%d)",
j, ctx->n_outputs));
}
return ctx->embd + j*ctx->model.hparams.n_embd;
} }
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) { float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {