From 3e6e3668c94edbff2e4b9a023e2eb0d964e7ea20 Mon Sep 17 00:00:00 2001 From: ngxson Date: Fri, 1 Mar 2024 14:47:00 +0100 Subject: [PATCH] merge: new input format --- examples/merge/merge.cpp | 197 +++++++++++++++----------- llama.cpp | 289 +++++++++++++++++++++++---------------- llama.h | 21 +-- 3 files changed, 303 insertions(+), 204 deletions(-) diff --git a/examples/merge/merge.cpp b/examples/merge/merge.cpp index 1cb0f6f56..024ae5692 100644 --- a/examples/merge/merge.cpp +++ b/examples/merge/merge.cpp @@ -10,18 +10,36 @@ #include #include - -// usage: -// ./merge ./path/model_1 CONFIG1 ./path/model_2 CONFIG2 -// [[noreturn]] -static void usage(const char * executable) { - printf("usage: %s ./path/model_1 CONFIG1 ./path/model_2 CONFIG2 ./path/output\n\n", executable); - printf(" CONFIG must be in format: p0-p1,p2-p3,p4,... Example: 0-5,7,8-12\n"); - printf(" Optionally, you can specify the scaling for a range of layers, for example: 0-5*0.5,6-7*1. By default, scale will be 0.5. The number of layer start counting from 0.\n"); - printf(" The embedding layer of the first model will be used\n"); - printf(" NOTE: currently, only F16 model type is supported\n"); - exit(1); +static void usage(const char * executable, int exit_code) { + printf("usage: %s -c CONFIG_FILE -o OUTPUT_FILE -m MODEL_PATH -m MODEL_PATH ...\n\n", executable); + printf("\n"); + printf("Merging 2 models and change layers configuration.\n"); + printf("Merge config format is CSV, without header, one line represents one layer of the output model, columns in the order below:\n"); + printf("- Model A layer\n"); + printf("- Model A scale\n"); + printf("- Model B layer\n"); + printf("- Model B scale\n"); + printf("- ...\n"); + printf("\n"); + printf("For example:\n"); + printf("0,1.0,0,0.0 meaning: output layer 0 = A[0]*1.0 + B[0] * 0.0\n"); + printf("0,1.0,0,0.0 meaning: output layer 1 = A[0]*1.0 + B[0] * 0.0\n"); + printf("1,0.0,2,0.0 meaning: output layer 2 = A[1]*0.0 + B[2] * 0.0\n"); + printf("2,0.5,1,0.5 meaning: output layer 3 = A[2]*0.5 + B[1] * 0.5\n"); + printf("\n"); + printf("NOTE:\n"); + printf("- The embedding layer of the first model will be used\n"); + printf("- Currently, only F16 model type is supported\n"); + printf("\n"); + printf("Options:\n"); + printf(" -h, --help Show this help message and exit\n"); + printf(" -c, --config CONFIG_FILE Path to config file (CSV format)\n"); + printf(" -m, --model MODEL_PATH Path to model. This option can be repeated multiple times and must be specified in the right order.\n"); + printf(" -o, --output OUTPUT_FILE Path to the output model\n"); + printf("\n"); + printf("Example: ./merge -c config.csv -o output.gguf -m model_a.gguf -m model_b.gguf\n"); + exit(exit_code); } inline std::vector str_split(std::string str, const std::string & delimiter) { @@ -37,82 +55,101 @@ inline std::vector str_split(std::string str, const std::string & d return output; } -static std::vector parse_config(std::string & input) { - std::vector configs; - auto intervals = str_split(input, ","); - for (auto & interval : intervals) { - auto components = str_split(interval, "*"); - if (components.empty()) { - throw std::runtime_error("Config is incorrect"); - } - float scale = components.size() == 2 - ? std::stof(components[1]) - : 0.5; // be default - auto p0p1 = str_split(components[0], "-"); - if (p0p1.empty()) { - throw std::runtime_error("Layer interval is invalid"); - } - int p0 = std::stoi(p0p1[0]); - int p1 = p0p1.size() == 2 ? std::stoi(p0p1[1]) : p0; - if (p0 > p1) { - throw std::runtime_error("Layer interval is invalid, the end layer number is bigger and start layer number (p0 > p1)"); - } - for (int i = p0; i <= p1; i++) { - struct llama_merge_config conf{i, scale, scale}; - configs.push_back(conf); - } - // TODO: maybe check for overlap intervals? +static std::vector parse_config(std::string & config_path, size_t n_models, std::vector & buf_srcs, std::vector & buf_scales) { + // read file + std::ifstream file(config_path); + if (!file.is_open()) { + throw std::runtime_error("Unable to open file merge config file"); } - return configs; + std::ostringstream content; + content << file.rdbuf(); // Read the entire file into the stringstream + file.close(); + + // allocate memory + auto lines = str_split(content.str(), "\n"); + buf_srcs.resize(lines.size()*n_models); + buf_scales.resize(lines.size()*n_models); + + // process line by line, one line is one layer + std::vector layers; + for (size_t i_layer = 0; i_layer < lines.size(); i_layer++) { + auto columns = str_split(lines[i_layer], ","); + if (columns.size() != n_models*2) { + std::stringstream ss; + ss << "error: line " << i_layer+1 << " is malformed. Expect to have exactly " << n_models*2 << " columns, but got " << columns.size() << " columns"; + throw std::runtime_error(ss.str()); + } + int * srcs = buf_srcs.data() + i_layer*n_models; + float * scales = buf_scales.data() + i_layer*n_models; + for (size_t i_model = 0; i_model < n_models; i_model++) { + srcs[i_model] = std::stoi(columns[i_model*2]); + scales[i_model] = std::stof(columns[i_model*2 + 1]); + } + layers.push_back(llama_merge_layer{srcs, scales}); + } + return layers; } int main(int argc, char ** argv) { - llama_backend_init(); + bool invalid_param = false; + std::string config_path; + std::vector model_paths; + std::string output_path; - if (argc < 6) { - usage(argv[0]); - } - - std::string fname_model1(argv[1]); - std::string config_model1(argv[2]); - std::string fname_model2(argv[3]); - std::string config_model2(argv[4]); - std::string fname_output(argv[5]); - - // TODO: add try catch - auto configs1 = parse_config(config_model1); - auto configs2 = parse_config(config_model2); - std::vector configs; - - if (configs1.size() != configs2.size()) { - fprintf(stderr, "Number of layers between 2 configs does not match, config1 has %ld layers and config2 has %ld layers\n", configs1.size(), configs2.size()); - } - - // merge 2 configs - printf("Merge configs:\n"); - for (auto c1 : configs1) { - float scale2 = -1; - for (auto c2 : configs2) { - if (c2.i_layer == c1.i_layer) { - scale2 = c2.scale2; + std::string arg; + for (int i = 1; i < argc; i++) { + arg = argv[i]; + if (arg == "-h" || arg == "--help") { + usage(argv[0], 0); + } else if (arg == "-c" || arg == "--config") { + if (++i >= argc) { + invalid_param = true; + break; } + config_path = argv[i]; + } else if (arg == "-m" || arg == "--model") { + if (++i >= argc) { + invalid_param = true; + break; + } + model_paths.push_back(argv[i]); + } else if (arg == "-o" || arg == "--output") { + if (++i >= argc) { + invalid_param = true; + break; + } + output_path = argv[i]; } - if (scale2 < 0) { - fprintf(stderr, "Cannot find config for layer %d in CONFIG2\n", c1.i_layer); - exit(1); - } - struct llama_merge_config conf{c1.i_layer, c1.scale1, scale2}; - configs.push_back(conf); - - printf(" Layer %d: scale1 = %f, scale2 = %f\n", conf.i_layer, conf.scale1, conf.scale2); } - llama_merge_models( - fname_model1.c_str(), - fname_model2.c_str(), - configs.data(), - configs.size(), - fname_output.c_str() - ); - llama_backend_free(); + if (invalid_param) { + throw std::invalid_argument("error: invalid parameter for argument: " + arg); + } else if (config_path.empty()) { + throw std::invalid_argument("error: missing config path"); + } else if (model_paths.size() < 2) { + throw std::invalid_argument("error: require at least 2 models"); + } else if (output_path.empty()) { + throw std::invalid_argument("error: missing output path"); + } + + // buffers to hold allocated data + std::vector buf_srcs; + std::vector buf_scales; + + auto layers = parse_config(config_path, model_paths.size(), buf_srcs, buf_scales); + std::vector p_model_paths; + for (auto & m : model_paths) { + p_model_paths.push_back(m.data()); + } + const struct llama_merge_config config{ + p_model_paths.data(), + p_model_paths.size(), + layers.data(), + layers.size(), + output_path.data(), + }; + + llama_merge_models(&config); + + return 0; } \ No newline at end of file diff --git a/llama.cpp b/llama.cpp index 382abf9a2..f48cc6935 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11309,11 +11309,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } -static int32_t llama_merge_models_internal( - const std::string & fname_inp1, - const std::string & fname_inp2, - const std::vector & configs, - const std::string & fname_out) +int32_t llama_merge_models(const struct llama_merge_config * config) { #if defined(__linux__) || defined(_WIN32) constexpr bool use_mmap = true; @@ -11321,36 +11317,102 @@ static int32_t llama_merge_models_internal( constexpr bool use_mmap = false; #endif - llama_model model1; - llama_model model2; - llama_model_loader ml1(fname_inp1, use_mmap, NULL); - llama_model_loader ml2(fname_inp2, use_mmap, NULL); + // std::move doesn't work with llama_model and llama_model_loader, why? + std::vector> models; + std::vector> mls; + std::vector> buf_in; + std::vector> buf_out; + std::set ref_names; // list of ref_name per layer + int max_input_layers = 0; // number of layers that the input model has + std::vector output_tensors; - auto load_model = [](llama_model_loader & ml, llama_model & model) { - ml.init_mapping(false); - llm_load_arch(ml, model); - llm_load_hparams(ml, model); + // output file + struct gguf_context * ctx_out = gguf_init_empty(); + std::ofstream fout(config->output_path, std::ios::binary); + fout.exceptions(std::ofstream::failbit); // fail fast on write errors + + // get layer index from tensor name, for example "blk.x.attn_norm.weight" + // returns -1 if it is non-layer + auto get_i_layer = [&](std::string tensor_name) -> int { + int i_layer = -1; + return sscanf(tensor_name.c_str(), "blk.%d.", &i_layer) == 1 ? i_layer : -1; }; - load_model(ml1, model1); - load_model(ml2, model2); - if (model1.hparams != model2.hparams) { - LLAMA_LOG_ERROR("hparams of two models are different, aborting..."); - return -1; + // get new tensor name by i_layer and ref_name, for example "blk.x.attn_norm.weight" + auto get_name = [&](int i_layer, std::string ref_name) -> std::string { + ref_name.erase(0, ref_name.find(".", 4)); // delete the "blk.x" part + std::stringstream ss; + ss << "blk." << i_layer << ref_name; + return ss.str(); + }; + + // remember to call before exit + auto clean_up = [&]() { + fout.close(); + for (auto & tensor : output_tensors) { + free(tensor); + } + gguf_free(ctx_out); + }; + + // load the input models + for (size_t i = 0; i < config->n_models; i++) { + auto model = std::unique_ptr(new llama_model()); + auto ml = std::unique_ptr(new llama_model_loader(config->model_paths[i], use_mmap, NULL)); + ml->init_mapping(false); + llm_load_arch(*ml, *model); + llm_load_hparams(*ml, *model); + + if (i > 0 && models[i-1]->hparams != model->hparams) { + LLAMA_LOG_ERROR("hparams of input models are different, aborting..."); + clean_up(); + return -1; + } + + models.push_back(std::move(model)); + mls.push_back(std::move(ml)); } - struct gguf_context * ctx_out = gguf_init_empty(); - std::ofstream fout(fname_out, std::ios::binary); - fout.exceptions(std::ofstream::failbit); // fail fast on write errors - + // construct metadata { // copy the KV pairs from the input file - gguf_set_kv(ctx_out, ml1.ctx_gguf); + gguf_set_kv(ctx_out, mls[0]->ctx_gguf); + + // correct layer count for output model + // TODO: is this key "llama.block_count" this the same for all architectures? + gguf_set_val_u32(ctx_out, "llama.block_count", config->n_layers); - // populate the original tensors so we get an initial meta data - for (int i = 0; i < ml1.n_tensors; ++i) { - struct ggml_tensor * meta = ml1.get_tensor_meta(i); - gguf_add_tensor(ctx_out, meta); + // read input layers + for (int i = 0; i < mls[0]->n_tensors; i++) { + struct ggml_tensor * meta = mls[0]->get_tensor_meta(i); + int i_layer = get_i_layer(ggml_get_name(meta)); + if (i_layer < 0) { + // populate data for non-layers tensor + struct ggml_tensor * out_tensor = (struct ggml_tensor *) malloc(GGML_TENSOR_SIZE); + memcpy(out_tensor, meta, GGML_TENSOR_SIZE); // copy metadata (shape, type,...) + gguf_add_tensor(ctx_out, out_tensor); + output_tensors.push_back(out_tensor); + } else { + max_input_layers = std::max(i_layer, max_input_layers); + if (i_layer == 0) { + // only extract names of one layer, assuming all layers have the same structure + ref_names.insert(ggml_get_name(meta)); + } + } + } + + // populate layers metadata for output model + for (size_t i_layer = 0; i_layer < config->n_layers; i_layer++) { + for (auto & ref_name : ref_names) { + // create new tensor, because new model may have more tensors than input model + struct ggml_tensor * out_tensor = (struct ggml_tensor *) malloc(GGML_TENSOR_SIZE); + struct ggml_tensor * ref_tensor = mls[0]->get_tensor_meta(ref_name.c_str()); // get ref tensor from layer 0 + memcpy(out_tensor, ref_tensor, GGML_TENSOR_SIZE); // copy metadata (shape, type,...) + ggml_set_name(out_tensor, get_name(i_layer, ref_name).c_str()); // set the correct name (with correct i_layer) + output_tensors.push_back(out_tensor); + gguf_add_tensor(ctx_out, out_tensor); + LLAMA_LOG_INFO("%s\n", ggml_get_name(out_tensor)); + } } const size_t meta_size = gguf_get_meta_size(ctx_out); @@ -11361,6 +11423,7 @@ static int32_t llama_merge_models_internal( ::zeros(fout, meta_size); } + // load tensor data into buffer auto read_tensor_data = [&](struct ggml_tensor * tensor, llama_model_loader & ml, std::vector> & buf) -> size_t { if (!ml.use_mmap) { if (buf.size() < ggml_nbytes(tensor)) { @@ -11372,86 +11435,103 @@ static int32_t llama_merge_models_internal( return ggml_nbytes(tensor); }; - // map tensor name to its index for ml2 - std::unordered_map ml2_name_to_idx; - for (int i = 0; i < ml2.n_tensors; ++i) { - struct ggml_tensor * tensor = ml1.get_tensor_meta(i); - const std::string name = ggml_get_name(tensor); - ml2_name_to_idx[name] = i; - } - - auto get_config_for_layer = [&](int i_layer) -> const struct llama_merge_config* { - for (auto & conf : configs) { - if (conf->i_layer == i_layer) { - return conf; + // TODO: maybe we should use ggml_add and ggml_scale? and how? + auto calc_output_tensor = [&](enum ggml_type type, std::vector> & in_buf, float scale, std::vector> & out_buf) { + GGML_ASSERT(in_buf.size() == out_buf.size()); + if (type == GGML_TYPE_F16) { + GGML_ASSERT(in_buf.size() % sizeof(ggml_fp16_t) == 0); + for (size_t i = 0; i < in_buf.size() / sizeof(ggml_fp16_t); i++) { + ggml_fp16_t * in = (ggml_fp16_t *) in_buf.data(); + ggml_fp16_t * dest = (ggml_fp16_t *) out_buf.data(); + float in_dequant = ggml_fp16_to_fp32(in[i]); + float res = in_dequant * scale; + dest[i] = ggml_fp32_to_fp16(res); } + } else if (type == GGML_TYPE_F32) { + GGML_ASSERT(in_buf.size() % sizeof(float) == 0); + for (size_t i = 0; i < in_buf.size() / sizeof(float); i++) { + float * in = (float *) in_buf.data(); + float * dest = (float *) out_buf.data(); + dest[i] = in[i] * scale; + } + } else { + LLAMA_LOG_ERROR("Only GGML_TYPE_F16 or GGML_TYPE_F32 is supported, current type = %s\n", ggml_type_name(type)); + return -1; // return of lambda, no need clean up } - LLAMA_LOG_ERROR("Cannot find llama_merge_config for i_layer=%d\n", i_layer); - return nullptr; + return 0; // return of lambda, no need clean up }; - // process layers - for (int i = 0; i < ml1.n_tensors; ++i) { - struct ggml_tensor * tensor1 = ml1.get_tensor_meta(i); - std::vector> buf1; - const std::string name = ggml_get_name(tensor1); - const size_t tensor_size = ggml_nbytes(tensor1); - - int idx_ml2 = ml2_name_to_idx[name]; - std::vector> buf2; - struct ggml_tensor * tensor2 = ml2.get_tensor_meta(idx_ml2); - - // GGML_TYPE_F16 - std::vector> result(tensor_size); - - if (llama_format_tensor_shape(tensor1) != llama_format_tensor_shape(tensor2)) { - LLAMA_LOG_ERROR("Tensor shapes are different\n"); + size_t n_done = 0; + // process non-layer output tensor + for (auto & out_tensor : output_tensors) { + std::string name = ggml_get_name(out_tensor); + int i_layer_out = get_i_layer(name.c_str()); + std::vector> buf(ggml_nbytes(out_tensor)); + if (i_layer_out >= 0) { + continue; + } + struct ggml_tensor * in_tensor = mls[0]->get_tensor_meta(name.c_str()); + if (in_tensor == nullptr) { + LLAMA_LOG_ERROR("Cannot find layer name %s from base model\n", name.c_str()); + clean_up(); return -1; } + read_tensor_data(in_tensor, *mls[0], buf); // read from first model - int i_layer = -1; - if (sscanf(name.c_str(), "blk.%d.", &i_layer) != 1) { - // non-layer, simply copy - read_tensor_data(tensor1, ml1, buf1); - memcpy(result.data(), tensor1->data, tensor_size); - } else { - auto conf = get_config_for_layer(i_layer); - read_tensor_data(tensor1, ml1, buf1); - read_tensor_data(tensor2, ml2, buf2); - LLAMA_LOG_INFO("Merge layer %d with scale1 = %f, scale2 = %f\n", i_layer, conf->scale1, conf->scale2); + n_done++; + LLAMA_LOG_INFO("[%4ld/%4ld] %36s - [%s], type = %6s\n", + n_done, output_tensors.size(), + name.c_str(), + llama_format_tensor_shape(out_tensor).c_str(), + ggml_type_name(out_tensor->type)); - if (tensor1->type == GGML_TYPE_F16 && tensor2->type == GGML_TYPE_F16) { - for (size_t i = 0; i < result.size() / sizeof(float); i++) { - ggml_fp16_t * t1 = (ggml_fp16_t *) tensor1->data; - ggml_fp16_t * t2 = (ggml_fp16_t *) tensor2->data; - ggml_fp16_t * dest = (ggml_fp16_t *) result.data(); - float dequant1 = ggml_fp16_to_fp32(t1[i]); - float dequant2 = ggml_fp16_to_fp32(t2[2]); - float res = dequant1 * conf->scale1 + dequant2 * conf->scale2; - dest[i] = ggml_fp32_to_fp16(res); - } - } else if (tensor1->type == GGML_TYPE_F32 && tensor2->type == GGML_TYPE_F32) { - for (size_t i = 0; i < result.size() / sizeof(double); i++) { - float * t1 = (float *) tensor1->data; - float * t2 = (float *) tensor2->data; - float * dest = (float *) result.data(); - dest[i] = t1[i] * conf->scale1 + t2[i] * conf->scale2; - } - } else { - LLAMA_LOG_ERROR("Only GGML_TYPE_F16 or GGML_TYPE_F32 is supported, current type = %s\n", ggml_type_name(tensor1->type)); + // write tensor data + padding + fout.write((const char *) buf.data(), buf.size()); + zeros(fout, GGML_PAD(buf.size(), GGUF_DEFAULT_ALIGNMENT) - buf.size()); + } + + // process layer output tensor + for (auto & out_tensor : output_tensors) { + std::vector> in_buf(ggml_nbytes(out_tensor)); + std::vector> out_buf(ggml_nbytes(out_tensor)); + + std::string out_name = ggml_get_name(out_tensor); + int i_layer_out = get_i_layer(out_name.c_str()); + auto layer = config->layers[i_layer_out]; + + if (i_layer_out < 0) { + continue; + } + + for (size_t i_model = 0; i_model < config->n_models; i_model++) { + int src_layer = layer.srcs[i_model]; // source layer + float scale = layer.scales[i_model]; + std::string src_name = get_name(src_layer, out_name); // find the correct tensor based on src_layer + struct ggml_tensor * in_tensor = mls[i_model]->get_tensor_meta(src_name.c_str()); + int res; + if (in_tensor == nullptr) { + LLAMA_LOG_ERROR("Cannot find layer name %s from model %ld\n", src_name.c_str(), i_model + 1); + clean_up(); return -1; } + read_tensor_data(in_tensor, *mls[i_model], in_buf); + res = calc_output_tensor(in_tensor->type, in_buf, scale, out_buf); + if (res < 0) { + clean_up(); + return res; + } } - LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s\n", - i + 1, ml1.n_tensors, - ggml_get_name(tensor1), - llama_format_tensor_shape(tensor1).c_str(), - ggml_type_name(tensor1->type)); + n_done++; + LLAMA_LOG_INFO("[%4ld/%4ld] %36s - [%s], type = %6s\n", + n_done, output_tensors.size(), + out_name.c_str(), + llama_format_tensor_shape(out_tensor).c_str(), + ggml_type_name(out_tensor->type)); // write tensor data + padding - fout.write((const char *) result.data(), tensor_size); - zeros(fout, GGML_PAD(tensor_size, GGUF_DEFAULT_ALIGNMENT) - tensor_size); + fout.write((const char *) out_buf.data(), out_buf.size()); + zeros(fout, GGML_PAD(out_buf.size(), GGUF_DEFAULT_ALIGNMENT) - out_buf.size()); } // go back to beginning of file and write the updated meta data @@ -11462,31 +11542,10 @@ static int32_t llama_merge_models_internal( fout.write((const char *) data.data(), data.size()); } - fout.close(); - - gguf_free(ctx_out); + clean_up(); return 0; } -int32_t llama_merge_models( - const char * fname_inp1, - const char * fname_inp2, - const struct llama_merge_config * configs, - const int n_configs, - const char * fname_out) -{ - std::vector v_configs(n_configs); - for (int i = 0; i < n_configs; i++) { - v_configs[i] = &configs[i]; - } - return llama_merge_models_internal( - fname_inp1, - fname_inp2, - v_configs, - fname_out - ); -} - static int llama_apply_lora_from_file_internal( const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads ) { diff --git a/llama.h b/llama.h index 7806476f8..23cabe291 100644 --- a/llama.h +++ b/llama.h @@ -327,11 +327,18 @@ extern "C" { const char * content; } llama_chat_message; + // used to merge models + struct llama_merge_layer { + const int * srcs; // contains n_models elements + const float * scales; // contains n_models elements + }; + struct llama_merge_config { - const int i_layer; - const float scale1; - const float scale2; - // TODO add support for embeding and output layers + const char ** model_paths; + const size_t n_models; + const struct llama_merge_layer * layers; + const size_t n_layers; + const char * output_path; }; // Helpers for getting default parameters @@ -422,11 +429,7 @@ extern "C" { const llama_model_quantize_params * params); LLAMA_API int32_t llama_merge_models( - const char * fname_inp1, - const char * fname_inp2, - const struct llama_merge_config * configs, - const int n_configs, - const char * fname_out); + const struct llama_merge_config * config); // Apply a LoRA adapter to a loaded model // path_base_model is the path to a higher quality model to use as a base for