mixing multiple LORA adapters is now possible

pass more than one '--lora FNAME' argument to apply more than one LORA.
use '--lora-scaled FNAME S' when you want to specify a user-defined scale for an adapter.
This commit is contained in:
xaedes 2023-08-20 18:36:20 +02:00
parent 37dfb544aa
commit d61ed6b431
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
5 changed files with 49 additions and 13 deletions

View file

@ -310,7 +310,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.lora_adapter = argv[i]; params.lora_adapter.push_back({argv[i], 1.0f});
params.use_mmap = false;
} else if (arg == "--lora-scaled") {
if (++i >= argc) {
invalid_param = true;
break;
}
const char * lora_adapter = argv[i];
if (++i >= argc) {
invalid_param = true;
break;
}
params.lora_adapter.push_back({lora_adapter, std::stof(argv[i])});
params.use_mmap = false; params.use_mmap = false;
} else if (arg == "--lora-base") { } else if (arg == "--lora-base") {
if (++i >= argc) { if (++i >= argc) {
@ -601,6 +613,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " --verbose-prompt print prompt before generation\n"); fprintf(stdout, " --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
fprintf(stdout, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); fprintf(stdout, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
fprintf(stdout, " --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
fprintf(stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); fprintf(stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
fprintf(stdout, " -m FNAME, --model FNAME\n"); fprintf(stdout, " -m FNAME, --model FNAME\n");
fprintf(stdout, " model path (default: %s)\n", params.model.c_str()); fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
@ -677,10 +690,15 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
return std::make_tuple(nullptr, nullptr); return std::make_tuple(nullptr, nullptr);
} }
if (!params.lora_adapter.empty()) { for (int i = 0; i < params.lora_adapter.size(); ++i) {
const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]);
float lora_scale = std::get<1>(params.lora_adapter[i]);
int err = llama_model_apply_lora_from_file(model, int err = llama_model_apply_lora_from_file(model,
params.lora_adapter.c_str(), lora_adapter.c_str(),
params.lora_base.empty() ? NULL : params.lora_base.c_str(), lora_scale,
((i > 0) || params.lora_base.empty())
? NULL
: params.lora_base.c_str(),
params.n_threads); params.n_threads);
if (err != 0) { if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);

View file

@ -62,8 +62,8 @@ struct gpt_params {
std::string grammar = ""; // optional BNF-like grammar to constrain sampling std::string grammar = ""; // optional BNF-like grammar to constrain sampling
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string lora_adapter = ""; // lora adapter path std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
std::string lora_base = ""; // base model path for the lora adapter std::string lora_base = ""; // base model path for the lora adapter
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score

View file

@ -869,7 +869,23 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
invalid_param = true; invalid_param = true;
break; break;
} }
params.lora_adapter = argv[i]; params.lora_adapter.push_back({argv[i], 1.0f});
params.use_mmap = false;
}
else if (arg == "--lora-scaled")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
const char * lora_adapter = argv[i];
if (++i >= argc)
{
invalid_param = true;
break;
}
params.lora_adapter.push_back({lora_adapter, std::stof(argv[i])});
params.use_mmap = false; params.use_mmap = false;
} }
else if (arg == "--lora-base") else if (arg == "--lora-base")

View file

@ -3401,7 +3401,7 @@ int llama_model_quantize(
} }
} }
int llama_apply_lora_from_file_internal(const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads) { int llama_apply_lora_from_file_internal(const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads) {
fprintf(stderr, "%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); fprintf(stderr, "%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
const int64_t t_start_lora_us = ggml_time_us(); const int64_t t_start_lora_us = ggml_time_us();
@ -3433,7 +3433,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
int32_t lora_alpha; int32_t lora_alpha;
fin.read((char *) &lora_r, sizeof(lora_r)); fin.read((char *) &lora_r, sizeof(lora_r));
fin.read((char *) &lora_alpha, sizeof(lora_alpha)); fin.read((char *) &lora_alpha, sizeof(lora_alpha));
float scaling = (float)lora_alpha / (float)lora_r; float scaling = scale * (float)lora_alpha / (float)lora_r;
fprintf(stderr, "%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling); fprintf(stderr, "%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
@ -3682,18 +3682,18 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
return 0; return 0;
} }
int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) { int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, float scale, const char * path_base_model, int n_threads) {
try { try {
return llama_apply_lora_from_file_internal(ctx->model, path_lora, path_base_model, n_threads); return llama_apply_lora_from_file_internal(ctx->model, path_lora, scale, path_base_model, n_threads);
} catch (const std::exception & err) { } catch (const std::exception & err) {
fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.what()); fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.what());
return 1; return 1;
} }
} }
int llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, const char * path_base_model, int n_threads) { int llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int n_threads) {
try { try {
return llama_apply_lora_from_file_internal(*model, path_lora, path_base_model, n_threads); return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads);
} catch (const std::exception & err) { } catch (const std::exception & err) {
fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.what()); fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.what());
return 1; return 1;

View file

@ -249,6 +249,7 @@ extern "C" {
LLAMA_API DEPRECATED(int llama_apply_lora_from_file( LLAMA_API DEPRECATED(int llama_apply_lora_from_file(
struct llama_context * ctx, struct llama_context * ctx,
const char * path_lora, const char * path_lora,
float scale,
const char * path_base_model, const char * path_base_model,
int n_threads), int n_threads),
"please use llama_model_apply_lora_from_file instead"); "please use llama_model_apply_lora_from_file instead");
@ -256,6 +257,7 @@ extern "C" {
LLAMA_API int llama_model_apply_lora_from_file( LLAMA_API int llama_model_apply_lora_from_file(
const struct llama_model * model, const struct llama_model * model,
const char * path_lora, const char * path_lora,
float scale,
const char * path_base_model, const char * path_base_model,
int n_threads); int n_threads);