shuffle attn_q.weight and attn_output.weight for broadcasting

This commit is contained in:
okada 2023-12-24 17:58:55 +09:00
parent 9339ffc96d
commit db1b18dc97
2 changed files with 26 additions and 0 deletions

View file

@ -1002,6 +1002,20 @@ class PlamoModel(Model):
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
def shuffle_attn_q_weight(self, data_torch):
assert data_torch.size() == (5120, 5120)
data_torch = data_torch.reshape(8, 5, 128, 5120)
data_torch = torch.permute(data_torch, (1, 0, 2, 3))
data_torch = torch.reshape(data_torch, (5120, 5120))
return data_torch
def shuffle_attn_output_weight(self, data_torch):
assert data_torch.size() == (5120, 5120)
data_torch = data_torch.reshape(5120, 8, 5, 128)
data_torch = torch.permute(data_torch, (0, 2, 1, 3))
data_torch = torch.reshape(data_torch, (5120, 5120))
return data_torch
def write_tensors(self): def write_tensors(self):
block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers")) block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
@ -1016,6 +1030,12 @@ class PlamoModel(Model):
print(f"Can not map tensor {name!r}") print(f"Can not map tensor {name!r}")
sys.exit() sys.exit()
# shuffle for broadcasting of gqa in ggml_mul_mat
if new_name.endswith("attn_q.weight"):
data_torch = self.shuffle_attn_q_weight(data_torch)
elif new_name.endswith("attn_output.weight"):
data_torch = self.shuffle_attn_output_weight(data_torch)
old_dtype = data_torch.dtype old_dtype = data_torch.dtype
# convert any unsupported data types to float32 # convert any unsupported data types to float32

View file

@ -5601,11 +5601,14 @@ struct llm_build_context {
0); 0);
cb(k, "k", il); cb(k, "k", il);
/*
// we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att // we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att
struct ggml_tensor * k_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k->ne[0], k->ne[1], q->ne[2]); struct ggml_tensor * k_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k->ne[0], k->ne[1], q->ne[2]);
cb(k_repeated, "k_repeated", il); cb(k_repeated, "k_repeated", il);
struct ggml_tensor * kq = ggml_mul_mat(ctx, ggml_repeat(ctx, k, k_repeated), q); struct ggml_tensor * kq = ggml_mul_mat(ctx, ggml_repeat(ctx, k, k_repeated), q);
*/
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il); cb(kq, "kq", il);
kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head))); kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head)));
@ -5620,11 +5623,14 @@ struct llm_build_context {
0); 0);
cb(v, "v", il); cb(v, "v", il);
/*
// we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att // we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att
struct ggml_tensor * v_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, v->ne[0], v->ne[1], q->ne[2]); struct ggml_tensor * v_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, v->ne[0], v->ne[1], q->ne[2]);
cb(k_repeated, "v_repeated", il); cb(k_repeated, "v_repeated", il);
struct ggml_tensor * kqv = ggml_mul_mat(ctx, ggml_repeat(ctx, v, v_repeated), kq); struct ggml_tensor * kqv = ggml_mul_mat(ctx, ggml_repeat(ctx, v, v_repeated), kq);
*/
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
cb(kqv, "kqv", il); cb(kqv, "kqv", il);
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);