Add LoRA support (#820)

This commit is contained in:
slaren 2023-04-17 17:28:55 +02:00 committed by GitHub
parent efd05648c8
commit 315a95a4d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 753 additions and 41 deletions

View file

@ -139,6 +139,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.model = argv[i];
} else if (arg == "--lora") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.lora_adapter = argv[i];
params.use_mmap = false;
} else if (arg == "--lora-base") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.lora_base = argv[i];
} else if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
} else if (arg == "--embedding") {
@ -242,6 +255,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
}
fprintf(stderr, " --mtest compute maximum memory usage\n");
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, "\n");

View file

@ -31,11 +31,12 @@ struct gpt_params {
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt = "";
std::string input_prefix = ""; // string to prefix user inputs with
std::string input_prefix = ""; // string to prefix user inputs with
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string lora_adapter = ""; // lora adapter path
std::string lora_base = ""; // base model path for the lora adapter
bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs

View file

@ -114,6 +114,17 @@ int main(int argc, char ** argv) {
}
}
if (!params.lora_adapter.empty()) {
int err = llama_apply_lora_from_file(ctx,
params.lora_adapter.c_str(),
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
return 1;
}
}
// print system information
{
fprintf(stderr, "\n");

View file

@ -134,6 +134,17 @@ int main(int argc, char ** argv) {
}
}
if (!params.lora_adapter.empty()) {
int err = llama_apply_lora_from_file(ctx,
params.lora_adapter.c_str(),
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
params.n_threads);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
return 1;
}
}
// print system information
{
fprintf(stderr, "\n");