llama : avoid hardcoded special tokens
This commit is contained in:
parent
035d511457
commit
5d2656d670
11 changed files with 61 additions and 65 deletions
|
@ -851,7 +851,7 @@ struct sql_printer : public printer {
|
|||
};
|
||||
|
||||
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
|
||||
std::vector<llama_token> tokens(n_batch, llama_token_bos());
|
||||
std::vector<llama_token> tokens(n_batch, llama_token_bos(ctx));
|
||||
int n_processed = 0;
|
||||
while (n_processed < n_prompt) {
|
||||
int n_tokens = std::min(n_prompt - n_processed, n_batch);
|
||||
|
@ -861,7 +861,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
|
|||
}
|
||||
|
||||
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
|
||||
llama_token token = llama_token_bos();
|
||||
llama_token token = llama_token_bos(ctx);
|
||||
for (int i = 0; i < n_gen; i++) {
|
||||
llama_eval(ctx, &token, 1, n_past + i, n_threads);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue