Restore signature of llama_init_from_gpt_params

This commit is contained in:
Bach Le 2023-07-07 22:25:00 +08:00
parent 478630019b
commit 8ba5b137c8
8 changed files with 20 additions and 13 deletions

View file

@ -556,7 +556,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res;
}
std::tuple<struct llama_model *, struct llama_context *, struct llama_context_params> llama_init_from_gpt_params(const gpt_params & params) {
struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
@ -572,17 +572,23 @@ std::tuple<struct llama_model *, struct llama_context *, struct llama_context_pa
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;
return lparams;
}
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) {
auto lparams = llama_get_context_params_from_gpt_params(params);
llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return std::make_tuple(nullptr, nullptr, lparams);
return std::make_tuple(nullptr, nullptr);
}
llama_context * lctx = llama_new_context_with_model(model, lparams);
if (lctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model);
return std::make_tuple(nullptr, nullptr, lparams);
return std::make_tuple(nullptr, nullptr);
}
if (!params.lora_adapter.empty()) {
@ -594,11 +600,11 @@ std::tuple<struct llama_model *, struct llama_context *, struct llama_context_pa
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
llama_free(lctx);
llama_free_model(model);
return std::make_tuple(nullptr, nullptr, lparams);
return std::make_tuple(nullptr, nullptr);
}
}
return std::make_tuple(model, lctx, lparams);
return std::make_tuple(model, lctx);
}
void console_init(console_state & con_st) {

View file

@ -104,7 +104,8 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
// Model utils
//
std::tuple<struct llama_model *, struct llama_context *, struct llama_context_params> llama_init_from_gpt_params(const gpt_params & params);
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params);
struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_params & params);
//
// Console utils

View file

@ -42,7 +42,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) {
g_ctx = &ctx;
// load the model and apply lora adapter, if any
std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return nullptr;

View file

@ -41,7 +41,7 @@ int main(int argc, char ** argv) {
llama_context * ctx;
// load the model
std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;

View file

@ -124,12 +124,12 @@ int main(int argc, char ** argv) {
llama_model * model;
llama_context * ctx;
llama_context * guidance_ctx = NULL;
struct llama_context_params lparams;
g_ctx = &ctx;
// load the model and apply lora adapter, if any
std::tie(model, ctx, lparams) = llama_init_from_gpt_params(params);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (params.cfg_scale > 1.f) {
struct llama_context_params lparams = llama_get_context_params_from_gpt_params(params);
guidance_ctx = llama_new_context_with_model(model, lparams);
}

View file

@ -153,7 +153,7 @@ int main(int argc, char ** argv) {
llama_context * ctx;
// load the model and apply lora adapter, if any
std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;

View file

@ -245,7 +245,7 @@ struct llama_server_context
bool loadModel(const gpt_params &params_)
{
params = params_;
std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == nullptr)
{
LOG_ERROR("unable to load model", {{"model", params_.model}});

View file

@ -71,7 +71,7 @@ int main(int argc, char ** argv)
llama_model * model;
llama_context * ctx;
std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params( params );
std::tie(model, ctx) = llama_init_from_gpt_params( params );
if ( model == NULL )
{