Update public API use cases: move away from deprecated llama_init_from_file

This commit is contained in:
Didzis Gosko 2023-06-20 23:47:33 +03:00
parent 7bba46ba62
commit 69f776282b
11 changed files with 100 additions and 27 deletions

View file

@ -536,7 +536,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res; return res;
} }
struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params(); auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx; lparams.n_ctx = params.n_ctx;
@ -552,11 +552,17 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
lparams.logits_all = params.perplexity; lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding; lparams.embedding = params.embedding;
llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams); llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
if (model == NULL) {
if (lctx == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return NULL; return std::make_tuple(nullptr, nullptr);
}
llama_context * lctx = llama_new_context_with_model(model, lparams);
if (lctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model);
return std::make_tuple(nullptr, nullptr);
} }
if (!params.lora_adapter.empty()) { if (!params.lora_adapter.empty()) {
@ -566,11 +572,13 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
params.n_threads); params.n_threads);
if (err != 0) { if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
return NULL; llama_free(lctx);
llama_free_model(model);
return std::make_tuple(nullptr, nullptr);
} }
} }
return lctx; return std::make_tuple(model, lctx);
} }
void console_init(console_state & con_st) { void console_init(console_state & con_st) {

View file

@ -9,6 +9,7 @@
#include <random> #include <random>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <tuple>
#if !defined (_WIN32) #if !defined (_WIN32)
#include <stdio.h> #include <stdio.h>
@ -95,7 +96,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
// Model utils // Model utils
// //
struct llama_context * llama_init_from_gpt_params(const gpt_params & params); std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params);
// //
// Console utils // Console utils

View file

@ -37,11 +37,12 @@ int main(int argc, char ** argv) {
llama_init_backend(); llama_init_backend();
llama_model * model;
llama_context * ctx; llama_context * ctx;
// load the model // load the model
ctx = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (ctx == NULL) { if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__); fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1; return 1;
} }
@ -90,6 +91,7 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx); llama_print_timings(ctx);
llama_free(ctx); llama_free(ctx);
llama_free_model(model);
return 0; return 0;
} }

View file

@ -107,12 +107,13 @@ int main(int argc, char ** argv) {
llama_init_backend(); llama_init_backend();
llama_model * model;
llama_context * ctx; llama_context * ctx;
g_ctx = &ctx; g_ctx = &ctx;
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
ctx = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (ctx == NULL) { if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__); fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1; return 1;
} }
@ -139,6 +140,7 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx); llama_print_timings(ctx);
llama_free(ctx); llama_free(ctx);
llama_free_model(model);
return 0; return 0;
} }
@ -147,6 +149,7 @@ int main(int argc, char ** argv) {
if (params.export_cgraph) { if (params.export_cgraph) {
llama_eval_export(ctx, "llama.ggml"); llama_eval_export(ctx, "llama.ggml");
llama_free(ctx); llama_free(ctx);
llama_free_model(model);
return 0; return 0;
} }
@ -666,6 +669,7 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx); llama_print_timings(ctx);
llama_free(ctx); llama_free(ctx);
llama_free_model(model);
return 0; return 0;
} }

View file

@ -149,11 +149,12 @@ int main(int argc, char ** argv) {
llama_init_backend(); llama_init_backend();
llama_model * model;
llama_context * ctx; llama_context * ctx;
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
ctx = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (ctx == NULL) { if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__); fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1; return 1;
} }
@ -169,6 +170,7 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx); llama_print_timings(ctx);
llama_free(ctx); llama_free(ctx);
llama_free_model(model);
return 0; return 0;
} }

View file

