llama.cpp : split llama_context_params into model and context params

ggml-ci
This commit is contained in:
slaren 2023-09-21 18:13:06 +02:00
parent 324f3403d5
commit cf1f80596d
16 changed files with 370 additions and 337 deletions

View file

@ -722,40 +722,49 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
// Model utils
//
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params();
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
auto mparams = llama_model_default_params();
lparams.n_ctx = params.n_ctx;
lparams.n_batch = params.n_batch;
if (params.n_gpu_layers != -1) {
lparams.n_gpu_layers = params.n_gpu_layers;
mparams.n_gpu_layers = params.n_gpu_layers;
}
lparams.main_gpu = params.main_gpu;
lparams.tensor_split = params.tensor_split;
lparams.low_vram = params.low_vram;
lparams.mul_mat_q = params.mul_mat_q;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;
lparams.rope_freq_base = params.rope_freq_base;
lparams.rope_freq_scale = params.rope_freq_scale;
mparams.main_gpu = params.main_gpu;
mparams.tensor_split = params.tensor_split;
mparams.low_vram = params.low_vram;
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
return lparams;
return mparams;
}
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto cparams = llama_context_default_params();
cparams.n_ctx = params.n_ctx;
cparams.n_batch = params.n_batch;
cparams.low_vram = params.low_vram;
cparams.mul_mat_q = params.mul_mat_q;
cparams.seed = params.seed;
cparams.f16_kv = params.memory_f16;
cparams.logits_all = params.perplexity;
cparams.embedding = params.embedding;
cparams.rope_freq_base = params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale;
return cparams;
}
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto lparams = llama_context_params_from_gpt_params(params);
auto mparams = llama_model_params_from_gpt_params(params);
auto cparams = llama_context_params_from_gpt_params(params);
llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return std::make_tuple(nullptr, nullptr);
}
llama_context * lctx = llama_new_context_with_model(model, lparams);
llama_context * lctx = llama_new_context_with_model(model, cparams);
if (lctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model);

View file

@ -131,6 +131,7 @@ std::string gpt_random_prompt(std::mt19937 & rng);
//
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
//

View file

@ -43,9 +43,11 @@ int main(int argc, char ** argv) {
}
const int n_ctx_train = llama_n_ctx_train(ctx);
if (params.n_ctx > n_ctx_train) {
const int n_ctx = llama_n_ctx(ctx);
if (n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);
__func__, n_ctx_train, n_ctx);
}
// print system information
@ -70,9 +72,9 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n");
}
if (embd_inp.size() > (size_t)params.n_ctx) {
if (embd_inp.size() > (size_t)n_ctx) {
fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
__func__, embd_inp.size(), params.n_ctx);
__func__, embd_inp.size(), n_ctx);
return 1;
}

View file

