self merge ok

This commit is contained in:
ngxson 2024-03-03 18:04:48 +01:00
parent 65730438aa
commit a032bb6ca2
4 changed files with 248 additions and 137 deletions

View file

@ -7,17 +7,19 @@
# Supported verbs:
# - linear: merge linearly, parameters: source_layer,source_layer,t
# - slerp: spherical linear interpolation, parameters: source_layer,source_layer,scale,scale
# - repeat: repeat a layer in the same output model (to reduce file size)
#
# For example:
#
# - copy: copy from which model, which layer
#########################
# Example:
# This is the first layer of output model:
# For all tensors, we want slerp(model[0].layer[0], model[1].layer[0], 0.1)
# Except for "attn_output" tensor that we want t=0.5 instead t=0.1
output layer 0
all slerp 0,0,0.9
attn_output slerp 0,0,0.9
all slerp 0,0,0.1
attn_output slerp 0,0,0.5
# For next layer, we want: model[0].layer[1]*0.6 + model[1].layer[1]*0.4
# Except for "attn_output" tensor that we want to use slerp with t=0.9
@ -26,13 +28,96 @@ output layer 1
all linear 1,1,0.6,0.4
attn_output slerp 1,1,0.9
output layer 2
all linear 2,2,1.0,0.0
# For next layer, we want to copy from model[0].layer[2]
# repeat the first layers defined earlier in this file
output layer 2
all copy 0,2
output layer 3
all repeat 0
all copy 0,3
# For next layer, we want to copy from model[1].layer[4]
output layer 4
all repeat 1
all copy 1,4
output layer 5
all copy 1,5
output layer 6
all linear 6,6,0.1,0.9
output layer 7
all linear 7,7,0.1,0.9
output layer 8
all linear 8,8,0.1,0.9
output layer 9
all linear 9,9,0.1,0.9
output layer 10
all linear 10,10,0.1,0.9
output layer 11
all linear 11,11,0.1,0.9
output layer 12
all linear 12,12,0.1,0.9
output layer 13
all linear 13,13,0.3333,0.6666
output layer 14
all linear 14,14,0.3333,0.6666
output layer 15
all linear 15,15,0.3333,0.6666
output layer 16
all linear 16,16,0.3333,0.6666
output layer 17
all linear 17,17,0.3333,0.6666
output layer 18
all linear 18,18,0.3333,0.6666
output layer 19
all linear 19,19,0.3333,0.6666
output layer 20
all slerp 20,20,0.8
output layer 21
all slerp 21,21,0.8
output layer 22
all slerp 22,22,0.8
output layer 23
all slerp 23,23,0.8
output layer 24
all slerp 24,24,0.8
output layer 25
all slerp 25,25,0.8
output layer 26
all slerp 26,26,0.8
output layer 27
all slerp 27,27,0.8
output layer 28
all slerp 28,28,0.8
output layer 29
all slerp 29,29,0.8
output layer 30
all slerp 30,30,0.8
output layer 31
all slerp 31,31,0.8

View file