@ -320,6 +320,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "Loading model\n"); fprintf(stderr, "Loading model\n");
const int64_t t_main_start_us = ggml_time_us(); const int64_t t_main_start_us = ggml_time_us();
llama_model * model;
llama_context * ctx; llama_context * ctx;
{ {
@ -330,10 +331,18 @@ int main(int argc, char ** argv) {
lparams.f16_kv = false; lparams.f16_kv = false;
lparams.use_mlock = false; lparams.use_mlock = false;
ctx = llama_init_from_file(params.model.c_str(), lparams); model = llama_load_model_from_file(params.model.c_str(), lparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return 1;
}
ctx = llama_new_context_with_model(model, lparams);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model);
return 1; return 1;
} }
} }
@ -357,6 +366,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: error: Quantization should be tested with a float model, " fprintf(stderr, "%s: error: Quantization should be tested with a float model, "
"this model contains already quantized layers (%s is type %d)\n", __func__, kv_tensor.first.c_str(), kv_tensor.second->type); "this model contains already quantized layers (%s is type %d)\n", __func__, kv_tensor.first.c_str(), kv_tensor.second->type);
llama_free(ctx); llama_free(ctx);
llama_free_model(model);
return 1; return 1;
} }
included_layers++; included_layers++;
@ -415,6 +425,7 @@ int main(int argc, char ** argv) {
llama_free(ctx); llama_free(ctx);
llama_free_model(model);
// report timing // report timing
{ {
const int64_t t_main_end_us = ggml_time_us(); const int64_t t_main_end_us = ggml_time_us();

View file

@ -35,12 +35,22 @@ int main(int argc, char ** argv) {
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0); auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
// init // init
auto ctx = llama_init_from_file(params.model.c_str(), lparams); auto model = llama_load_model_from_file(params.model.c_str(), lparams);
if (model == nullptr) {
return 1;
}
auto ctx = llama_new_context_with_model(model, lparams);
if (ctx == nullptr) {
llama_free_model(model);
return 1;
}
auto tokens = std::vector<llama_token>(params.n_ctx); auto tokens = std::vector<llama_token>(params.n_ctx);
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true); auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true);
if (n_prompt_tokens < 1) { if (n_prompt_tokens < 1) {
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__); fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
llama_free(ctx);
llama_free_model(model);
return 1; return 1;
} }
@ -84,6 +94,8 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str); printf("%s", next_token_str);
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) { if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__); fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx);
llama_free_model(model);
return 1; return 1;
} }
n_past += 1; n_past += 1;
@ -91,23 +103,27 @@ int main(int argc, char ** argv) {
printf("\n\n"); printf("\n\n");
// free old model // free old context
llama_free(ctx); llama_free(ctx);
// load new model // make new context
auto ctx2 = llama_init_from_file(params.model.c_str(), lparams); auto ctx2 = llama_new_context_with_model(model, lparams);
// Load state (rng, logits, embedding and kv_cache) from file // Load state (rng, logits, embedding and kv_cache) from file
{ {
FILE *fp_read = fopen("dump_state.bin", "rb"); FILE *fp_read = fopen("dump_state.bin", "rb");
if (state_size != llama_get_state_size(ctx2)) { if (state_size != llama_get_state_size(ctx2)) {
fprintf(stderr, "\n%s : failed to validate state size\n", __func__); fprintf(stderr, "\n%s : failed to validate state size\n", __func__);
llama_free(ctx2);
llama_free_model(model);
return 1; return 1;
} }
const size_t ret = fread(state_mem, 1, state_size, fp_read); const size_t ret = fread(state_mem, 1, state_size, fp_read);
if (ret != state_size) { if (ret != state_size) {
fprintf(stderr, "\n%s : failed to read state\n", __func__); fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx2);
llama_free_model(model);
return 1; return 1;
} }
@ -138,6 +154,8 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str); printf("%s", next_token_str);
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) { if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__); fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx2);
llama_free_model(model);
return 1; return 1;
} }
n_past += 1; n_past += 1;
@ -145,5 +163,8 @@ int main(int argc, char ** argv) {
printf("\n\n"); printf("\n\n");
llama_free(ctx2);
llama_free_model(model);
return 0; return 0;
} }

View file