@ -356,18 +356,27 @@ struct cmd_params_instance {
bool low_vram;
std::array<float, LLAMA_MAX_DEVICES> tensor_split;
llama_context_params to_llama_params() const {
llama_context_params lparams = llama_context_default_params();
lparams.n_ctx = n_prompt + n_gen;
lparams.n_batch = n_batch;
lparams.f16_kv = !f32_kv;
lparams.n_gpu_layers = n_gpu_layers;
lparams.main_gpu = main_gpu;
lparams.mul_mat_q = mul_mat_q;
lparams.low_vram = low_vram;
lparams.tensor_split = tensor_split.data();
llama_model_params to_llama_mparams() const {
llama_model_params mparams = llama_model_default_params();
return lparams;
mparams.n_gpu_layers = n_gpu_layers;
mparams.main_gpu = main_gpu;
mparams.low_vram = low_vram;
mparams.tensor_split = tensor_split.data();
return mparams;
}
llama_context_params to_llama_cparams() const {
llama_context_params cparams = llama_context_default_params();
cparams.n_ctx = n_prompt + n_gen;
cparams.n_batch = n_batch;
cparams.f16_kv = !f32_kv;
cparams.mul_mat_q = mul_mat_q;
cparams.low_vram = low_vram;
return cparams;
}
};
@ -960,15 +969,13 @@ int main(int argc, char ** argv) {
for (const auto & inst : params_instances) {
// TODO: keep the model between tests when possible
llama_context_params lparams = inst.to_llama_params();
llama_model * lmodel = llama_load_model_from_file(inst.model.c_str(), lparams);
llama_model * lmodel = llama_load_model_from_file(inst.model.c_str(), inst.to_llama_mparams());
if (lmodel == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, inst.model.c_str());
return 1;
}
llama_context * ctx = llama_new_context_with_model(lmodel, lparams);
llama_context * ctx = llama_new_context_with_model(lmodel, inst.to_llama_cparams());
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str());
llama_free_model(lmodel);

View file

@ -140,6 +140,11 @@ int main(int argc, char ** argv) {
return 0;
}
if (params.n_ctx != 0 && params.n_ctx < 8) {
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
}
if (params.rope_freq_base != 10000.0) {
LOG_TEE("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
}
@ -185,12 +190,12 @@ int main(int argc, char ** argv) {
}
const int n_ctx_train = llama_n_ctx_train(ctx);
if (params.n_ctx > n_ctx_train) {
const int n_ctx = llama_n_ctx(ctx);
LOG("n_ctx: %d\n", n_ctx);
if (n_ctx > n_ctx_train) {
LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);
} else if (params.n_ctx < 8) {
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
__func__, n_ctx_train, n_ctx);
}
// print system information
@ -220,7 +225,7 @@ int main(int argc, char ** argv) {
if (fp != NULL) {
std::fclose(fp);
session_tokens.resize(params.n_ctx);
session_tokens.resize(n_ctx);
size_t n_token_count_out = 0;
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
@ -276,9 +281,6 @@ int main(int argc, char ** argv) {
LOG("guidance_offset: %s", log_tostr(guidance_offset));
}
const int n_ctx = llama_n_ctx(ctx);
LOG("n_ctx: %d\n", n_ctx);
if ((int) embd_inp.size() > n_ctx - 4) {
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
return 1;

View file

@ -145,9 +145,11 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
if (int(tokens.size()) < 2*params.n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*params.n_ctx,
params.n_ctx);
const int n_ctx = llama_n_ctx(ctx);
if (int(tokens.size()) < 2*n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
n_ctx);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return {std::move(tokens), 0., {}, {}};
}
@ -163,13 +165,13 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
return {tokens, -1, logit_history, prob_history};
}
const int calc_chunk = params.n_ctx;
const int calc_chunk = n_ctx;
fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
if (int(tokens.size()) <= calc_chunk) {
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
tokens.size(), params.n_ctx, params.ppl_stride);
tokens.size(), n_ctx, params.ppl_stride);
return {tokens, -1, logit_history, prob_history};
}
@ -235,7 +237,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
}
//fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
for (int j = params.n_ctx - params.ppl_stride - 1; j < params.n_ctx - 1; ++j) {
for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
// Calculate probability of next token, given the previous ones.
const std::vector<float> tok_logits(
@ -274,6 +276,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;
const int n_ctx = llama_n_ctx(ctx);
auto tim1 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
@ -283,9 +286,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
auto tim2 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
if (int(tokens.size()) < 2*params.n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*params.n_ctx,
params.n_ctx);
if (int(tokens.size()) < 2*n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
n_ctx);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return {std::move(tokens), 0., {}, {}};
}
@ -296,7 +299,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<float> prob_history;
prob_history.resize(tokens.size());
const int n_chunk_max = tokens.size() / params.n_ctx;
const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(ctx);
@ -311,10 +314,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
for (int i = 0; i < n_chunk; ++i) {
const int start = i * params.n_ctx;
const int end = start + params.n_ctx;
const int start = i * n_ctx;
const int end = start + n_ctx;
const int num_batches = (params.n_ctx + n_batch - 1) / n_batch;
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
std::vector<float> logits;
@ -369,10 +372,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// Example, we have a context window of 512, we will compute perplexity for each of the
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.
const int first = params.n_ctx/2;
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first,
const int first = n_ctx/2;
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += params.n_ctx - first - 1;
count += n_ctx - first - 1;
// perplexity is e^(average negative log-likelihood)
if (params.ppl_output_type == 0) {
@ -381,7 +384,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
double av = nll/count;
double av2 = nll2/count - av*av;
if (av2 > 0) av2 = sqrt(av2/(count-1));
printf("%8d %.4lf %4lf %4lf\n", i*params.n_ctx, std::exp(nll / count), av, av2);
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
}
fflush(stdout);
}
@ -513,6 +516,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
double acc = 0.0f;
const int n_vocab = llama_n_vocab(ctx);
const int n_ctx = llama_n_ctx(ctx);
std::vector<std::vector<int>> ending_tokens(4);
@ -540,7 +544,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
auto query_size = query_embd.size();
// Stop if query wont fit the ctx window
if (query_size > (size_t)params.n_ctx) {
if (query_size > (size_t)n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
}
@ -587,7 +591,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
query_size = query_embd.size();
// Stop if query wont fit the ctx window
if (context_size + query_size > (size_t)params.n_ctx) {
if (context_size + query_size > (size_t)n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
}

View file

@ -309,21 +309,22 @@ int main(int argc, char ** argv) {
llama_context * ctx;
{
auto lparams = llama_context_default_params();
auto mparams = llama_model_default_params();
mparams.use_mlock = false;
lparams.n_ctx = 256;
lparams.seed = 1;
lparams.f16_kv = false;
lparams.use_mlock = false;
model = llama_load_model_from_file(params.model.c_str(), lparams);
model = llama_load_model_from_file(params.model.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return 1;
}
ctx = llama_new_context_with_model(model, lparams);
auto cparams = llama_context_default_params();
cparams.n_ctx = 256;
cparams.seed = 1;
cparams.f16_kv = false;
ctx = llama_new_context_with_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());

View file

@ -23,23 +23,17 @@ int main(int argc, char ** argv) {
params.n_predict = 16;
}
auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
auto n_past = 0;
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
// init
auto model = llama_load_model_from_file(params.model.c_str(), lparams);
llama_model * model;
llama_context * ctx;
std::tie(model, ctx) = llama_init_from_gpt_params( params );
if (model == nullptr) {
return 1;
}
auto ctx = llama_new_context_with_model(model, lparams);
if (ctx == nullptr) {
llama_free_model(model);
return 1;
@ -106,7 +100,7 @@ int main(int argc, char ** argv) {
llama_free(ctx);
// make new context
auto ctx2 = llama_new_context_with_model(model, lparams);
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
// Load state (rng, logits, embedding and kv_cache) from file
{

View file

@ -200,6 +200,7 @@ struct llama_server_context
llama_model *model = nullptr;
llama_context *ctx = nullptr;
gpt_params params;
int n_ctx;
grammar_parser::parse_state parsed_grammar;
llama_grammar *grammar = nullptr;
@ -239,7 +240,7 @@ struct llama_server_context
num_prompt_tokens = 0;
num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(params.n_ctx);
generated_text.reserve(n_ctx);
generated_token_probs.clear();
truncated = false;
stopped_eos = false;
@ -265,8 +266,8 @@ struct llama_server_context
LOG_ERROR("unable to load model", {{"model", params_.model}});
return false;
}
last_n_tokens.resize(params.n_ctx);
n_ctx = llama_n_ctx(ctx);
last_n_tokens.resize(n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
return true;
}
@ -351,19 +352,19 @@ struct llama_server_context
{
params.n_keep = (int)num_prompt_tokens;
}
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
params.n_keep = std::min(n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal
if (num_prompt_tokens >= (size_t)params.n_ctx)
if (num_prompt_tokens >= (size_t)n_ctx)
{
const int n_left = (params.n_ctx - params.n_keep) / 2;
const int n_left = (n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
LOG_VERBOSE("input truncated", {
{"n_ctx", params.n_ctx},
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
@ -409,10 +410,10 @@ struct llama_server_context
completion_token_output result;
result.tok = -1;
if (embd.size() >= (size_t)params.n_ctx)
if (embd.size() >= (size_t)n_ctx)
{
// Reset context
const int n_left = (params.n_ctx - params.n_keep) / 2;
const int n_left = (n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep);
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
@ -420,7 +421,7 @@ struct llama_server_context
n_past = params.n_keep;
truncated = true;
LOG_VERBOSE("input truncated", {
{"n_ctx", params.n_ctx},
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
@ -461,7 +462,7 @@ struct llama_server_context
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
@ -492,7 +493,7 @@ struct llama_server_context
// Apply penalties
float nl_logit = logits[llama_token_nl(ctx)];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
@ -1002,7 +1003,7 @@ static json format_generation_settings(llama_server_context &llama)
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
return json{
{"n_ctx", llama.params.n_ctx},
{"n_ctx", llama.n_ctx},
{"model", llama.params.model_alias},
{"seed", llama.params.seed},
{"temp", llama.params.temp},

View file

@ -30,16 +30,14 @@ int main(int argc, char ** argv) {
llama_backend_init(params.numa);
llama_context_params ctx_params = llama_context_default_params();
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
llama_model * model = llama_load_model_from_file(params.model.c_str(), llama_model_default_params());
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
llama_context * ctx = llama_new_context_with_model(model, llama_context_default_params());
// tokenize the prompt

View file

@ -2001,11 +2001,13 @@ int main(int argc, char ** argv) {
printf("%s: seed: %u\n", __func__, params.seed);
srand(params.seed);
struct llama_context_params llama_params = llama_context_default_params();
llama_params.vocab_only = true;
struct llama_model_params mparams = llama_model_default_params();
mparams.vocab_only = true;
struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, llama_params);
struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
struct llama_context_params cparams = llama_context_default_params();
struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, mparams);
struct llama_context * lctx = llama_new_context_with_model(lmodel, cparams);
printf("%s: tokenize training data\n", __func__);
std::vector<llama_token> train_tokens;

401
llama.cpp
View file

@ -930,9 +930,9 @@ static const size_t MB = kB*kB;
static const size_t GB = kB*kB*kB;
struct llama_hparams {
bool vocab_only;
uint32_t n_vocab;
uint32_t n_ctx_train; // context size the model was trained on
uint32_t n_ctx; // context size used during inference
uint32_t n_embd;
uint32_t n_head;
uint32_t n_head_kv;
@ -943,8 +943,8 @@ struct llama_hparams {
float f_norm_eps;
float f_norm_rms_eps;
float rope_freq_base;
float rope_freq_scale;
float rope_freq_base_train;
float rope_freq_scale_train;
bool operator!=(const llama_hparams & other) const {
return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT
@ -961,15 +961,16 @@ struct llama_hparams {
uint32_t n_embd_gqa() const {
return n_embd/n_gqa();
}
};
size_t kv_size() const {
size_t result = 2ull;
result *= (size_t) n_embd_gqa();
result *= (size_t) n_ctx;
result *= (size_t) n_layer;
result *= sizeof(ggml_fp16_t);
return result;
}
struct llama_cparams {
uint32_t n_ctx; // context size used during inference
uint32_t n_batch;
float rope_freq_base;
float rope_freq_scale;
bool mul_mat_q;
};
struct llama_layer {
@ -1127,11 +1128,8 @@ struct llama_model {
};
struct llama_context {
llama_context(const llama_model & model) : model(model), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {}
llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {}
~llama_context() {
if (model_owner) {
delete &model;
}
#ifdef GGML_USE_METAL
if (ctx_metal) {
ggml_metal_free(ctx_metal);
@ -1142,27 +1140,35 @@ struct llama_context {
}
}
llama_cparams cparams;
const llama_model & model;
// key + value cache for the self attention
struct llama_kv_cache kv_self;
size_t kv_size() const {
size_t result = 2ull;
result *= (size_t) model.hparams.n_embd_gqa();
result *= (size_t) cparams.n_ctx;
result *= (size_t) model.hparams.n_layer;
result *= sizeof(ggml_fp16_t);
return result;
}
std::mt19937 rng;
bool has_evaluated_once = false;
int64_t t_start_us;
int64_t t_load_us;
int64_t t_sample_us = 0;
int64_t t_eval_us = 0;
int64_t t_p_eval_us = 0;
int64_t t_eval_us = 0;
int32_t n_sample = 0; // number of tokens sampled
int32_t n_eval = 0; // number of eval calls
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
const llama_model & model;
bool model_owner = false;
int64_t t_load_us;
int64_t t_start_us;
// key + value cache for the self attention
struct llama_kv_cache kv_self;
int32_t n_eval = 0; // number of eval calls
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
@ -1592,7 +1598,15 @@ struct llama_model_loader {
// load LLaMA models
//
static std::string llama_model_ftype_name(enum llama_ftype ftype) {
static std::string llama_model_arch_name(llm_arch arch) {
auto it = LLM_ARCH_NAMES.find(arch);
if (it == LLM_ARCH_NAMES.end()) {
return "unknown";
}
return it->second;
}
static std::string llama_model_ftype_name(llama_ftype ftype) {
if (ftype & LLAMA_FTYPE_GUESSED) {
return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)";
}
@ -1648,10 +1662,7 @@ static void llm_load_arch(llama_model_loader & ml, llama_model & model) {
static void llm_load_hparams(
llama_model_loader & ml,
llama_model & model,
int n_ctx,
float rope_freq_base,
float rope_freq_scale) {
llama_model & model) {
struct gguf_context * ctx = ml.ctx_gguf;
const auto kv = LLM_KV(model.arch);
@ -1662,29 +1673,25 @@ static void llm_load_hparams(
GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME));
// get hparams kv
GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST));
GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH));
GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST));
GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH));
GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
// n_head_kv is optional, default to n_head
hparams.n_head_kv = hparams.n_head;
GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
// rope_freq_base (optional)
if (rope_freq_base == 0.0f) {
rope_freq_base = 10000.0f;
GGUF_GET_KEY(ctx, rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
}
hparams.rope_freq_base_train = 10000.0f;
GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
// rope_freq_scale (inverse of the kv) is optional
if (rope_freq_scale == 0.0f) {
float ropescale = 1.0f;
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
rope_freq_scale = 1.0f/ropescale;
}
float ropescale = 1.0f;
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
hparams.rope_freq_scale_train = 1.0f/ropescale;
// sanity check for n_rot (optional)
{
@ -1751,10 +1758,6 @@ static void llm_load_hparams(
};
model.ftype = ml.ftype;
hparams.n_ctx = n_ctx;
hparams.rope_freq_base = rope_freq_base;
hparams.rope_freq_scale = rope_freq_scale;
}
// TODO: This should probably be in llama.h
@ -1880,31 +1883,30 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
const auto & vocab = model.vocab;
// hparams
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size());
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx);
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head);
LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps);
LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9);
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size());
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head);
LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps);
LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9);
if (ml.n_bytes < GB) {
LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
} else {
LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
}
// general kv
@ -1922,13 +1924,10 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
static void llm_load_tensors(
llama_model_loader & ml,
llama_model & model,
int n_batch,
int n_gpu_layers,
int main_gpu,
const float * tensor_split,
const bool mul_mat_q,
bool low_vram,
ggml_type memory_type,
bool use_mlock,
llama_progress_callback progress_callback,
void * progress_callback_user_data) {
@ -1967,11 +1966,9 @@ static void llm_load_tensors(
}
(void) main_gpu;
(void) mul_mat_q;
#if defined(GGML_USE_CUBLAS)
LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__);
ggml_cuda_set_main_device(main_gpu);
ggml_cuda_set_mul_mat_q(mul_mat_q);
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT
#elif defined(GGML_USE_CLBLAST)
@ -2293,20 +2290,12 @@ static void llm_load_tensors(
// print memory requirements
{
const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
// this is the total memory required to run the inference
size_t mem_required =
ctx_size +
mmapped_size - vram_weights; // weights in VRAM not in memory
// this is the memory required by one llama_state
const size_t mem_required_state = scale*hparams.kv_size();
LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
(void) n_batch;
LLAMA_LOG_INFO("%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
@ -2315,36 +2304,17 @@ static void llm_load_tensors(
if (n_gpu_layers > (int) hparams.n_layer) {
LLAMA_LOG_INFO("%s: offloading non-repeating layers to GPU\n", __func__);
}
size_t vram_kv_cache = 0;
#ifdef GGML_USE_CUBLAS
const int max_backend_supported_layers = hparams.n_layer + 3;
const int max_offloadable_layers = low_vram ? hparams.n_layer + 1 : hparams.n_layer + 3;
if (n_gpu_layers > (int) hparams.n_layer + 1) {
if (low_vram) {
LLAMA_LOG_INFO("%s: cannot offload v cache to GPU due to low VRAM option\n", __func__);
} else {
LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__);
vram_kv_cache += hparams.kv_size() / 2;
}
}
if (n_gpu_layers > (int) hparams.n_layer + 2) {
if (low_vram) {
LLAMA_LOG_WARN("%s: cannot offload k cache to GPU due to low VRAM option\n", __func__);
} else {
LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__);
vram_kv_cache += hparams.kv_size() / 2;
}
}
#elif defined(GGML_USE_CLBLAST)
const int max_backend_supported_layers = hparams.n_layer + 1;
const int max_offloadable_layers = hparams.n_layer + 1;
#endif // GGML_USE_CUBLAS
LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n",
__func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
LLAMA_LOG_INFO("%s: VRAM used: %zu MB\n",
__func__, (vram_weights + vram_kv_cache + MB - 1) / MB); // round up
LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
LLAMA_LOG_INFO("%s: VRAM used: %.2f MB\n", __func__, vram_weights / 1024.0 / 1024.0);
#else
(void) n_gpu_layers;
#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
@ -2379,29 +2349,25 @@ static void llm_load_tensors(
static bool llama_model_load(
const std::string & fname,
llama_model & model,
int n_ctx,
int n_batch,
int n_gpu_layers,
int main_gpu,
const float * tensor_split,
const bool mul_mat_q,
float rope_freq_base,
float rope_freq_scale,
bool low_vram,
ggml_type memory_type,
bool use_mmap,
bool use_mlock,
bool vocab_only,
llama_progress_callback progress_callback,
void *progress_callback_user_data) {
try {
std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname, use_mmap));
llama_model_loader ml(fname, use_mmap);
llm_load_arch (*ml, model);
llm_load_hparams(*ml, model, n_ctx, rope_freq_base, rope_freq_scale);
llm_load_vocab (*ml, model);
model.hparams.vocab_only = vocab_only;
llm_load_print_meta(*ml, model);
llm_load_arch (ml, model);
llm_load_hparams(ml, model);
llm_load_vocab (ml, model);
llm_load_print_meta(ml, model);
if (model.hparams.n_vocab != model.vocab.id_to_token.size()) {
throw std::runtime_error("vocab size mismatch");
@ -2413,8 +2379,8 @@ static bool llama_model_load(
}
llm_load_tensors(
*ml, model, n_batch, n_gpu_layers,
main_gpu, tensor_split, mul_mat_q, low_vram, memory_type,
ml, model, n_gpu_layers,
main_gpu, tensor_split, low_vram,
use_mlock, progress_callback, progress_callback_user_data);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("error loading model: %s\n", err.what());
@ -2437,6 +2403,7 @@ static struct ggml_cgraph * llm_build_llama(
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
@ -2444,7 +2411,7 @@ static struct ggml_cgraph * llm_build_llama(
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = hparams.n_ctx;
const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
@ -2452,8 +2419,8 @@ static struct ggml_cgraph * llm_build_llama(
GGML_ASSERT(n_embd_head == hparams.n_rot);
const float freq_base = hparams.rope_freq_base;
const float freq_scale = hparams.rope_freq_scale;
const float freq_base = cparams.rope_freq_base;
const float freq_scale = cparams.rope_freq_scale;
const float norm_rms_eps = hparams.f_norm_rms_eps;
const int n_gpu_layers = model.n_gpu_layers;
@ -2769,6 +2736,7 @@ static struct ggml_cgraph * llm_build_baichaun(
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
@ -2776,7 +2744,7 @@ static struct ggml_cgraph * llm_build_baichaun(
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = hparams.n_ctx;
const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
@ -2784,8 +2752,8 @@ static struct ggml_cgraph * llm_build_baichaun(
GGML_ASSERT(n_embd_head == hparams.n_rot);
const float freq_base = hparams.rope_freq_base;
const float freq_scale = hparams.rope_freq_scale;
const float freq_base = cparams.rope_freq_base;
const float freq_scale = cparams.rope_freq_scale;
const float norm_rms_eps = hparams.f_norm_rms_eps;
const int n_gpu_layers = model.n_gpu_layers;
@ -3129,6 +3097,7 @@ static struct ggml_cgraph * llm_build_falcon(
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
@ -3136,7 +3105,7 @@ static struct ggml_cgraph * llm_build_falcon(
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = hparams.n_ctx;
const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
@ -3144,8 +3113,8 @@ static struct ggml_cgraph * llm_build_falcon(
GGML_ASSERT(n_embd_head == hparams.n_rot);
const float freq_base = hparams.rope_freq_base;
const float freq_scale = hparams.rope_freq_scale;
const float freq_base = cparams.rope_freq_base;
const float freq_scale = cparams.rope_freq_scale;
const float norm_eps = hparams.f_norm_eps;
const int n_gpu_layers = model.n_gpu_layers;
@ -3436,6 +3405,7 @@ static struct ggml_cgraph * llm_build_starcoder(
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
@ -3443,7 +3413,7 @@ static struct ggml_cgraph * llm_build_starcoder(
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = hparams.n_ctx;
const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
@ -3658,6 +3628,8 @@ static struct ggml_cgraph * llama_build_graph(
const float * embd,
int n_tokens,
int n_past) {
GGML_ASSERT(n_tokens > 0);
const auto & model = lctx.model;
struct ggml_cgraph * result = NULL;
@ -3706,11 +3678,14 @@ static bool llama_eval_internal(
GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT
const auto & cparams = lctx.cparams;
const int n_ctx = cparams.n_ctx;
const int n_batch = cparams.n_batch;
GGML_ASSERT(n_tokens > 0);
GGML_ASSERT(n_past >= 0);
// TODO: keep the values of n_batch and n_ctx
// GGML_ASSERT(n_tokens <= n_batch);
// GGML_ASSERT(n_past + n_tokens <= n_ctx);
GGML_ASSERT(n_tokens <= n_batch);
GGML_ASSERT(n_past + n_tokens <= n_ctx);
const int64_t t_start_us = ggml_time_us();
@ -3752,6 +3727,8 @@ static bool llama_eval_internal(
ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data);
}
}
ggml_cuda_set_mul_mat_q(cparams.mul_mat_q);
#endif
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
@ -5663,11 +5640,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
nthread = std::thread::hardware_concurrency();
}
std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname_inp, /*use_mmap*/ false));
llama_model_loader ml(fname_inp, /*use_mmap*/ false);
llama_model model;
llm_load_arch(*ml, model);
llm_load_hparams(*ml, model, 0, 0, 0);
llm_load_arch(ml, model);
llm_load_hparams(ml, model);
if (params->only_copy) {
ftype = model.ftype;
@ -5677,7 +5654,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
struct gguf_context * ctx_out = gguf_init_empty();
// copy the KV pairs from the input file
gguf_set_kv (ctx_out, ml->ctx_gguf);
gguf_set_kv (ctx_out, ml.ctx_gguf);
gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION);
gguf_set_val_u32(ctx_out, "general.file_type", ftype);
@ -5685,8 +5662,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
int n_attention_wv = 0;
int n_feed_forward_w2 = 0;
for (int i = 0; i < ml->n_tensors; ++i) {
struct ggml_tensor * meta = ml->get_tensor_meta(i);
for (int i = 0; i < ml.n_tensors; ++i) {
struct ggml_tensor * meta = ml.get_tensor_meta(i);
const std::string name = ggml_get_name(meta);
@ -5722,8 +5699,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
std::vector<no_init<float>> f32_conv_buf;
// populate the original tensors so we get an initial meta data
for (int i = 0; i < ml->n_tensors; ++i) {
struct ggml_tensor * meta = ml->get_tensor_meta(i);
for (int i = 0; i < ml.n_tensors; ++i) {
struct ggml_tensor * meta = ml.get_tensor_meta(i);
gguf_add_tensor(ctx_out, meta);
}
@ -5736,8 +5713,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// placeholder for the meta data
::zeros(fout, meta_size);
for (int i = 0; i < ml->n_tensors; ++i) {
struct ggml_tensor * tensor = ml->get_tensor_meta(i);
for (int i = 0; i < ml.n_tensors; ++i) {
struct ggml_tensor * tensor = ml.get_tensor_meta(i);
const std::string name = ggml_get_name(tensor);
@ -5745,10 +5722,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
read_data.resize(ggml_nbytes(tensor));
}
tensor->data = read_data.data();
ml->load_data_for(tensor);
ml.load_data_for(tensor);
LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
++idx, ml->n_tensors,
++idx, ml.n_tensors,
ggml_get_name(tensor),
llama_format_tensor_shape(tensor).c_str(),
ggml_type_name(tensor->type));
@ -6176,33 +6153,40 @@ static int llama_apply_lora_from_file_internal(
//
// interface implementation
//
struct llama_model_params llama_model_default_params() {
struct llama_model_params result = {
/*.n_gpu_layers =*/ 0,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr,
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
/*.low_vram =*/ false,
/*.vocab_only =*/ false,
/*.use_mmap =*/ true,
/*.use_mlock =*/ false,
};
#ifdef GGML_USE_METAL
result.n_gpu_layers = 1;
#endif
return result;
}
struct llama_context_params llama_context_default_params() {
struct llama_context_params result = {
/*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_ctx =*/ 512,
/*.n_batch =*/ 512,
/*.n_gpu_layers =*/ 0,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr,
/*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f,
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
/*.low_vram =*/ false,
/*.mul_mat_q =*/ true,
/*.f16_kv =*/ true,
/*.logits_all =*/ false,
/*.vocab_only =*/ false,
/*.use_mmap =*/ true,
/*.use_mlock =*/ false,
/*.embedding =*/ false,
};
#ifdef GGML_USE_METAL
result.n_gpu_layers = 1;
#endif
return result;
}
@ -6261,13 +6245,11 @@ int64_t llama_time_us(void) {
struct llama_model * llama_load_model_from_file(
const char * path_model,
struct llama_context_params params) {
struct llama_model_params params) {
ggml_time_init();
llama_model * model = new llama_model;
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
unsigned cur_percentage = 0;
if (params.progress_callback == NULL) {
params.progress_callback_user_data = &cur_percentage;
@ -6284,9 +6266,9 @@ struct llama_model * llama_load_model_from_file(
};
}
if (!llama_model_load(path_model, *model, params.n_ctx, params.n_batch, params.n_gpu_layers,
params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale,
params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only,
if (!llama_model_load(path_model, *model, params.n_gpu_layers,
params.main_gpu, params.tensor_split, params.low_vram,
params.use_mmap, params.use_mlock, params.vocab_only,
params.progress_callback, params.progress_callback_user_data)) {
LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
delete model;
@ -6310,18 +6292,31 @@ struct llama_context * llama_new_context_with_model(
llama_context * ctx = new llama_context(*model);
const auto & hparams = model->hparams;
auto & cparams = ctx->cparams;
cparams.n_batch = params.n_batch;
cparams.mul_mat_q = params.mul_mat_q;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0 ? hparams.rope_freq_base_train : params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale == 0 ? hparams.rope_freq_scale_train : params.rope_freq_scale;
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
ctx->rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all;
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
// reserve memory for context buffers
if (!params.vocab_only) {
if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, ctx->model.hparams.n_ctx, params.n_gpu_layers)) {
if (!hparams.vocab_only) {
if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, cparams.n_ctx, model->n_gpu_layers)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
@ -6332,11 +6327,33 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
}
const auto & hparams = ctx->model.hparams;
#ifdef GGML_USE_CUBLAS
{
size_t vram_kv_cache = 0;
if (model->n_gpu_layers > (int) hparams.n_layer + 1) {
if (params.low_vram) {
LLAMA_LOG_INFO("%s: cannot offload v cache to GPU due to low VRAM option\n", __func__);
} else {
LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__);
vram_kv_cache += ctx->kv_size() / 2;
}
}
if (model->n_gpu_layers > (int) hparams.n_layer + 2) {
if (params.low_vram) {
LLAMA_LOG_WARN("%s: cannot offload k cache to GPU due to low VRAM option\n", __func__);
} else {
LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__);
vram_kv_cache += ctx->kv_size() / 2;
}
}
LLAMA_LOG_INFO("%s: VRAM kv cache = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
}
#endif
// resized during inference
if (params.logits_all) {
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
ctx->logits.reserve(cparams.n_ctx*hparams.n_vocab);
} else {
ctx->logits.reserve(hparams.n_vocab);
}
@ -6354,8 +6371,8 @@ struct llama_context * llama_new_context_with_model(
ctx->alloc = ggml_allocr_new_measure(tensor_alignment);
// build worst-case graph
int n_tokens = std::min((int)hparams.n_ctx, params.n_batch);
int n_past = hparams.n_ctx - n_tokens;
int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
int n_past = cparams.n_ctx - n_tokens;
llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past);
#ifdef GGML_USE_METAL
@ -6373,7 +6390,7 @@ struct llama_context * llama_new_context_with_model(
// measure memory requirements for the graph
size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment;
LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
LLAMA_LOG_INFO("%s: compute buffer total size = %.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
// recreate allocator with exact memory requirements
ggml_allocr_free(ctx->alloc);
@ -6386,6 +6403,7 @@ struct llama_context * llama_new_context_with_model(
}
#endif
#ifdef GGML_USE_CUBLAS
// TODO: different scratch buffers per context
if (params.low_vram) {
LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__);
ggml_cuda_set_scratch_size(0); // disable scratch
@ -6422,11 +6440,8 @@ struct llama_context * llama_new_context_with_model(
return NULL; \
}
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.data, ctx->buf_compute.size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.data, ctx->buf_alloc.size, 0));
#undef LLAMA_METAL_CHECK_BUF
}
@ -6438,7 +6453,7 @@ struct llama_context * llama_new_context_with_model(
if (ggml_mpi_rank(ctx->ctx_mpi) > 0) {
// Enter a blocking eval loop with dummy input, letting rank=0 drive the process
const std::vector<llama_token> tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx));
const std::vector<llama_token> tmp(cparams.n_ctx, llama_token_bos(ctx));
while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {};
llama_backend_free();
exit(1);
@ -6448,20 +6463,6 @@ struct llama_context * llama_new_context_with_model(
return ctx;
}
static struct llama_context * llama_init_from_file(
const char * path_model,
struct llama_context_params params) {
struct llama_model * model = llama_load_model_from_file(path_model, params);
if (!model) {
return nullptr;
}
struct llama_context * ctx = llama_new_context_with_model(model, params);
ctx->model_owner = true;
return ctx;
}
void llama_free(struct llama_context * ctx) {
delete ctx;
}
@ -6471,7 +6472,7 @@ int llama_n_vocab(const struct llama_context * ctx) {
}
int llama_n_ctx(const struct llama_context * ctx) {
return llama_model_n_ctx(&ctx->model);
return ctx->cparams.n_ctx;
}
int llama_n_ctx_train(const struct llama_context * ctx) {
@ -6490,10 +6491,6 @@ int llama_model_n_vocab(const struct llama_model * model) {
return model->vocab.id_to_token.size();
}
int llama_model_n_ctx(const struct llama_model * model) {
return model->hparams.n_ctx;
}
int llama_model_n_ctx_train(const struct llama_model * model) {
return model->hparams.n_ctx_train;
}
@ -6504,7 +6501,7 @@ int llama_model_n_embd(const struct llama_model * model) {
int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
return snprintf(buf, buf_size, "%s %s %s",
model->name.c_str(),
llama_model_arch_name(model->arch).c_str(),
llama_model_type_name(model->type),
llama_model_ftype_name(model->ftype).c_str());
}
@ -6704,9 +6701,11 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
{
const auto & kv_self = ctx->kv_self;
const auto & hparams = ctx->model.hparams;
const auto & cparams = ctx->cparams;
const int n_layer = hparams.n_layer;
const int n_embd = hparams.n_embd_gqa();
const int n_ctx = hparams.n_ctx;
const int n_ctx = cparams.n_ctx;
const size_t kv_size = kv_self.buf.size;
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
@ -6812,9 +6811,11 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
{
const auto & kv_self = ctx->kv_self;
const auto & hparams = ctx->model.hparams;
const auto & cparams = ctx->cparams;
const int n_layer = hparams.n_layer;
const int n_embd = hparams.n_embd_gqa();
const int n_ctx = hparams.n_ctx;
const int n_ctx = cparams.n_ctx;
size_t kv_size;
int kv_ntok;

37
llama.h
View file

@ -122,19 +122,11 @@ extern "C" {
typedef void (*llama_progress_callback)(float progress, void *ctx);
struct llama_context_params {
uint32_t seed; // RNG seed, -1 for random
int32_t n_ctx; // text context
int32_t n_batch; // prompt processing batch size
int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t main_gpu; // the GPU that is used for scratch and small tensors
struct llama_model_params {
int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t main_gpu; // the GPU that is used for scratch and small tensors
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency
float rope_freq_scale; // RoPE frequency scaling factor
// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;
// context pointer passed to the progress callback
@ -142,12 +134,25 @@ extern "C" {
// Keep the booleans together to avoid misalignment during copy-by-value.
bool low_vram; // if true, reduce VRAM usage at the cost of performance
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM
};
struct llama_context_params {
uint32_t seed; // RNG seed, -1 for random
uint32_t n_ctx; // text context
uint32_t n_batch; // prompt processing batch size
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency
float rope_freq_scale; // RoPE frequency scaling factor
// Keep the booleans together to avoid misalignment during copy-by-value.
bool low_vram; // if true, reduce VRAM usage at the cost of performance
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool embedding; // embedding mode only
};
@ -215,6 +220,7 @@ extern "C" {
int32_t n_eval;
};
LLAMA_API struct llama_model_params llama_model_default_params(void);
LLAMA_API struct llama_context_params llama_context_default_params(void);
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
@ -228,7 +234,7 @@ extern "C" {
LLAMA_API struct llama_model * llama_load_model_from_file(
const char * path_model,
struct llama_context_params params);
struct llama_model_params params);
LLAMA_API void llama_free_model(struct llama_model * model);
@ -253,7 +259,6 @@ extern "C" {
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
LLAMA_API int llama_model_n_vocab (const struct llama_model * model);
LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
LLAMA_API int llama_model_n_embd (const struct llama_model * model);

View file

@ -62,18 +62,20 @@ int main(int argc, char **argv) {
// load the vocab
{
auto lparams = llama_context_default_params();
auto mparams = llama_model_default_params();
lparams.vocab_only = true;
mparams.vocab_only = true;
model = llama_load_model_from_file(fname.c_str(), lparams);
model = llama_load_model_from_file(fname.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
return 1;
}
ctx = llama_new_context_with_model(model, lparams);
auto cparams = llama_context_default_params();
ctx = llama_new_context_with_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());

View file

@ -64,18 +64,20 @@ int main(int argc, char **argv) {
// load the vocab
{
auto lparams = llama_context_default_params();
auto mparams = llama_model_default_params();
lparams.vocab_only = true;
mparams.vocab_only = true;
model = llama_load_model_from_file(fname.c_str(), lparams);
model = llama_load_model_from_file(fname.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
return 1;
}
ctx = llama_new_context_with_model(model, lparams);
auto cparams = llama_context_default_params();
ctx = llama_new_context_with_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());

View file

@ -52,18 +52,20 @@ int main(int argc, char **argv) {
// load the vocab
{
auto lparams = llama_context_default_params();
auto mparams = llama_model_default_params();
lparams.vocab_only = true;
mparams.vocab_only = true;
model = llama_load_model_from_file(fname.c_str(), lparams);
model = llama_load_model_from_file(fname.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
return 1;
}
ctx = llama_new_context_with_model(model, lparams);
auto cparams = llama_context_default_params();
ctx = llama_new_context_with_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());