@ -125,7 +125,8 @@ static std::vector<struct llama_merge_inst> parse_config(std::string & config_pa
struct llama_merge_inst ins;
ins.method = LLAMA_MERGE_COPY;
strcpy(ins.name, name.c_str());
strcpy(ins.srcs[0], name.c_str());
strcpy(ins.srcs[0], name.c_str()); // always take the first model
strcpy(ins.srcs[1], "");
instructions.push_back(ins);
} else {
// tensor belong to layer
@ -177,7 +178,7 @@ static std::vector<struct llama_merge_inst> parse_config(std::string & config_pa
auto parts = str_split(line, " ");
if (parts.size() != 3) {
raise_err(i_line, "does not follow format: \"target (space) verb (space) arguments\"");
raise_err(i_line, "does not follow format: \"target (space) verb (space) parameters\"");
}
auto target = parts[0];
@ -197,7 +198,7 @@ static std::vector<struct llama_merge_inst> parse_config(std::string & config_pa
auto linear = [&](struct llama_merge_inst & ins, std::string unit) {
if (params.size() != 4) {
raise_err(i_line, "verb \"linear\" requires exactly 4 params");
raise_err(i_line, "verb \"linear\" requires exactly 4 parameters");
}
ins.method = LLAMA_MERGE_LINEAR;
int src0 = std::stoi(params[0]);
@ -211,7 +212,7 @@ static std::vector<struct llama_merge_inst> parse_config(std::string & config_pa
auto slerp = [&](struct llama_merge_inst & ins, std::string unit) {
if (params.size() != 3) {
raise_err(i_line, "verb \"slerp\" requires exactly 3 params");
raise_err(i_line, "verb \"slerp\" requires exactly 3 parameters");
}
ins.method = LLAMA_MERGE_SLERP;
int src0 = std::stoi(params[0]);
@ -222,14 +223,33 @@ static std::vector<struct llama_merge_inst> parse_config(std::string & config_pa
is_layer_empty = false;
};
auto repeat = [&](struct llama_merge_inst & ins, std::string unit) {
/*auto repeat = [&](struct llama_merge_inst & ins, std::string unit) {
if (params.size() != 1) {
raise_err(i_line, "verb \"repeat\" requires exactly 1 param");
raise_err(i_line, "verb \"repeat\" requires exactly 1 parameter");
}
ins.method = LLAMA_MERGE_REPEAT;
int src0 = std::stoi(params[0]);
strcpy(ins.srcs[0], get_tensor_name(src0, unit).c_str());
is_layer_empty = false;
};*/
auto copy = [&](struct llama_merge_inst & ins, std::string unit) {
if (params.size() != 2) {
raise_err(i_line, "verb \"copy\" requires exactly 2 parameters");
}
ins.method = LLAMA_MERGE_COPY;
int model = std::stoi(params[0]);
int layer = std::stoi(params[1]);
if (model == 0) {
strcpy(ins.srcs[0], get_tensor_name(layer, unit).c_str());
strcpy(ins.srcs[1], "");
} else if (model == 1) {
strcpy(ins.srcs[0], "");
strcpy(ins.srcs[1], get_tensor_name(layer, unit).c_str());
} else {
raise_err(i_line, "can only copy from model 0 or 1");
}
is_layer_empty = false;
};
auto apply_verb = [&](struct llama_merge_inst & ins, std::string unit) {
@ -238,12 +258,16 @@ static std::vector<struct llama_merge_inst> parse_config(std::string & config_pa
} else if (verb == "slerp") {
slerp(ins, unit);
} else if (verb == "repeat") {
repeat(ins, unit);
// repeat(ins, unit);
raise_err(i_line, "repeat is currently not supported");
} else if (verb == "copy") {
copy(ins, unit);
} else {
raise_err(i_line, "invalid verb: " + verb);
}
};
// TODO: what if user does not use "all"? we may miss some tensors?
if (target == "all") {
for (auto & u : units) {
apply_verb(layer[u], u);

236
llama.cpp
View file

@ -11358,14 +11358,12 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
#else
constexpr bool use_mmap = false;
#endif
/*
// std::move doesn't work with llama_model and llama_model_loader, why?
std::vector<std::unique_ptr<llama_model>> models;
std::vector<std::unique_ptr<llama_model_loader>> mls;
std::vector<no_init<uint8_t>> buf_in;
std::vector<no_init<uint8_t>> buf_out;
std::set<std::string> ref_names; // list of ref_name per layer
int max_input_layers = 0; // number of layers that the input model has
std::vector<struct ggml_tensor *> output_tensors;
// output file
@ -11373,21 +11371,6 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
std::ofstream fout(config->output_path, std::ios::binary);
fout.exceptions(std::ofstream::failbit); // fail fast on write errors
// get layer index from tensor name, for example "blk.x.attn_norm.weight"
// returns -1 if it is non-layer
auto get_i_layer = [&](std::string tensor_name) -> int {
int i_layer = -1;
return sscanf(tensor_name.c_str(), "blk.%d.", &i_layer) == 1 ? i_layer : -1;
};
// get new tensor name by i_layer and ref_name, for example "blk.x.attn_norm.weight"
auto get_name = [&](int i_layer, std::string ref_name) -> std::string {
ref_name.erase(0, ref_name.find(".", 4)); // delete the "blk.x" part
std::stringstream ss;
ss << "blk." << i_layer << ref_name;
return ss.str();
};
// remember to call before exit
auto clean_up = [&]() {
fout.close();
@ -11398,7 +11381,8 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
};
// load the input models
for (size_t i = 0; i < config->n_models; i++) {
static const size_t n_models = 2;
for (size_t i = 0; i < n_models; i++) {
auto model = std::unique_ptr<llama_model>(new llama_model());
auto ml = std::unique_ptr<llama_model_loader>(new llama_model_loader(config->model_paths[i], use_mmap, NULL));
ml->init_mapping(false);
@ -11415,6 +11399,12 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
mls.push_back(std::move(ml));
}
// for verb copy, we want to get the source tensor
auto get_src_tensor_for_copy = [&](const struct llama_merge_inst ins, size_t & i_model) {
i_model = std::string(ins.srcs[0]).empty() ? 1 : 0;
return mls[i_model]->get_tensor_meta(ins.srcs[i_model]);
};
// construct metadata
{
// copy the KV pairs from the input file
@ -11424,40 +11414,58 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
std::stringstream ss;
ss << mls[0]->get_arch_name() << ".block_count";
gguf_set_val_u32(ctx_out, ss.str().c_str(), config->n_layers);
printf("====> Set new value of %s = %ld\n", ss.str().c_str(), config->n_layers);
// read input layers, process firstly non-layer tensors (embedding, output,...)
for (int i = 0; i < mls[0]->n_tensors; i++) {
struct ggml_tensor * meta = mls[0]->get_tensor_meta(i);
int i_layer = get_i_layer(ggml_get_name(meta));
if (i_layer < 0) {
// populate data for non-layer tensors (embedding, output,...)
struct ggml_tensor * out_tensor = (struct ggml_tensor *) malloc(GGML_TENSOR_SIZE);
memcpy(out_tensor, meta, GGML_TENSOR_SIZE); // copy metadata (shape, type,...)
gguf_add_tensor(ctx_out, out_tensor);
output_tensors.push_back(out_tensor);
} else {
max_input_layers = std::max(i_layer, max_input_layers);
if (i_layer == 0) {
// only extract names of one layer, assuming all layers have the same structure
ref_names.insert(ggml_get_name(meta));
}
}
}
LLAMA_LOG_INFO("====> Set new value of %s = %ld\n", ss.str().c_str(), config->n_layers);
// populate layers metadata for output model
for (size_t i_layer = 0; i_layer < config->n_layers; i_layer++) {
for (auto & ref_name : ref_names) {
// create new tensor, because new model may have more tensors than input model
struct ggml_tensor * out_tensor = (struct ggml_tensor *) malloc(GGML_TENSOR_SIZE);
struct ggml_tensor * ref_tensor = mls[0]->get_tensor_meta(ref_name.c_str()); // get ref tensor from layer 0
memcpy(out_tensor, ref_tensor, GGML_TENSOR_SIZE); // copy metadata (shape, type,...)
ggml_set_name(out_tensor, get_name(i_layer, ref_name).c_str()); // set the correct name (with correct i_layer)
output_tensors.push_back(out_tensor);
gguf_add_tensor(ctx_out, out_tensor);
// TODO: reject non-requantize-able type (one that requires imatrix)
// populate metadata for output tensors
auto push_tensor = [&](struct ggml_tensor * ref, const char * name) {
struct ggml_tensor * out_tensor = (struct ggml_tensor *) malloc(GGML_TENSOR_SIZE);
if (ref != nullptr) {
// copy metadata (shape, type,...)
memcpy(out_tensor, ref, GGML_TENSOR_SIZE);
}
ggml_set_name(out_tensor, name);
gguf_add_tensor(ctx_out, out_tensor);
output_tensors.push_back(out_tensor);
};
for (size_t i = 0; i < config->n_insts; i++) {
const struct llama_merge_inst ins = config->insts[i];
struct ggml_tensor * t0;
struct ggml_tensor * t1;
// TODO: reject non-requantize-able type (one that requires imatrix)
if (ins.method == LLAMA_MERGE_COPY) {
// simply copy from model A
size_t i_model;
t0 = get_src_tensor_for_copy(ins, i_model);
push_tensor(t0, ins.name);
} else if (ins.method == LLAMA_MERGE_LINEAR || ins.method == LLAMA_MERGE_SLERP) {
t0 = mls[0]->get_tensor_meta(ins.srcs[0]);
t1 = mls[1]->get_tensor_meta(ins.srcs[1]);
if (llama_format_tensor_shape(t0) != llama_format_tensor_shape(t1)) {
LLAMA_LOG_ERROR("some tensors does not have the same shape");
clean_up();
return -1;
}
push_tensor(t0, ins.name);
} else if (ins.method == LLAMA_MERGE_REPEAT) {
// TODO: in theory, we can point 2 tensors to the same offset, but here we're unable to do that, because offset is currently managed by gguf_add_tensor()
GGML_ASSERT(false);
/*int idx = nullptr;
std::string search_tensor(ins.srcs[0]);
for (auto & tensor : output_tensors) {
if (std::string(ggml_get_name(tensor)) == search_tensor) {
t0 = tensor;
break;
}
}
if (t0 == nullptr) {
LLAMA_LOG_ERROR("cannot find source tensor to repeat");
clean_up();
return -1;
}
push_tensor(t0, ins.name);*/
} else {
GGML_ASSERT(false); // should never happen
}
// TODO: how to reuse tensor (duplicated layers)? we can play with ctx->infos[tensor_idx].offset
}
const size_t meta_size = gguf_get_meta_size(ctx_out);
@ -11481,8 +11489,11 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
};
size_t n_done = 0;
size_t n_curr = 0;
auto log_step = [&](const struct ggml_tensor * tensor) {
auto write_output_tensor = [&](const struct ggml_tensor * tensor, void * data) {
// write tensor data + padding
const size_t len = ggml_nbytes(tensor);
fout.write((const char *) data, len);
zeros(fout, GGML_PAD(len, GGUF_DEFAULT_ALIGNMENT) - len);
n_done++;
LLAMA_LOG_INFO("[%4ld/%4ld] %36s - [%s], input type = %6s\n",
n_done, output_tensors.size(),
@ -11491,84 +11502,82 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
ggml_type_name(tensor->type));
};
// process non-layer output tensor
for (auto & out_tensor : output_tensors) {
std::string name = ggml_get_name(out_tensor);
int i_layer_out = get_i_layer(name.c_str());
std::vector<no_init<uint8_t>> buf;
if (i_layer_out >= 0) {
continue;
}
n_curr++;
struct ggml_tensor * in_tensor = mls[0]->get_tensor_meta(name.c_str());
if (in_tensor == nullptr) {
LLAMA_LOG_ERROR("Cannot find layer name %s from base model\n", name.c_str());
clean_up();
return -1;
}
read_tensor_data(in_tensor, *mls[0], buf); // read from first model
// write tensor data + padding
const size_t write_size = ggml_nbytes(out_tensor);
fout.write((const char *) in_tensor->data, write_size);
zeros(fout, GGML_PAD(write_size, GGUF_DEFAULT_ALIGNMENT) - write_size);
log_step(out_tensor);
}
// TODO: allow user to set n_threads
const int n_threads = std::thread::hardware_concurrency();
std::vector<std::thread> workers;
workers.reserve(n_threads);
// process tensors associated to layer
for (auto & out_tensor : output_tensors) {
// process instruction one by one
GGML_ASSERT(config->n_insts == output_tensors.size());
for (size_t i = 0; i < config->n_insts; i++) {
const struct llama_merge_inst ins = config->insts[i];
struct ggml_tensor * t0;
struct ggml_tensor * t1;
struct ggml_tensor * out_tensor = output_tensors[i];
const size_t n_elements = ggml_nelements(out_tensor);
std::vector<no_init<uint8_t>> in_buf;
std::vector<no_init<float>> f32_in_buf; // dequant it internally
std::vector<no_init<uint8_t>> in_buf0;
std::vector<no_init<float>> f32_in_buf0; // dequant it internally
std::vector<no_init<uint8_t>> in_buf1;
std::vector<no_init<float>> f32_in_buf1; // dequant it internally
std::vector<float> f32_out_buf(n_elements, 0.0); // do not resize!
std::vector<uint8_t> out_buf(ggml_nbytes(out_tensor)); // do not resize!
std::string out_name = ggml_get_name(out_tensor);
int i_layer_out = get_i_layer(out_name.c_str());
auto layer = config->layers[i_layer_out];
if (i_layer_out < 0) {
continue; // skip non-layer tensors
if (ins.method == LLAMA_MERGE_COPY) {
LLAMA_LOG_INFO("copy\n");
size_t i_model;
t0 = get_src_tensor_for_copy(ins, i_model);
read_tensor_data(t0, *mls[i_model], in_buf0);
write_output_tensor(out_tensor, t0->data);
continue;
}
for (size_t i_model = 0; i_model < config->n_models; i_model++) {
int src_layer = layer.srcs[i_model]; // source layer
float scale = layer.scales[i_model];
std::string src_name = get_name(src_layer, out_name); // find the correct tensor based on src_layer
struct ggml_tensor * in_tensor = mls[i_model]->get_tensor_meta(src_name.c_str());
if (in_tensor == nullptr) {
LLAMA_LOG_ERROR("Cannot find layer name %s from model %ld\n", src_name.c_str(), i_model + 1);
clean_up();
return -1; // stop
}
read_tensor_data(in_tensor, *mls[i_model], in_buf);
// dequant the tensor to FP32
// dequantize the tensor to FP32
auto dequantize = [&](struct ggml_tensor * in_tensor, std::vector<no_init<float>> & f32_in_buf) {
if (in_tensor->type != GGML_TYPE_F32) {
//LLAMA_LOG_ERROR("dequant ");
LLAMA_LOG_INFO("dequant ");
llama_convert_tensor_internal(in_tensor, f32_in_buf, workers, n_elements, n_threads);
} else {
// if we already have f32, just copy it
//LLAMA_LOG_ERROR("f32_copy ");
LLAMA_LOG_INFO("f32_copy ");
f32_in_buf.resize(n_elements);
memcpy((void *) f32_in_buf.data(), in_tensor->data, n_elements * sizeof(float));
}
// do the calculation
//LLAMA_LOG_ERROR("calc ");
};
// load data and dequantize
if (ins.method == LLAMA_MERGE_LINEAR || ins.method == LLAMA_MERGE_SLERP) {
t0 = mls[0]->get_tensor_meta(ins.srcs[0]);
t1 = mls[1]->get_tensor_meta(ins.srcs[1]);
read_tensor_data(t0, *mls[0], in_buf0);
read_tensor_data(t1, *mls[1], in_buf1);
dequantize(t0, f32_in_buf0);
dequantize(t1, f32_in_buf1);
}
if (ins.method == LLAMA_MERGE_LINEAR) {
LLAMA_LOG_INFO("linear ");
float * in0 = (float *) f32_in_buf0.data();
float * in1 = (float *) f32_in_buf1.data();
float * dest = (float *) f32_out_buf.data();
for (size_t i = 0; i < n_elements; i++) {
float * in = (float *) f32_in_buf.data();
float * dest = (float *) f32_out_buf.data();
dest[i] += in[i] * scale;
dest[i] = in0[i] * ins.scales[0] + in1[i] * ins.scales[1];
}
}
if (ins.method == LLAMA_MERGE_SLERP) {
LLAMA_LOG_INFO("slerp ");
float * in0 = (float *) f32_in_buf0.data();
float * in1 = (float *) f32_in_buf1.data();
float * dest = (float *) f32_out_buf.data();
for (size_t i = 0; i < n_elements; i++) {
//dest[i] = in0[i] * ins.t + in1[i] * 0;
dest[i] = in0[i];
}
}
// re-quantize it
//LLAMA_LOG_ERROR("requant\n");
{
LLAMA_LOG_INFO("requant\n");
std::array<int64_t, 1 << 4> hist_cur = {};
const int n_per_row = out_tensor->ne[0];
const int n_rows = n_elements / n_per_row;
@ -11588,16 +11597,10 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
GGML_ASSERT(new_size == out_buf.size());
}
// write tensor to file
{
LLAMA_LOG_ERROR("===> INPUT [layer %d] %f %f %f\n", i_layer_out, f32_in_buf[0].value, f32_in_buf[1].value, f32_in_buf[2].value);
LLAMA_LOG_ERROR("===> OUTPUT [layer %d] %f %f %f\n", i_layer_out, f32_out_buf[0], f32_out_buf[1], f32_out_buf[2]);
// my turn, write the result!
// write tensor data + padding
fout.write((const char *) out_buf.data(), out_buf.size());
zeros(fout, GGML_PAD(out_buf.size(), GGUF_DEFAULT_ALIGNMENT) - out_buf.size());
log_step(out_tensor);
}
LLAMA_LOG_INFO("===> INPUT %f %f %f\n", f32_in_buf0[0].value, f32_in_buf0[1].value, f32_in_buf0[2].value);
LLAMA_LOG_INFO("===> OUTPUT %f %f %f\n", f32_out_buf[0], f32_out_buf[1], f32_out_buf[2]);
write_output_tensor(out_tensor, out_buf.data());
}
// go back to beginning of file and write the updated meta data
@ -11610,7 +11613,6 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
}
clean_up();
*/
return 0;
}

View file

@ -330,7 +330,7 @@ extern "C" {
enum llama_merge_method {
LLAMA_MERGE_LINEAR,
LLAMA_MERGE_SLERP,
LLAMA_MERGE_REPEAT,
LLAMA_MERGE_REPEAT, // doesn't work for now
LLAMA_MERGE_COPY,
};
@ -339,7 +339,7 @@ extern "C" {
char name[GGML_MAX_NAME]; // name of output tensor
enum llama_merge_method method;
// we only support 2 models for now
char srcs[2][GGML_MAX_NAME]; // name of input tensors
char srcs[2][GGML_MAX_NAME]; // name of input tensors. if method == copy, only one src is non-empty
float scales[2]; // for linear method
float t; // for slerp method
};