llama.cpp : split llama_context_params into model and context params (#3301)

* llama.cpp : split llama_context_params into model and context params

ggml-ci

* fix metal build

* fix freq_base/scale default to model value

* llama-bench : keep the same model between tests when possible

* move n_threads to llama_context_params, add n_threads_batch

* fix mpi build

* remove kv_size(), cuda scratch fixes

* remove low-vram option

* add n_threads_batch to system info, refactor to get_system_info()

* add documentation about --threads-batch to the READMEs

* llama-bench fix

* main : fix rope freq/scale warning

* llama.cpp : add llama_get_model
common : add llama_tokenize from model

* remove duplicated ctx/model functions

ggml-ci

* cuda : print total VRAM used
This commit is contained in:
slaren 2023-09-28 21:42:38 +02:00 committed by GitHub
parent 0512d66670
commit 16bc66d947
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 713 additions and 633 deletions

View file

@ -108,7 +108,7 @@ int main(int argc, char ** argv) {
fflush(stderr);
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(ctx);
const int n_vocab = llama_n_vocab(model);
std::vector<client> clients(n_clients);
for (size_t i = 0; i < clients.size(); ++i) {
@ -153,7 +153,7 @@ int main(int argc, char ** argv) {
batch.logits[i] = false;
}
if (llama_decode(ctx, batch, params.n_threads) != 0) {
if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
@ -272,7 +272,7 @@ int main(int argc, char ** argv) {
0, 0, 0, // unused
};
const int ret = llama_decode(ctx, batch_view, params.n_threads);
const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size