@ -115,6 +115,7 @@ struct llama_server_context {
std::vector<llama_token> embd; std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens; std::vector<llama_token> last_n_tokens;
llama_model * model = nullptr;
llama_context * ctx = nullptr; llama_context * ctx = nullptr;
gpt_params params; gpt_params params;
@ -130,6 +131,10 @@ struct llama_server_context {
llama_free(ctx); llama_free(ctx);
ctx = nullptr; ctx = nullptr;
} }
if (model) {
llama_free_model(model);
model = nullptr;
}
} }
void rewind() { void rewind() {
@ -150,8 +155,8 @@ struct llama_server_context {
bool loadModel(const gpt_params & params_) { bool loadModel(const gpt_params & params_) {
params = params_; params = params_;
ctx = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (ctx == nullptr) { if (model == nullptr) {
LOG_ERROR("unable to load model", { { "model", params_.model } }); LOG_ERROR("unable to load model", { { "model", params_.model } });
return false; return false;
} }

View file

@ -68,11 +68,12 @@ int main(int argc, char ** argv)
llama_init_backend(); llama_init_backend();
llama_model * model;
llama_context * ctx; llama_context * ctx;
ctx = llama_init_from_gpt_params( params ); std::tie(model, ctx) = llama_init_from_gpt_params( params );
if ( ctx == NULL ) if ( model == NULL )
{ {
fprintf( stderr , "%s: error: unable to load model\n" , __func__ ); fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
return 1; return 1;
@ -170,6 +171,7 @@ int main(int argc, char ** argv)
} // wend of main loop } // wend of main loop
llama_free( ctx ); llama_free( ctx );
llama_free_model( model );
return 0; return 0;
} }

View file

@ -3054,7 +3054,8 @@ int main(int argc, char ** argv) {
struct llama_context_params llama_params = llama_context_default_params(); struct llama_context_params llama_params = llama_context_default_params();
llama_params.vocab_only = true; llama_params.vocab_only = true;
struct llama_context * lctx = llama_init_from_file(params.fn_vocab_model, llama_params); struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, llama_params);
struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
struct llama_vocab vocab; struct llama_vocab vocab;
{ {
@ -3395,6 +3396,8 @@ int main(int argc, char ** argv) {
delete[] compute_addr; delete[] compute_addr;
delete[] compute_buf_0; delete[] compute_buf_0;
delete[] compute_buf_1; delete[] compute_buf_1;
llama_free(lctx);
llama_free_model(lmodel);
ggml_free(model.ctx); ggml_free(model.ctx);
return 0; return 0;

View file

@ -28,6 +28,7 @@ int main(int argc, char **argv) {
fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str()); fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
llama_model * model;
llama_context * ctx; llama_context * ctx;
// load the vocab // load the vocab
@ -36,10 +37,18 @@ int main(int argc, char **argv) {
lparams.vocab_only = true; lparams.vocab_only = true;
ctx = llama_init_from_file(fname.c_str(), lparams); model = llama_load_model_from_file(fname.c_str(), lparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
return 1;
}
ctx = llama_new_context_with_model(model, lparams);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
llama_free_model(model);
return 1; return 1;
} }
} }
@ -48,6 +57,8 @@ int main(int argc, char **argv) {
if (n_vocab != 32000) { if (n_vocab != 32000) {
fprintf(stderr, "%s : expected 32000 tokens, got %d\n", __func__, n_vocab); fprintf(stderr, "%s : expected 32000 tokens, got %d\n", __func__, n_vocab);
llama_free_model(model);
llama_free(ctx);
return 2; return 2;
} }
@ -77,10 +88,13 @@ int main(int argc, char **argv) {
} }
fprintf(stderr, "\n"); fprintf(stderr, "\n");
llama_free_model(model);
llama_free(ctx);
return 3; return 3;
} }
} }
llama_free_model(model);
llama_free(ctx); llama_free(ctx);
return 0; return 0;