llama : avoid hardcoded special tokens
This commit is contained in:
parent
035d511457
commit
5d2656d670
11 changed files with 61 additions and 65 deletions
|
@ -143,7 +143,7 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx);
|
||||
|
||||
const std::vector<llama_token> tmp(params.n_batch, llama_token_bos());
|
||||
const std::vector<llama_token> tmp(params.n_batch, llama_token_bos(ctx));
|
||||
llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads);
|
||||
}
|
||||
|
||||
|
@ -345,10 +345,9 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "\n");
|
||||
|
||||
{
|
||||
auto it = params.logit_bias.find(llama_token_eos());
|
||||
auto it = params.logit_bias.find(llama_token_eos(ctx));
|
||||
if (it != params.logit_bias.end() && it->second == -INFINITY) {
|
||||
fprintf(stderr,
|
||||
"%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
|
||||
fprintf(stderr, "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -398,7 +397,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// do one empty run to warm up the model
|
||||
{
|
||||
const std::vector<llama_token> tmp = { llama_token_bos(), };
|
||||
const std::vector<llama_token> tmp = { llama_token_bos(ctx), };
|
||||
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
|
||||
llama_reset_timings(ctx);
|
||||
}
|
||||
|
@ -582,7 +581,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// Apply penalties
|
||||
float nl_logit = logits[llama_token_nl()];
|
||||
float nl_logit = logits[llama_token_nl(ctx)];
|
||||
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
|
||||
llama_sample_repetition_penalty(ctx, &candidates_p,
|
||||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||
|
@ -591,7 +590,7 @@ int main(int argc, char ** argv) {
|
|||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||
last_n_repeat, alpha_frequency, alpha_presence);
|
||||
if (!penalize_nl) {
|
||||
logits[llama_token_nl()] = nl_logit;
|
||||
logits[llama_token_nl(ctx)] = nl_logit;
|
||||
}
|
||||
|
||||
if (grammar != NULL) {
|
||||
|
@ -697,7 +696,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// deal with end of text token in interactive mode
|
||||
if (last_n_tokens.back() == llama_token_eos()) {
|
||||
if (last_n_tokens.back() == llama_token_eos(ctx)) {
|
||||
if (params.interactive) {
|
||||
if (params.antiprompt.size() != 0) {
|
||||
// tokenize and inject first reverse prompt
|
||||
|
@ -721,7 +720,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
if (params.input_prefix_bos) {
|
||||
embd_inp.push_back(llama_token_bos());
|
||||
embd_inp.push_back(llama_token_bos(ctx));
|
||||
}
|
||||
|
||||
std::string buffer;
|
||||
|
@ -786,7 +785,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// end of text token
|
||||
if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) {
|
||||
if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !(params.instruct || params.interactive)) {
|
||||
fprintf(stderr, " [end of text]\n");
|
||||
break;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue