merge: new input format

This commit is contained in:
ngxson 2024-03-01 14:47:00 +01:00
parent b6da762d68
commit 3e6e3668c9
3 changed files with 303 additions and 204 deletions

View file

@ -10,18 +10,36 @@
#include <cmath> #include <cmath>
#include <algorithm> #include <algorithm>
// usage:
// ./merge ./path/model_1 CONFIG1 ./path/model_2 CONFIG2
//
[[noreturn]] [[noreturn]]
static void usage(const char * executable) { static void usage(const char * executable, int exit_code) {
printf("usage: %s ./path/model_1 CONFIG1 ./path/model_2 CONFIG2 ./path/output\n\n", executable); printf("usage: %s -c CONFIG_FILE -o OUTPUT_FILE -m MODEL_PATH -m MODEL_PATH ...\n\n", executable);
printf(" CONFIG must be in format: p0-p1,p2-p3,p4,... Example: 0-5,7,8-12\n"); printf("\n");
printf(" Optionally, you can specify the scaling for a range of layers, for example: 0-5*0.5,6-7*1. By default, scale will be 0.5. The number of layer start counting from 0.\n"); printf("Merging 2 models and change layers configuration.\n");
printf(" The embedding layer of the first model will be used\n"); printf("Merge config format is CSV, without header, one line represents one layer of the output model, columns in the order below:\n");
printf(" NOTE: currently, only F16 model type is supported\n"); printf("- Model A layer\n");
exit(1); printf("- Model A scale\n");
printf("- Model B layer\n");
printf("- Model B scale\n");
printf("- ...\n");
printf("\n");
printf("For example:\n");
printf("0,1.0,0,0.0 meaning: output layer 0 = A[0]*1.0 + B[0] * 0.0\n");
printf("0,1.0,0,0.0 meaning: output layer 1 = A[0]*1.0 + B[0] * 0.0\n");
printf("1,0.0,2,0.0 meaning: output layer 2 = A[1]*0.0 + B[2] * 0.0\n");
printf("2,0.5,1,0.5 meaning: output layer 3 = A[2]*0.5 + B[1] * 0.5\n");
printf("\n");
printf("NOTE:\n");
printf("- The embedding layer of the first model will be used\n");
printf("- Currently, only F16 model type is supported\n");
printf("\n");
printf("Options:\n");
printf(" -h, --help Show this help message and exit\n");
printf(" -c, --config CONFIG_FILE Path to config file (CSV format)\n");
printf(" -m, --model MODEL_PATH Path to model. This option can be repeated multiple times and must be specified in the right order.\n");
printf(" -o, --output OUTPUT_FILE Path to the output model\n");
printf("\n");
printf("Example: ./merge -c config.csv -o output.gguf -m model_a.gguf -m model_b.gguf\n");
exit(exit_code);
} }
inline std::vector<std::string> str_split(std::string str, const std::string & delimiter) { inline std::vector<std::string> str_split(std::string str, const std::string & delimiter) {
@ -37,82 +55,101 @@ inline std::vector<std::string> str_split(std::string str, const std::string & d
return output; return output;
} }
static std::vector<struct llama_merge_config> parse_config(std::string & input) { static std::vector<struct llama_merge_layer> parse_config(std::string & config_path, size_t n_models, std::vector<int> & buf_srcs, std::vector<float> & buf_scales) {
std::vector<struct llama_merge_config> configs; // read file
auto intervals = str_split(input, ","); std::ifstream file(config_path);
for (auto & interval : intervals) { if (!file.is_open()) {
auto components = str_split(interval, "*"); throw std::runtime_error("Unable to open file merge config file");
if (components.empty()) {
throw std::runtime_error("Config is incorrect");
}
float scale = components.size() == 2
? std::stof(components[1])
: 0.5; // be default
auto p0p1 = str_split(components[0], "-");
if (p0p1.empty()) {
throw std::runtime_error("Layer interval is invalid");
}
int p0 = std::stoi(p0p1[0]);
int p1 = p0p1.size() == 2 ? std::stoi(p0p1[1]) : p0;
if (p0 > p1) {
throw std::runtime_error("Layer interval is invalid, the end layer number is bigger and start layer number (p0 > p1)");
}
for (int i = p0; i <= p1; i++) {
struct llama_merge_config conf{i, scale, scale};
configs.push_back(conf);
}
// TODO: maybe check for overlap intervals?
} }
return configs; std::ostringstream content;
content << file.rdbuf(); // Read the entire file into the stringstream
file.close();
// allocate memory
auto lines = str_split(content.str(), "\n");
buf_srcs.resize(lines.size()*n_models);
buf_scales.resize(lines.size()*n_models);
// process line by line, one line is one layer
std::vector<struct llama_merge_layer> layers;
for (size_t i_layer = 0; i_layer < lines.size(); i_layer++) {
auto columns = str_split(lines[i_layer], ",");
if (columns.size() != n_models*2) {
std::stringstream ss;
ss << "error: line " << i_layer+1 << " is malformed. Expect to have exactly " << n_models*2 << " columns, but got " << columns.size() << " columns";
throw std::runtime_error(ss.str());
}
int * srcs = buf_srcs.data() + i_layer*n_models;
float * scales = buf_scales.data() + i_layer*n_models;
for (size_t i_model = 0; i_model < n_models; i_model++) {
srcs[i_model] = std::stoi(columns[i_model*2]);
scales[i_model] = std::stof(columns[i_model*2 + 1]);
}
layers.push_back(llama_merge_layer{srcs, scales});
}
return layers;
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
llama_backend_init(); bool invalid_param = false;
std::string config_path;
std::vector<std::string> model_paths;
std::string output_path;
if (argc < 6) { std::string arg;
usage(argv[0]); for (int i = 1; i < argc; i++) {
} arg = argv[i];
if (arg == "-h" || arg == "--help") {
std::string fname_model1(argv[1]); usage(argv[0], 0);
std::string config_model1(argv[2]); } else if (arg == "-c" || arg == "--config") {
std::string fname_model2(argv[3]); if (++i >= argc) {
std::string config_model2(argv[4]); invalid_param = true;
std::string fname_output(argv[5]); break;
// TODO: add try catch
auto configs1 = parse_config(config_model1);
auto configs2 = parse_config(config_model2);
std::vector<struct llama_merge_config> configs;
if (configs1.size() != configs2.size()) {
fprintf(stderr, "Number of layers between 2 configs does not match, config1 has %ld layers and config2 has %ld layers\n", configs1.size(), configs2.size());
}
// merge 2 configs
printf("Merge configs:\n");
for (auto c1 : configs1) {
float scale2 = -1;
for (auto c2 : configs2) {
if (c2.i_layer == c1.i_layer) {
scale2 = c2.scale2;
} }
config_path = argv[i];
} else if (arg == "-m" || arg == "--model") {
if (++i >= argc) {
invalid_param = true;
break;
}
model_paths.push_back(argv[i]);
} else if (arg == "-o" || arg == "--output") {
if (++i >= argc) {
invalid_param = true;
break;
}
output_path = argv[i];
} }
if (scale2 < 0) {
fprintf(stderr, "Cannot find config for layer %d in CONFIG2\n", c1.i_layer);
exit(1);
}
struct llama_merge_config conf{c1.i_layer, c1.scale1, scale2};
configs.push_back(conf);
printf(" Layer %d: scale1 = %f, scale2 = %f\n", conf.i_layer, conf.scale1, conf.scale2);
} }
llama_merge_models( if (invalid_param) {
fname_model1.c_str(), throw std::invalid_argument("error: invalid parameter for argument: " + arg);
fname_model2.c_str(), } else if (config_path.empty()) {
configs.data(), throw std::invalid_argument("error: missing config path");
configs.size(), } else if (model_paths.size() < 2) {
fname_output.c_str() throw std::invalid_argument("error: require at least 2 models");
); } else if (output_path.empty()) {
llama_backend_free(); throw std::invalid_argument("error: missing output path");
}
// buffers to hold allocated data
std::vector<int> buf_srcs;
std::vector<float> buf_scales;
auto layers = parse_config(config_path, model_paths.size(), buf_srcs, buf_scales);
std::vector<const char*> p_model_paths;
for (auto & m : model_paths) {
p_model_paths.push_back(m.data());
}
const struct llama_merge_config config{
p_model_paths.data(),
p_model_paths.size(),
layers.data(),
layers.size(),
output_path.data(),
};
llama_merge_models(&config);
return 0;
} }

291
llama.cpp
View file

@ -11309,11 +11309,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
} }
} }
static int32_t llama_merge_models_internal( int32_t llama_merge_models(const struct llama_merge_config * config)
const std::string & fname_inp1,
const std::string & fname_inp2,
const std::vector<const struct llama_merge_config *> & configs,
const std::string & fname_out)
{ {
#if defined(__linux__) || defined(_WIN32) #if defined(__linux__) || defined(_WIN32)
constexpr bool use_mmap = true; constexpr bool use_mmap = true;
@ -11321,36 +11317,102 @@ static int32_t llama_merge_models_internal(
constexpr bool use_mmap = false; constexpr bool use_mmap = false;
#endif #endif
llama_model model1; // std::move doesn't work with llama_model and llama_model_loader, why?
llama_model model2; std::vector<std::unique_ptr<llama_model>> models;
llama_model_loader ml1(fname_inp1, use_mmap, NULL); std::vector<std::unique_ptr<llama_model_loader>> mls;
llama_model_loader ml2(fname_inp2, use_mmap, NULL); std::vector<no_init<uint8_t>> buf_in;
std::vector<no_init<uint8_t>> buf_out;
auto load_model = [](llama_model_loader & ml, llama_model & model) { std::set<std::string> ref_names; // list of ref_name per layer
ml.init_mapping(false); int max_input_layers = 0; // number of layers that the input model has
llm_load_arch(ml, model); std::vector<struct ggml_tensor *> output_tensors;
llm_load_hparams(ml, model);
};
load_model(ml1, model1);
load_model(ml2, model2);
if (model1.hparams != model2.hparams) {
LLAMA_LOG_ERROR("hparams of two models are different, aborting...");
return -1;
}
// output file
struct gguf_context * ctx_out = gguf_init_empty(); struct gguf_context * ctx_out = gguf_init_empty();
std::ofstream fout(fname_out, std::ios::binary); std::ofstream fout(config->output_path, std::ios::binary);
fout.exceptions(std::ofstream::failbit); // fail fast on write errors fout.exceptions(std::ofstream::failbit); // fail fast on write errors
// get layer index from tensor name, for example "blk.x.attn_norm.weight"
// returns -1 if it is non-layer
auto get_i_layer = [&](std::string tensor_name) -> int {
int i_layer = -1;
return sscanf(tensor_name.c_str(), "blk.%d.", &i_layer) == 1 ? i_layer : -1;
};
// get new tensor name by i_layer and ref_name, for example "blk.x.attn_norm.weight"
auto get_name = [&](int i_layer, std::string ref_name) -> std::string {
ref_name.erase(0, ref_name.find(".", 4)); // delete the "blk.x" part
std::stringstream ss;
ss << "blk." << i_layer << ref_name;
return ss.str();
};
// remember to call before exit
auto clean_up = [&]() {
fout.close();
for (auto & tensor : output_tensors) {
free(tensor);
}
gguf_free(ctx_out);
};
// load the input models
for (size_t i = 0; i < config->n_models; i++) {
auto model = std::unique_ptr<llama_model>(new llama_model());
auto ml = std::unique_ptr<llama_model_loader>(new llama_model_loader(config->model_paths[i], use_mmap, NULL));
ml->init_mapping(false);
llm_load_arch(*ml, *model);
llm_load_hparams(*ml, *model);
if (i > 0 && models[i-1]->hparams != model->hparams) {
LLAMA_LOG_ERROR("hparams of input models are different, aborting...");
clean_up();
return -1;
}
models.push_back(std::move(model));
mls.push_back(std::move(ml));
}
// construct metadata
{ {
// copy the KV pairs from the input file // copy the KV pairs from the input file
gguf_set_kv(ctx_out, ml1.ctx_gguf); gguf_set_kv(ctx_out, mls[0]->ctx_gguf);
// populate the original tensors so we get an initial meta data // correct layer count for output model
for (int i = 0; i < ml1.n_tensors; ++i) { // TODO: is this key "llama.block_count" this the same for all architectures?
struct ggml_tensor * meta = ml1.get_tensor_meta(i); gguf_set_val_u32(ctx_out, "llama.block_count", config->n_layers);
gguf_add_tensor(ctx_out, meta);
// read input layers
for (int i = 0; i < mls[0]->n_tensors; i++) {
struct ggml_tensor * meta = mls[0]->get_tensor_meta(i);
int i_layer = get_i_layer(ggml_get_name(meta));
if (i_layer < 0) {
// populate data for non-layers tensor
struct ggml_tensor * out_tensor = (struct ggml_tensor *) malloc(GGML_TENSOR_SIZE);
memcpy(out_tensor, meta, GGML_TENSOR_SIZE); // copy metadata (shape, type,...)
gguf_add_tensor(ctx_out, out_tensor);
output_tensors.push_back(out_tensor);
} else {
max_input_layers = std::max(i_layer, max_input_layers);
if (i_layer == 0) {
// only extract names of one layer, assuming all layers have the same structure
ref_names.insert(ggml_get_name(meta));
}
}
}
// populate layers metadata for output model
for (size_t i_layer = 0; i_layer < config->n_layers; i_layer++) {
for (auto & ref_name : ref_names) {
// create new tensor, because new model may have more tensors than input model
struct ggml_tensor * out_tensor = (struct ggml_tensor *) malloc(GGML_TENSOR_SIZE);
struct ggml_tensor * ref_tensor = mls[0]->get_tensor_meta(ref_name.c_str()); // get ref tensor from layer 0
memcpy(out_tensor, ref_tensor, GGML_TENSOR_SIZE); // copy metadata (shape, type,...)
ggml_set_name(out_tensor, get_name(i_layer, ref_name).c_str()); // set the correct name (with correct i_layer)
output_tensors.push_back(out_tensor);
gguf_add_tensor(ctx_out, out_tensor);
LLAMA_LOG_INFO("%s\n", ggml_get_name(out_tensor));
}
} }
const size_t meta_size = gguf_get_meta_size(ctx_out); const size_t meta_size = gguf_get_meta_size(ctx_out);
@ -11361,6 +11423,7 @@ static int32_t llama_merge_models_internal(
::zeros(fout, meta_size); ::zeros(fout, meta_size);
} }
// load tensor data into buffer
auto read_tensor_data = [&](struct ggml_tensor * tensor, llama_model_loader & ml, std::vector<no_init<uint8_t>> & buf) -> size_t { auto read_tensor_data = [&](struct ggml_tensor * tensor, llama_model_loader & ml, std::vector<no_init<uint8_t>> & buf) -> size_t {
if (!ml.use_mmap) { if (!ml.use_mmap) {
if (buf.size() < ggml_nbytes(tensor)) { if (buf.size() < ggml_nbytes(tensor)) {
@ -11372,86 +11435,103 @@ static int32_t llama_merge_models_internal(
return ggml_nbytes(tensor); return ggml_nbytes(tensor);
}; };
// map tensor name to its index for ml2 // TODO: maybe we should use ggml_add and ggml_scale? and how?
std::unordered_map<std::string, int> ml2_name_to_idx; auto calc_output_tensor = [&](enum ggml_type type, std::vector<no_init<uint8_t>> & in_buf, float scale, std::vector<no_init<uint8_t>> & out_buf) {
for (int i = 0; i < ml2.n_tensors; ++i) { GGML_ASSERT(in_buf.size() == out_buf.size());
struct ggml_tensor * tensor = ml1.get_tensor_meta(i); if (type == GGML_TYPE_F16) {
const std::string name = ggml_get_name(tensor); GGML_ASSERT(in_buf.size() % sizeof(ggml_fp16_t) == 0);
ml2_name_to_idx[name] = i; for (size_t i = 0; i < in_buf.size() / sizeof(ggml_fp16_t); i++) {
} ggml_fp16_t * in = (ggml_fp16_t *) in_buf.data();
ggml_fp16_t * dest = (ggml_fp16_t *) out_buf.data();
auto get_config_for_layer = [&](int i_layer) -> const struct llama_merge_config* { float in_dequant = ggml_fp16_to_fp32(in[i]);
for (auto & conf : configs) { float res = in_dequant * scale;
if (conf->i_layer == i_layer) { dest[i] = ggml_fp32_to_fp16(res);
return conf;
} }
} else if (type == GGML_TYPE_F32) {
GGML_ASSERT(in_buf.size() % sizeof(float) == 0);
for (size_t i = 0; i < in_buf.size() / sizeof(float); i++) {
float * in = (float *) in_buf.data();
float * dest = (float *) out_buf.data();
dest[i] = in[i] * scale;
}
} else {
LLAMA_LOG_ERROR("Only GGML_TYPE_F16 or GGML_TYPE_F32 is supported, current type = %s\n", ggml_type_name(type));
return -1; // return of lambda, no need clean up
} }
LLAMA_LOG_ERROR("Cannot find llama_merge_config for i_layer=%d\n", i_layer); return 0; // return of lambda, no need clean up
return nullptr;
}; };
// process layers size_t n_done = 0;
for (int i = 0; i < ml1.n_tensors; ++i) { // process non-layer output tensor
struct ggml_tensor * tensor1 = ml1.get_tensor_meta(i); for (auto & out_tensor : output_tensors) {
std::vector<no_init<uint8_t>> buf1; std::string name = ggml_get_name(out_tensor);
const std::string name = ggml_get_name(tensor1); int i_layer_out = get_i_layer(name.c_str());
const size_t tensor_size = ggml_nbytes(tensor1); std::vector<no_init<uint8_t>> buf(ggml_nbytes(out_tensor));
if (i_layer_out >= 0) {
int idx_ml2 = ml2_name_to_idx[name]; continue;
std::vector<no_init<uint8_t>> buf2; }
struct ggml_tensor * tensor2 = ml2.get_tensor_meta(idx_ml2); struct ggml_tensor * in_tensor = mls[0]->get_tensor_meta(name.c_str());
if (in_tensor == nullptr) {
// GGML_TYPE_F16 LLAMA_LOG_ERROR("Cannot find layer name %s from base model\n", name.c_str());
std::vector<no_init<uint8_t>> result(tensor_size); clean_up();
if (llama_format_tensor_shape(tensor1) != llama_format_tensor_shape(tensor2)) {
LLAMA_LOG_ERROR("Tensor shapes are different\n");
return -1; return -1;
} }
read_tensor_data(in_tensor, *mls[0], buf); // read from first model
int i_layer = -1; n_done++;
if (sscanf(name.c_str(), "blk.%d.", &i_layer) != 1) { LLAMA_LOG_INFO("[%4ld/%4ld] %36s - [%s], type = %6s\n",
// non-layer, simply copy n_done, output_tensors.size(),
read_tensor_data(tensor1, ml1, buf1); name.c_str(),
memcpy(result.data(), tensor1->data, tensor_size); llama_format_tensor_shape(out_tensor).c_str(),
} else { ggml_type_name(out_tensor->type));
auto conf = get_config_for_layer(i_layer);
read_tensor_data(tensor1, ml1, buf1);
read_tensor_data(tensor2, ml2, buf2);
LLAMA_LOG_INFO("Merge layer %d with scale1 = %f, scale2 = %f\n", i_layer, conf->scale1, conf->scale2);
if (tensor1->type == GGML_TYPE_F16 && tensor2->type == GGML_TYPE_F16) { // write tensor data + padding
for (size_t i = 0; i < result.size() / sizeof(float); i++) { fout.write((const char *) buf.data(), buf.size());
ggml_fp16_t * t1 = (ggml_fp16_t *) tensor1->data; zeros(fout, GGML_PAD(buf.size(), GGUF_DEFAULT_ALIGNMENT) - buf.size());
ggml_fp16_t * t2 = (ggml_fp16_t *) tensor2->data; }
ggml_fp16_t * dest = (ggml_fp16_t *) result.data();
float dequant1 = ggml_fp16_to_fp32(t1[i]); // process layer output tensor
float dequant2 = ggml_fp16_to_fp32(t2[2]); for (auto & out_tensor : output_tensors) {
float res = dequant1 * conf->scale1 + dequant2 * conf->scale2; std::vector<no_init<uint8_t>> in_buf(ggml_nbytes(out_tensor));
dest[i] = ggml_fp32_to_fp16(res); std::vector<no_init<uint8_t>> out_buf(ggml_nbytes(out_tensor));
}
} else if (tensor1->type == GGML_TYPE_F32 && tensor2->type == GGML_TYPE_F32) { std::string out_name = ggml_get_name(out_tensor);
for (size_t i = 0; i < result.size() / sizeof(double); i++) { int i_layer_out = get_i_layer(out_name.c_str());
float * t1 = (float *) tensor1->data; auto layer = config->layers[i_layer_out];
float * t2 = (float *) tensor2->data;
float * dest = (float *) result.data(); if (i_layer_out < 0) {
dest[i] = t1[i] * conf->scale1 + t2[i] * conf->scale2; continue;
} }
} else {
LLAMA_LOG_ERROR("Only GGML_TYPE_F16 or GGML_TYPE_F32 is supported, current type = %s\n", ggml_type_name(tensor1->type)); for (size_t i_model = 0; i_model < config->n_models; i_model++) {
int src_layer = layer.srcs[i_model]; // source layer
float scale = layer.scales[i_model];
std::string src_name = get_name(src_layer, out_name); // find the correct tensor based on src_layer
struct ggml_tensor * in_tensor = mls[i_model]->get_tensor_meta(src_name.c_str());
int res;
if (in_tensor == nullptr) {
LLAMA_LOG_ERROR("Cannot find layer name %s from model %ld\n", src_name.c_str(), i_model + 1);
clean_up();
return -1; return -1;
} }
read_tensor_data(in_tensor, *mls[i_model], in_buf);
res = calc_output_tensor(in_tensor->type, in_buf, scale, out_buf);
if (res < 0) {
clean_up();
return res;
}
} }
LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s\n", n_done++;
i + 1, ml1.n_tensors, LLAMA_LOG_INFO("[%4ld/%4ld] %36s - [%s], type = %6s\n",
ggml_get_name(tensor1), n_done, output_tensors.size(),
llama_format_tensor_shape(tensor1).c_str(), out_name.c_str(),
ggml_type_name(tensor1->type)); llama_format_tensor_shape(out_tensor).c_str(),
ggml_type_name(out_tensor->type));
// write tensor data + padding // write tensor data + padding
fout.write((const char *) result.data(), tensor_size); fout.write((const char *) out_buf.data(), out_buf.size());
zeros(fout, GGML_PAD(tensor_size, GGUF_DEFAULT_ALIGNMENT) - tensor_size); zeros(fout, GGML_PAD(out_buf.size(), GGUF_DEFAULT_ALIGNMENT) - out_buf.size());
} }
// go back to beginning of file and write the updated meta data // go back to beginning of file and write the updated meta data
@ -11462,31 +11542,10 @@ static int32_t llama_merge_models_internal(
fout.write((const char *) data.data(), data.size()); fout.write((const char *) data.data(), data.size());
} }
fout.close(); clean_up();
gguf_free(ctx_out);
return 0; return 0;
} }
int32_t llama_merge_models(
const char * fname_inp1,
const char * fname_inp2,
const struct llama_merge_config * configs,
const int n_configs,
const char * fname_out)
{
std::vector<const struct llama_merge_config *> v_configs(n_configs);
for (int i = 0; i < n_configs; i++) {
v_configs[i] = &configs[i];
}
return llama_merge_models_internal(
fname_inp1,
fname_inp2,
v_configs,
fname_out
);
}
static int llama_apply_lora_from_file_internal( static int llama_apply_lora_from_file_internal(
const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
) { ) {

21
llama.h
View file

@ -327,11 +327,18 @@ extern "C" {
const char * content; const char * content;
} llama_chat_message; } llama_chat_message;
// used to merge models
struct llama_merge_layer {
const int * srcs; // contains n_models elements
const float * scales; // contains n_models elements
};
struct llama_merge_config { struct llama_merge_config {
const int i_layer; const char ** model_paths;
const float scale1; const size_t n_models;
const float scale2; const struct llama_merge_layer * layers;
// TODO add support for embeding and output layers const size_t n_layers;
const char * output_path;
}; };
// Helpers for getting default parameters // Helpers for getting default parameters
@ -422,11 +429,7 @@ extern "C" {
const llama_model_quantize_params * params); const llama_model_quantize_params * params);
LLAMA_API int32_t llama_merge_models( LLAMA_API int32_t llama_merge_models(
const char * fname_inp1, const struct llama_merge_config * config);
const char * fname_inp2,
const struct llama_merge_config * configs,
const int n_configs,
const char * fname_out);
// Apply a LoRA adapter to a loaded model // Apply a LoRA adapter to a loaded model
// path_base_model is the path to a higher quality model to use as a base for // path_base_model is the path to a higher quality model to use as a base for