llama : add llama_token_is_eog()

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-04-20 16:46:46 +03:00
parent f3105b9eec
commit 3750706962
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
18 changed files with 76 additions and 46 deletions

View file

@ -2531,16 +2531,22 @@ class MambaModel(Model):
field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL) field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1])) self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]))
field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST) field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST)
self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size]) self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE) field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size]) self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES) field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES)
self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data]) self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data])
field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID) field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)
self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0]) self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0])
field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID) field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)
self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0]) self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0])
field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID) field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0]) self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0])

View file

@ -153,7 +153,7 @@ while n_cur <= n_len {
// const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); // const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
// is it an end of stream? -> mark the stream as finished // is it an end of stream? -> mark the stream as finished
if new_token_id == llama_token_eos(model) || n_cur == n_len { if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
i_batch[i] = -1 i_batch[i] = -1
// print("") // print("")
if n_parallel > 1 { if n_parallel > 1 {

View file

@ -191,8 +191,8 @@ int main(int argc, char ** argv) {
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
// is it an end of stream? -> mark the stream as finished // is it an end of generation? -> mark the stream as finished
if (new_token_id == llama_token_eos(model) || n_cur == n_len) { if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
i_batch[i] = -1; i_batch[i] = -1;
LOG_TEE("\n"); LOG_TEE("\n");
if (n_parallel > 1) { if (n_parallel > 1) {

View file

@ -47,7 +47,7 @@ struct beam_search_callback_data {
// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same. // In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
// For example, eob can be flagged due to maximum token length, stop words, etc. // For example, eob can be flagged due to maximum token length, stop words, etc.
static bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, size_t n_tokens) { static bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, size_t n_tokens) {
return n_tokens && tokens[n_tokens-1] == llama_token_eos(llama_get_model(callback_data.ctx)); return n_tokens && llama_token_is_eog(llama_get_model(callback_data.ctx), tokens[n_tokens-1]);
} }
// Function matching type llama_beam_search_callback_fn_t. // Function matching type llama_beam_search_callback_fn_t.

View file

@ -651,8 +651,8 @@ int main(int argc, char ** argv) {
// LOG_TEE("took new input\n"); // LOG_TEE("took new input\n");
is_interacting = false; is_interacting = false;
} }
// deal with end of text token in interactive mode // deal with end of generation tokens in interactive mode
else if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) { else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
LOG("found EOS token\n"); LOG("found EOS token\n");
if (params.interactive) { if (params.interactive) {
@ -731,8 +731,8 @@ int main(int argc, char ** argv) {
} }
} }
// end of text token // end of generation
if (!embd.empty() && embd.back() == llama_token_eos(model) && !params.interactive) { if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !params.interactive) {
break; break;
} }

View file

@ -408,7 +408,7 @@ Java_com_example_llama_Llm_completion_1loop(
const auto new_token_id = llama_sample_token_greedy(context, &candidates_p); const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (new_token_id == llama_token_eos(model) || n_cur == n_len) { if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
return env->NewStringUTF(""); return env->NewStringUTF("");
} }

View file

@ -158,7 +158,7 @@ actor LlamaContext {
new_token_id = llama_sample_token_greedy(context, &candidates_p) new_token_id = llama_sample_token_greedy(context, &candidates_p)
} }
if new_token_id == llama_token_eos(model) || n_cur == n_len { if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
print("\n") print("\n")
let new_token_str = String(cString: temporary_invalid_cchars + [0]) let new_token_str = String(cString: temporary_invalid_cchars + [0])
temporary_invalid_cchars.removeAll() temporary_invalid_cchars.removeAll()

View file

@ -45,7 +45,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL); const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
llama_sampling_accept(ctx_sampling, ctx_llama, id, true); llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
static std::string ret; static std::string ret;
if (id == llama_token_eos(llama_get_model(ctx_llama))) { if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
ret = "</s>"; ret = "</s>";
} else { } else {
ret = llama_token_to_piece(ctx_llama, id); ret = llama_token_to_piece(ctx_llama, id);

View file

@ -299,7 +299,7 @@ int main(int argc, char ** argv) {
} }
fflush(stdout); fflush(stdout);
if (id == llama_token_eos(model)) { if (llama_token_is_eog(model, id)) {
has_eos = true; has_eos = true;
} }

View file

@ -141,7 +141,7 @@ int main(int argc, char ** argv){
printf("%s", token_str.c_str()); printf("%s", token_str.c_str());
} }
if (id == llama_token_eos(model)) { if (llama_token_is_eog(model, id)) {
has_eos = true; has_eos = true;
} }

View file

@ -795,8 +795,8 @@ int main(int argc, char ** argv) {
} }
} }
// deal with end of text token in interactive mode // deal with end of generation tokens in interactive mode
if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) { if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
LOG("found EOS token\n"); LOG("found EOS token\n");
if (params.interactive) { if (params.interactive) {
@ -920,8 +920,8 @@ int main(int argc, char ** argv) {
} }
} }
// end of text token // end of generation
if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive || params.chatml)) { if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !(params.instruct || params.interactive || params.chatml)) {
LOG_TEE(" [end of text]\n"); LOG_TEE(" [end of text]\n");
break; break;
} }

View file

