llama.cpp : split llama_context_params into model and context params (#3301)

* llama.cpp : split llama_context_params into model and context params

ggml-ci

* fix metal build

* fix freq_base/scale default to model value

* llama-bench : keep the same model between tests when possible

* move n_threads to llama_context_params, add n_threads_batch

* fix mpi build

* remove kv_size(), cuda scratch fixes

* remove low-vram option

* add n_threads_batch to system info, refactor to get_system_info()

* add documentation about --threads-batch to the READMEs

* llama-bench fix

* main : fix rope freq/scale warning

* llama.cpp : add llama_get_model
common : add llama_tokenize from model

* remove duplicated ctx/model functions

ggml-ci

* cuda : print total VRAM used
This commit is contained in:
slaren 2023-09-28 21:42:38 +02:00 committed by GitHub
parent 0512d66670
commit 16bc66d947
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 713 additions and 633 deletions

View file

@ -150,16 +150,18 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
if (int(tokens.size()) < 2*params.n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*params.n_ctx,
params.n_ctx);
const int n_ctx = llama_n_ctx(ctx);
if (int(tokens.size()) < 2*n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
n_ctx);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return {std::move(tokens), 0., {}, {}};
}
@ -175,20 +177,20 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
return {tokens, -1, logit_history, prob_history};
}
const int calc_chunk = params.n_ctx;
const int calc_chunk = n_ctx;
fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
if (int(tokens.size()) <= calc_chunk) {
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
tokens.size(), params.n_ctx, params.ppl_stride);
tokens.size(), n_ctx, params.ppl_stride);
return {tokens, -1, logit_history, prob_history};
}
const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_batch = params.n_batch;
int count = 0;
@ -215,7 +217,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const int batch_size = std::min(end - batch_start, n_batch);
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
//fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
@ -250,7 +252,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
}
//fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
for (int j = params.n_ctx - params.ppl_stride - 1; j < params.n_ctx - 1; ++j) {
for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
// Calculate probability of next token, given the previous ones.
const std::vector<float> tok_logits(
@ -287,8 +289,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;
const int n_ctx = llama_n_ctx(ctx);
auto tim1 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
@ -298,9 +301,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
auto tim2 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
if (int(tokens.size()) < 2*params.n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*params.n_ctx,
params.n_ctx);
if (int(tokens.size()) < 2*n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
n_ctx);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return {std::move(tokens), 0., {}, {}};
}
@ -311,10 +314,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<float> prob_history;
prob_history.resize(tokens.size());
const int n_chunk_max = tokens.size() / params.n_ctx;
const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_batch = params.n_batch;
int count = 0;
@ -326,10 +329,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
for (int i = 0; i < n_chunk; ++i) {
const int start = i * params.n_ctx;
const int end = start + params.n_ctx;
const int start = i * n_ctx;
const int end = start + n_ctx;
const int num_batches = (params.n_ctx + n_batch - 1) / n_batch;
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
std::vector<float> logits;
@ -350,7 +353,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
tokens[batch_start] = llama_token_bos(ctx);
}
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
@ -358,7 +361,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;
const auto batch_logits = llama_get_logits(ctx);
const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
@ -387,10 +390,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// Example, we have a context window of 512, we will compute perplexity for each of the
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.
const int first = params.n_ctx/2;
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first,
const int first = n_ctx/2;
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += params.n_ctx - first - 1;
count += n_ctx - first - 1;
// perplexity is e^(average negative log-likelihood)
if (params.ppl_output_type == 0) {
@ -399,7 +402,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
double av = nll/count;
double av2 = nll2/count - av*av;
if (av2 > 0) av2 = sqrt(av2/(count-1));
printf("%8d %.4lf %4lf %4lf\n", i*params.n_ctx, std::exp(nll / count), av, av2);
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
}
fflush(stdout);
}
@ -420,7 +423,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
}
static std::vector<float> hellaswag_evaluate_tokens(
llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab, int n_thread
llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab
) {
std::vector<float> result;
result.reserve(tokens.size() * n_vocab);
@ -428,7 +431,7 @@ static std::vector<float> hellaswag_evaluate_tokens(
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
size_t n_tokens = tokens.size() - i_chunk * n_batch;
n_tokens = std::min(n_tokens, size_t(n_batch));
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) {
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {};
}
@ -475,7 +478,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
size_t hs_task_count = prompt_lines.size()/6;
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
fprintf(stderr, "================================= is_spm = %d\n", is_spm);
// This is needed as usual for LLaMA models
@ -530,7 +533,8 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
printf("\ntask\tacc_norm\n");
double acc = 0.0f;
const int n_vocab = llama_n_vocab(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
std::vector<std::vector<int>> ending_tokens(4);
@ -558,7 +562,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
auto query_size = query_embd.size();
// Stop if query wont fit the ctx window
if (query_size > (size_t)params.n_ctx) {
if (query_size > (size_t)n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
}
@ -571,7 +575,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
// clear the KV cache
llama_kv_cache_tokens_rm(ctx, -1, -1);
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
@ -608,7 +612,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
query_size = query_embd.size();
// Stop if query wont fit the ctx window
if (context_size + query_size > (size_t)params.n_ctx) {
if (context_size + query_size > (size_t)n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
}
@ -620,7 +624,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
//}
// Evaluate the query
logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads);
logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
@ -716,7 +720,7 @@ int main(int argc, char ** argv) {
return 1;
}
const int n_ctx_train = llama_n_ctx_train(ctx);
const int n_ctx_train = llama_n_ctx_train(model);
if (params.n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);
@ -725,8 +729,7 @@ int main(int argc, char ** argv) {
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
fprintf(stderr, "%s\n", get_system_info(params).c_str());
}
struct results_perplexity results;