llama : remove token functions with context
args in favor of model
(#3720)
* added `llama_model_token_*` variants to all the `llama_token_*` functions. * added `LLAMA_API` * formatting Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * removed old `llama_token` functions * changed 3 more functions to take in model - `llama_token_get_text` - `llama_token_get_score` - `llama_token_get_type` * added back docs * fixed main.cpp * changed token functions to use new model variants * changed token functions to use new model variants --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
6336701c93
commit
5be6c803fa
16 changed files with 81 additions and 79 deletions
|
@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
|
|||
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
|
||||
// is it an end of stream? -> mark the stream as finished
|
||||
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
|
||||
if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
|
||||
i_batch[i] = -1;
|
||||
LOG_TEE("\n");
|
||||
if (n_parallel > 1) {
|
||||
|
|
|
@ -47,7 +47,7 @@ struct beam_search_callback_data {
|
|||
// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
|
||||
// For example, eob can be flagged due to maximum token length, stop words, etc.
|
||||
static bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, size_t n_tokens) {
|
||||
return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx);
|
||||
return n_tokens && tokens[n_tokens-1] == llama_token_eos(llama_get_model(callback_data.ctx));
|
||||
}
|
||||
|
||||
// Function matching type llama_beam_search_callback_fn_t.
|
||||
|
|
|
@ -246,14 +246,14 @@ int main(int argc, char ** argv) {
|
|||
if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
|
||||
inp_sfx.erase(inp_sfx.begin());
|
||||
}
|
||||
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
|
||||
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
|
||||
if (add_bos) {
|
||||
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx));
|
||||
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(model));
|
||||
}
|
||||
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
|
||||
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
|
||||
embd_inp = inp_pfx;
|
||||
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
|
||||
embd_inp.push_back(llama_token_middle(ctx));
|
||||
embd_inp.push_back(llama_token_middle(model));
|
||||
|
||||
LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
|
||||
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
|
||||
|
@ -261,7 +261,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// Should not run without any tokens
|
||||
if (embd_inp.empty()) {
|
||||
embd_inp.push_back(llama_token_bos(ctx));
|
||||
embd_inp.push_back(llama_token_bos(model));
|
||||
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
||||
}
|
||||
|
||||
|
@ -577,10 +577,10 @@ int main(int argc, char ** argv) {
|
|||
if ((int) embd_inp.size() <= n_consumed) {
|
||||
|
||||
// deal with eot token in infill mode
|
||||
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(ctx) || is_interacting) && params.interactive){
|
||||
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
|
||||
if(is_interacting && !params.interactive_first) {
|
||||
// print an eot token
|
||||
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
|
||||
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
|
||||
}
|
||||
fflush(stdout);
|
||||
printf("\n");
|
||||
|
@ -627,14 +627,14 @@ int main(int argc, char ** argv) {
|
|||
if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
|
||||
inp_sfx.erase(inp_sfx.begin());
|
||||
}
|
||||
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx));
|
||||
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
|
||||
if (add_bos) {
|
||||
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx));
|
||||
inp_pfx.insert(inp_pfx.begin(), llama_token_bos(model));
|
||||
}
|
||||
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx));
|
||||
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
|
||||
embd_inp = inp_pfx;
|
||||
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
|
||||
embd_inp.push_back(llama_token_middle(ctx));
|
||||
embd_inp.push_back(llama_token_middle(model));
|
||||
embd.clear();
|
||||
embd_guidance.clear();
|
||||
n_remain = params.n_predict;
|
||||
|
@ -644,7 +644,7 @@ int main(int argc, char ** argv) {
|
|||
is_interacting = false;
|
||||
}
|
||||
// deal with end of text token in interactive mode
|
||||
else if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
|
||||
else if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) {
|
||||
LOG("found EOS token\n");
|
||||
|
||||
if (params.interactive) {
|
||||
|
@ -661,7 +661,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if (params.input_prefix_bos) {
|
||||
LOG("adding input prefix BOS token\n");
|
||||
embd_inp.push_back(llama_token_bos(ctx));
|
||||
embd_inp.push_back(llama_token_bos(model));
|
||||
}
|
||||
|
||||
std::string buffer;
|
||||
|
@ -724,7 +724,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// end of text token
|
||||
if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !params.interactive) {
|
||||
if (!embd.empty() && embd.back() == llama_token_eos(model) && !params.interactive) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -736,7 +736,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
if (!params.interactive && n_remain <= 0) {
|
||||
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
|
||||
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
|
|
|
@ -933,7 +933,7 @@ struct sql_printer : public printer {
|
|||
};
|
||||
|
||||
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
|
||||
std::vector<llama_token> tokens(n_batch, llama_token_bos(ctx));
|
||||
std::vector<llama_token> tokens(n_batch, llama_token_bos(llama_get_model(ctx)));
|
||||
int n_processed = 0;
|
||||
|
||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||
|
@ -946,7 +946,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
|
|||
}
|
||||
|
||||
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
|
||||
llama_token token = llama_token_bos(ctx);
|
||||
llama_token token = llama_token_bos(llama_get_model(ctx));
|
||||
|
||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||
|
||||
|
|
|
@ -137,7 +137,7 @@ inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
|
|||
inline const char * sample(struct llama_context * ctx_llama, gpt_params & params, int * n_past) {
|
||||
int id = sample_id(ctx_llama, params);
|
||||
static std::string ret;
|
||||
if (id == llama_token_eos(ctx_llama)) {
|
||||
if (id == llama_token_eos(llama_get_model(ctx_llama))) {
|
||||
ret = "</s>";
|
||||
} else {
|
||||
ret = llama_token_to_piece(ctx_llama, id);
|
||||
|
|
|
@ -248,7 +248,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// Should not run without any tokens
|
||||
if (embd_inp.empty()) {
|
||||
embd_inp.push_back(llama_token_bos(ctx));
|
||||
embd_inp.push_back(llama_token_bos(model));
|
||||
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
||||
}
|
||||
|
||||
|
@ -693,7 +693,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// deal with end of text token in interactive mode
|
||||
if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
|
||||
if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) {
|
||||
LOG("found EOS token\n");
|
||||
|
||||
if (params.interactive) {
|
||||
|
@ -720,7 +720,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if (params.input_prefix_bos) {
|
||||
LOG("adding input prefix BOS token\n");
|
||||
embd_inp.push_back(llama_token_bos(ctx));
|
||||
embd_inp.push_back(llama_token_bos(model));
|
||||
}
|
||||
|
||||
std::string buffer;
|
||||
|
@ -804,7 +804,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// end of text token
|
||||
if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !(params.instruct || params.interactive)) {
|
||||
if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive)) {
|
||||
LOG_TEE(" [end of text]\n");
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -347,7 +347,7 @@ int main(int argc, char ** argv) {
|
|||
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
|
||||
|
||||
if (client.n_decoded > 2 &&
|
||||
(id == llama_token_eos(ctx) ||
|
||||
(id == llama_token_eos(model) ||
|
||||
(params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) ||
|
||||
client.response.find("User:") != std::string::npos ||
|
||||
client.response.find('\n') != std::string::npos)) {
|
||||
|
|
|
@ -227,7 +227,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|||
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (add_bos && j == 0) {
|
||||
tokens[batch_start] = llama_token_bos(ctx);
|
||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
||||
}
|
||||
|
||||
const auto batch_logits = llama_get_logits(ctx);
|
||||
|
@ -350,7 +350,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (add_bos && j == 0) {
|
||||
tokens[batch_start] = llama_token_bos(ctx);
|
||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
||||
|
|
|
@ -726,7 +726,7 @@ struct llama_server_context
|
|||
|
||||
if (json_value(data, "ignore_eos", false))
|
||||
{
|
||||
slot->sparams.logit_bias[llama_token_eos(ctx)] = -INFINITY;
|
||||
slot->sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
||||
}
|
||||
|
||||
const auto &logit_bias = data.find("logit_bias");
|
||||
|
@ -1056,7 +1056,7 @@ struct llama_server_context
|
|||
slot.has_next_token = false;
|
||||
}
|
||||
|
||||
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx))
|
||||
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model))
|
||||
{
|
||||
slot.stopped_eos = true;
|
||||
slot.has_next_token = false;
|
||||
|
@ -1130,7 +1130,7 @@ struct llama_server_context
|
|||
|
||||
json get_formated_generation(llama_client_slot &slot)
|
||||
{
|
||||
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(ctx));
|
||||
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
|
||||
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() &&
|
||||
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
||||
return json {
|
||||
|
@ -1555,11 +1555,11 @@ struct llama_server_context
|
|||
suffix_tokens.erase(suffix_tokens.begin());
|
||||
}
|
||||
|
||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));
|
||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS
|
||||
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
|
||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
|
||||
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
|
||||
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
|
||||
prefix_tokens.push_back(llama_token_middle(ctx));
|
||||
prefix_tokens.push_back(llama_token_middle(model));
|
||||
prompt_tokens = prefix_tokens;
|
||||
}
|
||||
else
|
||||
|
|
|
@ -138,7 +138,7 @@ int main(int argc, char ** argv) {
|
|||
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
|
||||
// is it an end of stream?
|
||||
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
|
||||
if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
|
||||
LOG_TEE("\n");
|
||||
|
||||
break;
|
||||
|
|
|
@ -163,7 +163,7 @@ int main(int argc, char ** argv) {
|
|||
printf("%s", token_str.c_str());
|
||||
fflush(stdout);
|
||||
|
||||
if (id == llama_token_eos(ctx_tgt)) {
|
||||
if (id == llama_token_eos(model_tgt)) {
|
||||
has_eos = true;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue