Merge 8be06dc745
into d041d2ceaa
This commit is contained in:
commit
a64f086d72
16 changed files with 128 additions and 66 deletions
|
@ -270,12 +270,16 @@ static llama_token llama_sampling_sample_impl(
|
||||||
|
|
||||||
std::vector<float> original_logits;
|
std::vector<float> original_logits;
|
||||||
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
|
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
|
||||||
|
if (cur_p.data == NULL) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
if (ctx_sampling->grammar != NULL && !is_resampling) {
|
if (ctx_sampling->grammar != NULL && !is_resampling) {
|
||||||
GGML_ASSERT(!original_logits.empty());
|
GGML_ASSERT(!original_logits.empty());
|
||||||
}
|
}
|
||||||
llama_token id = 0;
|
llama_token id = 0;
|
||||||
// Get a pointer to the logits
|
// Get a pointer to the logits
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
|
GGML_ASSERT(logits); // already checked in llama_sampling_prepare
|
||||||
|
|
||||||
if (temp < 0.0) {
|
if (temp < 0.0) {
|
||||||
// greedy sampling, with probs
|
// greedy sampling, with probs
|
||||||
|
@ -365,6 +369,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
|
|
||||||
// Get a pointer to the logits
|
// Get a pointer to the logits
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
|
if (!logits) {
|
||||||
|
return {NULL, 0, false};
|
||||||
|
}
|
||||||
|
|
||||||
if (ctx_sampling->grammar != NULL && !apply_grammar) {
|
if (ctx_sampling->grammar != NULL && !apply_grammar) {
|
||||||
GGML_ASSERT(original_logits != NULL);
|
GGML_ASSERT(original_logits != NULL);
|
||||||
|
@ -379,6 +386,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
|
|
||||||
if (ctx_cfg) {
|
if (ctx_cfg) {
|
||||||
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
|
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
|
||||||
|
if (!logits_guidance) {
|
||||||
|
return {NULL, 0, false};
|
||||||
|
}
|
||||||
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -169,6 +169,9 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
auto n_vocab = llama_n_vocab(model);
|
auto n_vocab = llama_n_vocab(model);
|
||||||
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
|
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
|
||||||
|
if (!logits) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
|
|
@ -58,6 +58,9 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
||||||
// sum up all token embeddings
|
// sum up all token embeddings
|
||||||
for (int32_t k = n_inst; k < n_toks; k++) {
|
for (int32_t k = n_inst; k < n_toks; k++) {
|
||||||
float * emb = llama_get_embeddings_ith(ctx, k);
|
float * emb = llama_get_embeddings_ith(ctx, k);
|
||||||
|
if (!emb) {
|
||||||
|
throw std::runtime_error("llama_get_embeddings_ith failed");
|
||||||
|
}
|
||||||
for (uint64_t j = 0; j < n_embd; j++) {
|
for (uint64_t j = 0; j < n_embd; j++) {
|
||||||
emb_unorm[j] += emb[j];
|
emb_unorm[j] += emb[j];
|
||||||
}
|
}
|
||||||
|
@ -114,6 +117,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
|
||||||
|
|
||||||
llama_decode(ctx, bat);
|
llama_decode(ctx, bat);
|
||||||
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
|
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
|
||||||
|
if (!logits) {
|
||||||
|
throw std::runtime_error("llama_get_logits_ith failed");
|
||||||
|
}
|
||||||
|
|
||||||
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
|
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
|
||||||
auto n_candidates = (int32_t)candidates.size();
|
auto n_candidates = (int32_t)candidates.size();
|
||||||
|
|
|
@ -530,6 +530,9 @@ int main(int argc, char ** argv) {
|
||||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||||
|
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
||||||
|
if (id == -1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
|
|
@ -394,6 +394,9 @@ Java_com_example_llama_Llm_completion_1loop(
|
||||||
|
|
||||||
auto n_vocab = llama_n_vocab(model);
|
auto n_vocab = llama_n_vocab(model);
|
||||||
auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
|
auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
|
||||||
|
if (!logits) {
|
||||||
|
throw std::runtime_error("llama_get_logits_ith failed");
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
|
|
@ -44,6 +44,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_llama,
|
struct llama_context * ctx_llama,
|
||||||
int * n_past) {
|
int * n_past) {
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
|
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
|
||||||
|
GGML_ASSERT(id != -1);
|
||||||
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
|
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
|
||||||
static std::string ret;
|
static std::string ret;
|
||||||
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
||||||
|
|
|
@ -159,6 +159,9 @@ int main(int argc, char ** argv) {
|
||||||
// sample first token
|
// sample first token
|
||||||
{
|
{
|
||||||
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
|
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
|
||||||
|
if (id == -1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
@ -284,6 +287,9 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// sample the next token
|
// sample the next token
|
||||||
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
|
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
|
||||||
|
if (id == -1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
@ -361,6 +367,9 @@ int main(int argc, char ** argv) {
|
||||||
// sample from the last level
|
// sample from the last level
|
||||||
for (int i = 0; i < W; i++) {
|
for (int i = 0; i < W; i++) {
|
||||||
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
|
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
|
||||||
|
if (tokens_j[N - 2][i] == -1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < W; i++) {
|
for (int i = 0; i < W; i++) {
|
||||||
|
|
|
@ -131,6 +131,7 @@ int main(int argc, char ** argv){
|
||||||
while (true) {
|
while (true) {
|
||||||
// sample from the target model
|
// sample from the target model
|
||||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
|
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
|
||||||
|
GGML_ASSERT(id != -1);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
|
|
@ -706,6 +706,9 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
||||||
|
if (id == -1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
|
llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
|
||||||
|
|
||||||
|
|
|
@ -341,6 +341,7 @@ int main(int argc, char ** argv) {
|
||||||
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
||||||
|
|
||||||
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
|
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
|
||||||
|
GGML_ASSERT(id != -1);
|
||||||
|
|
||||||
llama_sampling_accept(client.ctx_sampling, ctx, id, true);
|
llama_sampling_accept(client.ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
|
|
@ -239,6 +239,9 @@ int main(int argc, char ** argv) {
|
||||||
{
|
{
|
||||||
auto n_vocab = llama_n_vocab(model);
|
auto n_vocab = llama_n_vocab(model);
|
||||||
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||||
|
if (!logits) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
|
|
@ -638,6 +638,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
|
|
||||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
||||||
|
if (!all_logits) {
|
||||||
|
return {std::move(tokens), -1, {}, {}};
|
||||||
|
}
|
||||||
|
|
||||||
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
||||||
if (!params.logits_file.empty()) {
|
if (!params.logits_file.empty()) {
|
||||||
|
|
|
@ -2257,6 +2257,14 @@ struct server_context {
|
||||||
|
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
|
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
|
||||||
|
if (id == -1) {
|
||||||
|
send_error(slot, "can't get completions out of an embeddings model");
|
||||||
|
slot.cache_tokens.clear();
|
||||||
|
slot.reset();
|
||||||
|
slot.release();
|
||||||
|
slot.i_batch = -1;
|
||||||
|
continue; // continue loop of slots
|
||||||
|
}
|
||||||
|
|
||||||
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
|
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
|
|
|
@ -120,6 +120,9 @@ int main(int argc, char ** argv) {
|
||||||
{
|
{
|
||||||
auto n_vocab = llama_n_vocab(model);
|
auto n_vocab = llama_n_vocab(model);
|
||||||
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||||
|
if (!logits) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
|
|
@ -229,6 +229,9 @@ int main(int argc, char ** argv) {
|
||||||
// stochastic verification
|
// stochastic verification
|
||||||
|
|
||||||
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
|
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
|
||||||
|
if (dist_tgt.data == NULL) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
llama_sample_softmax(ctx_tgt, &dist_tgt);
|
llama_sample_softmax(ctx_tgt, &dist_tgt);
|
||||||
float p_tgt = 0, p_dft = 0;
|
float p_tgt = 0, p_dft = 0;
|
||||||
|
|
||||||
|
@ -337,6 +340,9 @@ int main(int argc, char ** argv) {
|
||||||
// sample from the target model
|
// sample from the target model
|
||||||
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
|
if (token_id == -1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||||
|
|
||||||
|
@ -457,7 +463,9 @@ int main(int argc, char ** argv) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
|
if (llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft) == -1) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
const auto & cur_p = drafts[s].ctx_sampling->cur;
|
const auto & cur_p = drafts[s].ctx_sampling->cur;
|
||||||
|
|
||||||
|
|
123
llama.cpp
123
llama.cpp
|
@ -17745,42 +17745,39 @@ float * llama_get_logits(struct llama_context * ctx) {
|
||||||
return ctx->logits;
|
return ctx->logits;
|
||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
static float * llama_get_logits_ith_fail(int i, const std::string & reason) {
|
||||||
int32_t j = -1;
|
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, reason.c_str());
|
||||||
llama_synchronize(ctx);
|
|
||||||
|
|
||||||
try {
|
|
||||||
if (ctx->logits == nullptr) {
|
|
||||||
throw std::runtime_error("no logits");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (i < 0) {
|
|
||||||
j = ctx->n_outputs + i;
|
|
||||||
if (j < 0) {
|
|
||||||
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
|
|
||||||
}
|
|
||||||
} else if ((size_t) i >= ctx->output_ids.size()) {
|
|
||||||
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
|
|
||||||
} else {
|
|
||||||
j = ctx->output_ids[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (j < 0) {
|
|
||||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
|
||||||
}
|
|
||||||
if (j >= ctx->n_outputs) {
|
|
||||||
// This should not happen
|
|
||||||
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
|
||||||
}
|
|
||||||
|
|
||||||
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
|
||||||
} catch (const std::exception & err) {
|
|
||||||
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
#endif
|
#endif
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
||||||
|
int32_t j = -1;
|
||||||
|
llama_synchronize(ctx);
|
||||||
|
if (ctx->logits == nullptr) {
|
||||||
|
// this can happen for embeddings models like bert
|
||||||
|
return llama_get_logits_ith_fail(i, "no logits");
|
||||||
|
}
|
||||||
|
if (i < 0) {
|
||||||
|
j = ctx->n_outputs + i;
|
||||||
|
if (j < 0) {
|
||||||
|
return llama_get_logits_ith_fail(i, format("negative index out of range [%d, 0)", -ctx->n_outputs));
|
||||||
|
}
|
||||||
|
} else if ((size_t) i >= ctx->output_ids.size()) {
|
||||||
|
return llama_get_logits_ith_fail(i, format("out of range [0, %lu)", ctx->output_ids.size()));
|
||||||
|
} else {
|
||||||
|
j = ctx->output_ids[i];
|
||||||
|
}
|
||||||
|
if (j < 0) {
|
||||||
|
return llama_get_logits_ith_fail(i, format("batch.logits[%d] != true", i));
|
||||||
|
}
|
||||||
|
if (j >= ctx->n_outputs) {
|
||||||
|
// This should not happen
|
||||||
|
return llama_get_logits_ith_fail(i, format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
||||||
|
}
|
||||||
|
return ctx->logits + j*ctx->model.hparams.n_vocab;
|
||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_embeddings(struct llama_context * ctx) {
|
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||||
|
@ -17789,43 +17786,43 @@ float * llama_get_embeddings(struct llama_context * ctx) {
|
||||||
return ctx->embd;
|
return ctx->embd;
|
||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
static float * llama_get_embeddings_ith_fail(int i, const std::string & reason) {
|
||||||
int32_t j = -1;
|
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, reason.c_str());
|
||||||
|
|
||||||
llama_synchronize(ctx);
|
|
||||||
|
|
||||||
try {
|
|
||||||
if (ctx->embd == nullptr) {
|
|
||||||
throw std::runtime_error("no embeddings");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (i < 0) {
|
|
||||||
j = ctx->n_outputs + i;
|
|
||||||
if (j < 0) {
|
|
||||||
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
|
|
||||||
}
|
|
||||||
} else if ((size_t) i >= ctx->output_ids.size()) {
|
|
||||||
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
|
|
||||||
} else {
|
|
||||||
j = ctx->output_ids[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (j < 0) {
|
|
||||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
|
||||||
}
|
|
||||||
if (j >= ctx->n_outputs) {
|
|
||||||
// This should not happen
|
|
||||||
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
|
|
||||||
}
|
|
||||||
|
|
||||||
return ctx->embd + j*ctx->model.hparams.n_embd;
|
|
||||||
} catch (const std::exception & err) {
|
|
||||||
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
#endif
|
#endif
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
||||||
|
int32_t j = -1;
|
||||||
|
llama_synchronize(ctx);
|
||||||
|
if (ctx->embd == nullptr) {
|
||||||
|
return llama_get_embeddings_ith_fail(i, "no embeddings");
|
||||||
|
}
|
||||||
|
if (i < 0) {
|
||||||
|
j = ctx->n_outputs + i;
|
||||||
|
if (j < 0) {
|
||||||
|
return llama_get_embeddings_ith_fail(
|
||||||
|
i, format("negative index out of range [%d, 0)", -ctx->n_outputs));
|
||||||
|
}
|
||||||
|
} else if ((size_t) i >= ctx->output_ids.size()) {
|
||||||
|
return llama_get_embeddings_ith_fail(
|
||||||
|
i, format("out of range [0, %lu)", ctx->output_ids.size()));
|
||||||
|
} else {
|
||||||
|
j = ctx->output_ids[i];
|
||||||
|
}
|
||||||
|
if (j < 0) {
|
||||||
|
return llama_get_embeddings_ith_fail(
|
||||||
|
i, format("batch.logits[%d] != true", i));
|
||||||
|
}
|
||||||
|
if (j >= ctx->n_outputs) {
|
||||||
|
// This should not happen
|
||||||
|
return llama_get_embeddings_ith_fail(
|
||||||
|
i, format("corrupt output buffer (j=%d, n_outputs=%d)",
|
||||||
|
j, ctx->n_outputs));
|
||||||
|
}
|
||||||
|
return ctx->embd + j*ctx->model.hparams.n_embd;
|
||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue