implement slerp

This commit is contained in:
ngxson 2024-03-03 18:58:42 +01:00
parent a032bb6ca2
commit 10c477b8a8

View file

@ -11389,12 +11389,6 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
llm_load_arch(*ml, *model); llm_load_arch(*ml, *model);
llm_load_hparams(*ml, *model); llm_load_hparams(*ml, *model);
if (i > 0 && models[i-1]->hparams != model->hparams) {
LLAMA_LOG_ERROR("hparams of input models are different, aborting...");
clean_up();
return -1;
}
models.push_back(std::move(model)); models.push_back(std::move(model));
mls.push_back(std::move(ml)); mls.push_back(std::move(ml));
} }
@ -11521,6 +11515,8 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
std::vector<no_init<float>> f32_in_buf1; // dequant it internally std::vector<no_init<float>> f32_in_buf1; // dequant it internally
std::vector<float> f32_out_buf(n_elements, 0.0); // do not resize! std::vector<float> f32_out_buf(n_elements, 0.0); // do not resize!
std::vector<uint8_t> out_buf(ggml_nbytes(out_tensor)); // do not resize! std::vector<uint8_t> out_buf(ggml_nbytes(out_tensor)); // do not resize!
const int n_per_row = out_tensor->ne[0];
const int n_rows = n_elements / n_per_row;
if (ins.method == LLAMA_MERGE_COPY) { if (ins.method == LLAMA_MERGE_COPY) {
LLAMA_LOG_INFO("copy\n"); LLAMA_LOG_INFO("copy\n");
@ -11565,13 +11561,49 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
} }
if (ins.method == LLAMA_MERGE_SLERP) { if (ins.method == LLAMA_MERGE_SLERP) {
// Python code: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
LLAMA_LOG_INFO("slerp "); LLAMA_LOG_INFO("slerp ");
float * in0 = (float *) f32_in_buf0.data(); static const float dot_threshold = 0.9995;
float * in1 = (float *) f32_in_buf1.data(); auto lerp_row = [](float * in0, float * in1, float * out, size_t nelem, float t) {
float * dest = (float *) f32_out_buf.data(); for (size_t i = 0; i < nelem; i++) {
for (size_t i = 0; i < n_elements; i++) { out[i] = in0[i] * (1.0 - t) + in1[i] * t;
//dest[i] = in0[i] * ins.t + in1[i] * 0; }
dest[i] = in0[i]; };
auto slerp_row = [&lerp_row](float * in0, float * in1, float * out, size_t nelem, float t) {
float norm0 = std::sqrt(std::inner_product(in0, in0 + nelem, in0, 0.0));
float norm1 = std::sqrt(std::inner_product(in1, in1 + nelem, in1, 0.0));
// Normalize the vectors to get the directions and angles
std::vector<float> v0(nelem);
std::vector<float> v1(nelem);
for (size_t i = 0; i < nelem; i++) {
v0[i] = in0[i] / norm0;
v1[i] = in1[i] / norm1;
}
// Dot product with the normalized vectors
float dot = std::inner_product(v0.begin(), v0.end(), v1.begin(), 0.0);
// If absolute value of dot product is almost 1, vectors are ~colineal, so use lerp
if (std::abs(dot) > dot_threshold) {
return lerp_row(in0, in1, out, nelem, t);
}
// Calculate initial angle between v0 and v1
float theta_0 = std::acos(dot);
float sin_theta_0 = std::sin(theta_0);
// Angle at timestep t
float theta_t = theta_0 * t;
float sin_theta_t = std::sin(theta_t);
// Finish the slerp algorithm
float s0 = std::sin(theta_0 - theta_t) / sin_theta_0;
float s1 = sin_theta_t / sin_theta_0;
for (size_t i = 0; i < nelem; i++) {
out[i] = in0[i] * s0 + in1[i] * s1;
}
};
for (int r = 0; r < n_rows; r++) {
float * in0 = (float *) f32_in_buf0.data();
float * in1 = (float *) f32_in_buf1.data();
float * dest = (float *) f32_out_buf.data();
size_t offset = n_per_row * r;
slerp_row(in0 + offset, in1 + offset, dest + offset, n_per_row, ins.t);
} }
} }
@ -11579,8 +11611,6 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
{ {
LLAMA_LOG_INFO("requant\n"); LLAMA_LOG_INFO("requant\n");
std::array<int64_t, 1 << 4> hist_cur = {}; std::array<int64_t, 1 << 4> hist_cur = {};
const int n_per_row = out_tensor->ne[0];
const int n_rows = n_elements / n_per_row;
static const int min_chunk_size = 32 * 512; static const int min_chunk_size = 32 * 512;
const int chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row); const int chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row);
size_t new_size = llama_tensor_quantize_internal( size_t new_size = llama_tensor_quantize_internal(