improve handling of export-lora arguments

print errors and warnings when files could not be read or created
This commit is contained in:
xaedes 2023-09-24 14:42:52 +02:00
parent da05205af6
commit 2912f17010
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -198,6 +198,17 @@ static bool export_lora_params_parse(int argc, char ** argv, struct export_lora_
exit(1);
}
}
if (params->fn_model_base == default_params.fn_model_base) {
fprintf(stderr, "error: please specify a filename for model-base.\n", arg.c_str());
export_lora_print_usage(argc, argv, &default_params);
exit(1);
}
if (params->fn_model_out == default_params.fn_model_out) {
fprintf(stderr, "error: please specify a filename for model-out.\n", arg.c_str());
export_lora_print_usage(argc, argv, &default_params);
exit(1);
}
if (invalid_param) {
fprintf(stderr, "error: invalid parameter for argument: '%s'\n", arg.c_str());
export_lora_print_usage(argc, argv, &default_params);
@ -206,6 +217,13 @@ static bool export_lora_params_parse(int argc, char ** argv, struct export_lora_
return true;
}
static void free_lora(struct lora_data * lora) {
if (lora->ctx != NULL) {
ggml_free(lora->ctx);
}
delete lora;
}
static struct lora_data * load_lora(struct lora_info * info) {
struct lora_data * result = new struct lora_data;
result->info = *info;
@ -215,7 +233,10 @@ static struct lora_data * load_lora(struct lora_info * info) {
struct llama_file file(info->filename.c_str(), "rb");
if (file.fp == NULL) {
return result;
fprintf(stderr, "warning: Could not open lora adapter '%s'. Ignoring this adapter.\n",
info->filename.c_str());
free_lora(result);
return NULL;
}
struct ggml_init_params params_ggml;
@ -278,12 +299,6 @@ static struct lora_data * load_lora(struct lora_info * info) {
return result;
}
static void free_lora(struct lora_data * lora) {
if (lora->ctx != NULL) {
ggml_free(lora->ctx);
}
delete lora;
}
static struct ggml_cgraph * build_graph_lora(
struct ggml_context * ctx,
@ -304,6 +319,9 @@ static struct ggml_cgraph * build_graph_lora(
}
static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int n_threads) {
if (lora->ctx == NULL) {
return false;
}
std::string name = ggml_get_name(tensor);
std::string name_a = name + std::string(".loraA");
std::string name_b = name + std::string(".loraB");
@ -354,7 +372,19 @@ static void export_lora(struct export_lora_params * params) {
// load all loras
std::vector<struct lora_data *> loras;
for (size_t i = 0; i < params->lora.size(); ++i) {
loras.push_back(load_lora(&params->lora[i]));
struct lora_data * lora = load_lora(&params->lora[i]);
if (lora != NULL) {
loras.push_back(lora);
}
}
if (loras.size() == 0) {
fprintf(stderr, "warning: no lora adapters will be applied.\n");
}
// open input file
struct llama_file fin(params->fn_model_base.c_str(), "rb");
if (!fin.fp) {
die_fmt("Could not open file '%s'\n", params->fn_model_base.c_str());
}
// open base model gguf, read tensors without their data
@ -388,7 +418,6 @@ static void export_lora(struct export_lora_params * params) {
gguf_get_meta_data(gguf_out, meta.data());
fout.write_raw(meta.data(), meta.size());
struct llama_file fin(params->fn_model_base.c_str(), "rb");
std::vector<uint8_t> data;
std::vector<uint8_t> padding;
for (int i=0; i < n_tensors; ++i) {