@ -359,7 +359,7 @@ int main(int argc, char ** argv) {
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
if (client.n_decoded > 2 && if (client.n_decoded > 2 &&
(id == llama_token_eos(model) || (llama_token_is_eog(model, id) ||
(params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) || (params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) ||
client.response.find("User:") != std::string::npos || client.response.find("User:") != std::string::npos ||
client.response.find('\n') != std::string::npos)) { client.response.find('\n') != std::string::npos)) {

View file

@ -252,8 +252,8 @@ int main(int argc, char ** argv) {
// sample the most likely token // sample the most likely token
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
// is it an end of stream? // is it an end of generation?
if (new_token_id == llama_token_eos(model) || n_cur == n_len) { if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
LOG_TEE("\n"); LOG_TEE("\n");
break; break;

View file

@ -1201,7 +1201,7 @@ struct server_context {
}); });
} }
if (result.tok == llama_token_eos(model)) { if (llama_token_is_eog(model, result.tok)) {
slot.stopped_eos = true; slot.stopped_eos = true;
slot.has_next_token = false; slot.has_next_token = false;

View file

@ -133,8 +133,8 @@ int main(int argc, char ** argv) {
// sample the most likely token // sample the most likely token
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
// is it an end of stream? // is it an end of generation?
if (new_token_id == llama_token_eos(model) || n_cur == n_len) { if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
LOG_TEE("\n"); LOG_TEE("\n");
break; break;

View file

@ -360,7 +360,7 @@ int main(int argc, char ** argv) {
} }
} }
if (token_id == llama_token_eos(model_tgt)) { if (llama_token_is_eog(model_tgt, token_id)) {
has_eos = true; has_eos = true;
} }
++n_predict; ++n_predict;

View file

@ -4280,6 +4280,7 @@ static void llm_load_vocab(
{ LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id }, { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
{ LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id }, { LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id },
}; };
for (const auto & it : special_token_types) { for (const auto & it : special_token_types) {
const std::string & key = kv(std::get<0>(it)); const std::string & key = kv(std::get<0>(it));
int32_t & id = std::get<1>(it); int32_t & id = std::get<1>(it);
@ -4294,7 +4295,6 @@ static void llm_load_vocab(
} else { } else {
id = new_id; id = new_id;
} }
} }
// Handle add_bos_token and add_eos_token // Handle add_bos_token and add_eos_token
@ -4308,6 +4308,17 @@ static void llm_load_vocab(
vocab.special_add_eos = int(temp); vocab.special_add_eos = int(temp);
} }
} }
// find EOT token "<|eot_id|>"
// TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOT_ID
if (vocab.special_eot_id == -1) {
for (const auto & t : vocab.token_to_id) {
if (t.first == "<|eot_id|>" && vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL) {
vocab.special_eot_id = t.second;
break;
}
}
}
} }
// build special tokens cache // build special tokens cache
@ -4477,7 +4488,12 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
if (vocab.special_cls_id != -1) { LLAMA_LOG_INFO( "%s: CLS token = %d '%s'\n", __func__, vocab.special_cls_id, vocab.id_to_token[vocab.special_cls_id].text.c_str() ); } if (vocab.special_cls_id != -1) { LLAMA_LOG_INFO( "%s: CLS token = %d '%s'\n", __func__, vocab.special_cls_id, vocab.id_to_token[vocab.special_cls_id].text.c_str() ); }
if (vocab.special_mask_id != -1) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); } if (vocab.special_mask_id != -1) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
if (vocab.special_prefix_id != -1) { LLAMA_LOG_INFO( "%s: PRE token = %d '%s'\n", __func__, vocab.special_prefix_id, vocab.id_to_token[vocab.special_prefix_id].text.c_str() ); }
if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); }
} }
// Returns false if cancelled by progress_callback // Returns false if cancelled by progress_callback
@ -13072,16 +13088,14 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
GGML_ASSERT(ctx); GGML_ASSERT(ctx);
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
bool allow_eos = false; bool allow_eog = false;
for (const auto & stack : grammar->stacks) { for (const auto & stack : grammar->stacks) {
if (stack.empty()) { if (stack.empty()) {
allow_eos = true; allow_eog = true;
break; break;
} }
} }
const llama_token eos = llama_token_eos(&ctx->model);
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded; std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
candidates_decoded.reserve(candidates->size); candidates_decoded.reserve(candidates->size);
std::vector<llama_grammar_candidate> candidates_grammar; std::vector<llama_grammar_candidate> candidates_grammar;
@ -13090,8 +13104,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id; const llama_token id = candidates->data[i].id;
const std::string piece = llama_token_to_piece(ctx, id); const std::string piece = llama_token_to_piece(ctx, id);
if (id == eos) { if (llama_token_is_eog(&ctx->model, id)) {
if (!allow_eos) { if (!allow_eog) {
candidates->data[i].logit = -INFINITY; candidates->data[i].logit = -INFINITY;
} }
} else if (piece.empty() || piece[0] == 0) { } else if (piece.empty() || piece[0] == 0) {
@ -13280,7 +13294,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
if (token == llama_token_eos(&ctx->model)) { if (llama_token_is_eog(&ctx->model, token)) {
for (const auto & stack : grammar->stacks) { for (const auto & stack : grammar->stacks) {
if (stack.empty()) { if (stack.empty()) {
return; return;
@ -16683,6 +16697,13 @@ llama_token_type llama_token_get_type(const struct llama_model * model, llama_to
return model->vocab.id_to_token[token].type; return model->vocab.id_to_token[token].type;
} }
bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
return token != -1 && (
token == llama_token_eos(model) ||
token == llama_token_eot(model)
);
}
llama_token llama_token_bos(const struct llama_model * model) { llama_token llama_token_bos(const struct llama_model * model) {
return model->vocab.special_bos_id; return model->vocab.special_bos_id;
} }

View file

@ -783,6 +783,9 @@ extern "C" {
LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
// Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
// Special tokens // Special tokens
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
@ -796,7 +799,7 @@ extern "C" {
// Returns -1 if unknown, 1 for true or 0 for false. // Returns -1 if unknown, 1 for true or 0 for false.
LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model); LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
// codellama infill tokens // Codellama infill tokens
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix