removed old llama_token functions

This commit is contained in:
Marcus Dunn 2023-10-23 09:15:48 -07:00
parent 353f4ef717
commit 22d5eb41bb
4 changed files with 23 additions and 59 deletions

View file

@ -879,13 +879,13 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
} }
if (params.ignore_eos) { if (params.ignore_eos) {
params.sparams.logit_bias[llama_token_eos(lctx)] = -INFINITY; params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
} }
{ {
LOG("warming up the model with an empty run\n"); LOG("warming up the model with an empty run\n");
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_tokens_rm(lctx, -1, -1); llama_kv_cache_tokens_rm(lctx, -1, -1);
llama_reset_timings(lctx); llama_reset_timings(lctx);
@ -940,7 +940,7 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t
} }
std::string llama_detokenize_spm(llama_context * ctx, const std::vector<llama_token> & tokens) { std::string llama_detokenize_spm(llama_context * ctx, const std::vector<llama_token> & tokens) {
const llama_token bos_id = llama_token_bos(ctx); const llama_token bos_id = llama_token_bos(llama_get_model(ctx));
std::string piece; std::string piece;
std::string result; std::string result;
@ -1185,7 +1185,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(lctx)); const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");

View file

@ -147,7 +147,7 @@ llama_token llama_sampling_sample(
// apply penalties // apply penalties
if (!prev.empty()) { if (!prev.empty()) {
const float nl_logit = logits[llama_token_nl(ctx_main)]; const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
llama_sample_repetition_penalties(ctx_main, &cur_p, llama_sample_repetition_penalties(ctx_main, &cur_p,
prev.data() + prev.size() - penalty_last_n, prev.data() + prev.size() - penalty_last_n,
@ -155,7 +155,7 @@ llama_token llama_sampling_sample(
if (!penalize_nl) { if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) { for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(ctx_main)) { if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
cur_p.data[idx].logit = nl_logit; cur_p.data[idx].logit = nl_logit;
break; break;
} }

View file

@ -7473,7 +7473,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
} }
} }
const llama_token eos = llama_token_eos(ctx); const llama_token eos = llama_token_eos(&ctx->model);
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded; std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar; std::vector<llama_grammar_candidate> candidates_grammar;
@ -7683,7 +7683,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
if (token == llama_token_eos(ctx)) { if (token == llama_token_eos(&ctx->model)) {
for (const auto & stack : grammar->stacks) { for (const auto & stack : grammar->stacks) {
if (stack.empty()) { if (stack.empty()) {
return; return;
@ -8892,7 +8892,7 @@ struct llama_context * llama_new_context_with_model(
// build worst-case graph // build worst-case graph
int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch); int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
int n_past = cparams.n_ctx - n_tokens; int n_past = cparams.n_ctx - n_tokens;
llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0)); ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0));
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
@ -9665,58 +9665,31 @@ llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_to
return ctx->model.vocab.id_to_token[token].type; return ctx->model.vocab.id_to_token[token].type;
} }
llama_token llama_token_bos(const struct llama_context * ctx) { llama_token llama_token_bos(const struct llama_model * model) {
return ctx->model.vocab.special_bos_id;
}
llama_token llama_model_token_bos(const struct llama_model * model) {
return model->vocab.special_bos_id; return model->vocab.special_bos_id;
} }
llama_token llama_token_eos(const struct llama_context * ctx) { llama_token llama_token_eos(const struct llama_model * model) {
return ctx->model.vocab.special_eos_id;
}
llama_token llama_model_token_eos(const struct llama_model * model) {
return model->vocab.special_eos_id; return model->vocab.special_eos_id;
} }
llama_token llama_token_nl(const struct llama_context * ctx) { llama_token llama_token_nl(const struct llama_model * model) {
return ctx->model.vocab.linefeed_id;
}
llama_token llama_model_token_nl(const struct llama_model * model) {
return model->vocab.linefeed_id; return model->vocab.linefeed_id;
} }
llama_token llama_token_prefix(const struct llama_context * ctx) {
return ctx->model.vocab.special_prefix_id;
}
llama_token llama_model_token_prefix(const struct llama_model * model) { llama_token llama_token_prefix(const struct llama_model * model) {
return model->vocab.special_prefix_id; return model->vocab.special_prefix_id;
} }
llama_token llama_token_middle(const struct llama_context * ctx) { llama_token llama_token_middle(const struct llama_model * model) {
return ctx->model.vocab.special_middle_id;
}
llama_token llama_model_token_middle(const struct llama_model * model) {
return model->vocab.special_middle_id; return model->vocab.special_middle_id;
} }
llama_token llama_token_suffix(const struct llama_context * ctx) { llama_token llama_token_suffix(const struct llama_model * model) {
return ctx->model.vocab.special_suffix_id;
}
llama_token llama_model_token_suffix(const struct llama_model * model) {
return model->vocab.special_suffix_id; return model->vocab.special_suffix_id;
} }
llama_token llama_token_eot(const struct llama_context * ctx) { llama_token llama_token_eot(const struct llama_model * model) {
return ctx->model.vocab.special_eot_id;
}
llama_token llama_model_token_eot(const struct llama_model * model) {
return model->vocab.special_eot_id; return model->vocab.special_eot_id;
} }

23
llama.h
View file

@ -501,24 +501,15 @@ extern "C" {
LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token); LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token);
// Special tokens // Special tokens
LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence LLAMA_API llama_token llama_token_bos(const struct llama_model * model);
LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_model * model);
LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line LLAMA_API llama_token llama_token_nl (const struct llama_model * model);
LLAMA_API llama_token llama_model_token_bos(const struct llama_model * model);
LLAMA_API llama_token llama_model_token_eos(const struct llama_model * model);
LLAMA_API llama_token llama_model_token_nl (const struct llama_model * model);
// codellama infill tokens // codellama infill tokens
LLAMA_API llama_token llama_token_prefix(const struct llama_context * ctx); // Beginning of infill prefix LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
LLAMA_API llama_token llama_token_middle(const struct llama_context * ctx); // Beginning of infill middle LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
LLAMA_API llama_token llama_token_suffix(const struct llama_context * ctx); // Beginning of infill suffix LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
LLAMA_API llama_token llama_token_eot (const struct llama_context * ctx); // End of infill middle LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
LLAMA_API llama_token llama_model_token_prefix(const struct llama_model * model);
LLAMA_API llama_token llama_model_token_middle(const struct llama_model * model);
LLAMA_API llama_token llama_model_token_suffix(const struct llama_model * model);
LLAMA_API llama_token llama_model_token_eot (const struct llama_model * model);
// //
// Tokenization // Tokenization