llama : avoid hardcoded special tokens

This commit is contained in:
Georgi Gerganov 2023-08-18 17:29:20 +03:00
parent 035d511457
commit 5d2656d670
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
11 changed files with 61 additions and 65 deletions

View file

@ -851,7 +851,7 @@ struct sql_printer : public printer {
};
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
std::vector<llama_token> tokens(n_batch, llama_token_bos());
std::vector<llama_token> tokens(n_batch, llama_token_bos(ctx));
int n_processed = 0;
while (n_processed < n_prompt) {
int n_tokens = std::min(n_prompt - n_processed, n_batch);
@ -861,7 +861,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
}
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
llama_token token = llama_token_bos();
llama_token token = llama_token_bos(ctx);
for (int i = 0; i < n_gen; i++) {
llama_eval(ctx, &token, 1, n_past + i, n_threads);
}