llama.cpp : split llama_context_params into model and context params (#3301)
* llama.cpp : split llama_context_params into model and context params ggml-ci * fix metal build * fix freq_base/scale default to model value * llama-bench : keep the same model between tests when possible * move n_threads to llama_context_params, add n_threads_batch * fix mpi build * remove kv_size(), cuda scratch fixes * remove low-vram option * add n_threads_batch to system info, refactor to get_system_info() * add documentation about --threads-batch to the READMEs * llama-bench fix * main : fix rope freq/scale warning * llama.cpp : add llama_get_model common : add llama_tokenize from model * remove duplicated ctx/model functions ggml-ci * cuda : print total VRAM used
This commit is contained in:
parent
0512d66670
commit
16bc66d947
27 changed files with 713 additions and 633 deletions
|
@ -23,23 +23,17 @@ int main(int argc, char ** argv) {
|
|||
params.n_predict = 16;
|
||||
}
|
||||
|
||||
auto lparams = llama_context_default_params();
|
||||
|
||||
lparams.n_ctx = params.n_ctx;
|
||||
lparams.seed = params.seed;
|
||||
lparams.f16_kv = params.memory_f16;
|
||||
lparams.use_mmap = params.use_mmap;
|
||||
lparams.use_mlock = params.use_mlock;
|
||||
|
||||
auto n_past = 0;
|
||||
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
|
||||
|
||||
// init
|
||||
auto * model = llama_load_model_from_file(params.model.c_str(), lparams);
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
|
||||
std::tie(model, ctx) = llama_init_from_gpt_params( params );
|
||||
if (model == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
auto * ctx = llama_new_context_with_model(model, lparams);
|
||||
if (ctx == nullptr) {
|
||||
llama_free_model(model);
|
||||
return 1;
|
||||
|
@ -54,7 +48,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// evaluate prompt
|
||||
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads);
|
||||
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0));
|
||||
|
||||
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
|
||||
n_past += n_prompt_tokens;
|
||||
|
@ -79,7 +73,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto * logits = llama_get_logits(ctx);
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
|
@ -91,7 +85,7 @@ int main(int argc, char ** argv) {
|
|||
last_n_tokens_data.push_back(next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
@ -106,7 +100,7 @@ int main(int argc, char ** argv) {
|
|||
llama_free(ctx);
|
||||
|
||||
// make new context
|
||||
auto * ctx2 = llama_new_context_with_model(model, lparams);
|
||||
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||
|
||||
// Load state (rng, logits, embedding and kv_cache) from file
|
||||
{
|
||||
|
@ -139,7 +133,7 @@ int main(int argc, char ** argv) {
|
|||
// second run
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto * logits = llama_get_logits(ctx2);
|
||||
auto n_vocab = llama_n_vocab(ctx2);
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
|
@ -151,7 +145,7 @@ int main(int argc, char ** argv) {
|
|||
last_n_tokens_data.push_back(next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_free(ctx2);
|
||||
llama_free_model(model);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue