From db1b18dc9786f9fb17264878acbcd69ab35545ce Mon Sep 17 00:00:00 2001 From: okada Date: Sun, 24 Dec 2023 17:58:55 +0900 Subject: [PATCH] shuffle attn_q.weight and attn_output.weight for broadcasting --- convert-hf-to-gguf.py | 20 ++++++++++++++++++++ llama.cpp | 6 ++++++ 2 files changed, 26 insertions(+) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 6c783d796..689285fd6 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1002,6 +1002,20 @@ class PlamoModel(Model): self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) + def shuffle_attn_q_weight(self, data_torch): + assert data_torch.size() == (5120, 5120) + data_torch = data_torch.reshape(8, 5, 128, 5120) + data_torch = torch.permute(data_torch, (1, 0, 2, 3)) + data_torch = torch.reshape(data_torch, (5120, 5120)) + return data_torch + + def shuffle_attn_output_weight(self, data_torch): + assert data_torch.size() == (5120, 5120) + data_torch = data_torch.reshape(5120, 8, 5, 128) + data_torch = torch.permute(data_torch, (0, 2, 1, 3)) + data_torch = torch.reshape(data_torch, (5120, 5120)) + return data_torch + def write_tensors(self): block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers")) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) @@ -1016,6 +1030,12 @@ class PlamoModel(Model): print(f"Can not map tensor {name!r}") sys.exit() + # shuffle for broadcasting of gqa in ggml_mul_mat + if new_name.endswith("attn_q.weight"): + data_torch = self.shuffle_attn_q_weight(data_torch) + elif new_name.endswith("attn_output.weight"): + data_torch = self.shuffle_attn_output_weight(data_torch) + old_dtype = data_torch.dtype # convert any unsupported data types to float32 diff --git a/llama.cpp b/llama.cpp index 25a1a4a7c..3579c8960 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5601,11 +5601,14 @@ struct llm_build_context { 0); cb(k, "k", il); + /* // we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att struct ggml_tensor * k_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k->ne[0], k->ne[1], q->ne[2]); cb(k_repeated, "k_repeated", il); struct ggml_tensor * kq = ggml_mul_mat(ctx, ggml_repeat(ctx, k, k_repeated), q); + */ + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head))); @@ -5620,11 +5623,14 @@ struct llm_build_context { 0); cb(v, "v", il); + /* // we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att struct ggml_tensor * v_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, v->ne[0], v->ne[1], q->ne[2]); cb(k_repeated, "v_repeated", il); struct ggml_tensor * kqv = ggml_mul_mat(ctx, ggml_repeat(ctx, v, v_repeated), kq); + */ + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); cb(kqv, "kqv", il); struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);