Restore signature of llama_init_from_gpt_params

This commit is contained in:
Bach Le 2023-07-07 22:25:00 +08:00
parent 478630019b
commit 8ba5b137c8
8 changed files with 20 additions and 13 deletions

View file

@ -556,7 +556,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res; return res;
} }
std::tuple<struct llama_model *, struct llama_context *, struct llama_context_params> llama_init_from_gpt_params(const gpt_params & params) { struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params(); auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx; lparams.n_ctx = params.n_ctx;
@ -572,17 +572,23 @@ std::tuple<struct llama_model *, struct llama_context *, struct llama_context_pa
lparams.logits_all = params.perplexity; lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding; lparams.embedding = params.embedding;
return lparams;
}
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) {
auto lparams = llama_get_context_params_from_gpt_params(params);
llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
if (model == NULL) { if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return std::make_tuple(nullptr, nullptr, lparams); return std::make_tuple(nullptr, nullptr);
} }
llama_context * lctx = llama_new_context_with_model(model, lparams); llama_context * lctx = llama_new_context_with_model(model, lparams);
if (lctx == NULL) { if (lctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model); llama_free_model(model);
return std::make_tuple(nullptr, nullptr, lparams); return std::make_tuple(nullptr, nullptr);
} }
if (!params.lora_adapter.empty()) { if (!params.lora_adapter.empty()) {
@ -594,11 +600,11 @@ std::tuple<struct llama_model *, struct llama_context *, struct llama_context_pa
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
llama_free(lctx); llama_free(lctx);
llama_free_model(model); llama_free_model(model);
return std::make_tuple(nullptr, nullptr, lparams); return std::make_tuple(nullptr, nullptr);
} }
} }
return std::make_tuple(model, lctx, lparams); return std::make_tuple(model, lctx);
} }
void console_init(console_state & con_st) { void console_init(console_state & con_st) {

View file

@ -104,7 +104,8 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
// Model utils // Model utils
// //
std::tuple<struct llama_model *, struct llama_context *, struct llama_context_params> llama_init_from_gpt_params(const gpt_params & params); std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params);
struct llama_context_params llama_get_context_params_from_gpt_params(const gpt_params & params);
// //
// Console utils // Console utils

View file

@ -42,7 +42,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) {
g_ctx = &ctx; g_ctx = &ctx;
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == NULL) { if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__); fprintf(stderr, "%s: error: unable to load model\n", __func__);
return nullptr; return nullptr;

View file

@ -41,7 +41,7 @@ int main(int argc, char ** argv) {
llama_context * ctx; llama_context * ctx;
// load the model // load the model
std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == NULL) { if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__); fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1; return 1;

View file

@ -124,12 +124,12 @@ int main(int argc, char ** argv) {
llama_model * model; llama_model * model;
llama_context * ctx; llama_context * ctx;
llama_context * guidance_ctx = NULL; llama_context * guidance_ctx = NULL;
struct llama_context_params lparams;
g_ctx = &ctx; g_ctx = &ctx;
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
std::tie(model, ctx, lparams) = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (params.cfg_scale > 1.f) { if (params.cfg_scale > 1.f) {
struct llama_context_params lparams = llama_get_context_params_from_gpt_params(params);
guidance_ctx = llama_new_context_with_model(model, lparams); guidance_ctx = llama_new_context_with_model(model, lparams);
} }

View file

@ -153,7 +153,7 @@ int main(int argc, char ** argv) {
llama_context * ctx; llama_context * ctx;
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == NULL) { if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__); fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1; return 1;

View file

@ -245,7 +245,7 @@ struct llama_server_context
bool loadModel(const gpt_params &params_) bool loadModel(const gpt_params &params_)
{ {
params = params_; params = params_;
std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == nullptr) if (model == nullptr)
{ {
LOG_ERROR("unable to load model", {{"model", params_.model}}); LOG_ERROR("unable to load model", {{"model", params_.model}});

View file

@ -71,7 +71,7 @@ int main(int argc, char ** argv)
llama_model * model; llama_model * model;
llama_context * ctx; llama_context * ctx;
std::tie(model, ctx, std::ignore) = llama_init_from_gpt_params( params ); std::tie(model, ctx) = llama_init_from_gpt_params( params );
if ( model == NULL ) if ( model == NULL )
{ {