shuffle attn_q.weight and attn_output.weight for broadcasting
This commit is contained in:
parent
9339ffc96d
commit
db1b18dc97
2 changed files with 26 additions and 0 deletions
|
@ -1002,6 +1002,20 @@ class PlamoModel(Model):
|
||||||
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
|
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
|
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
|
||||||
|
|
||||||
|
def shuffle_attn_q_weight(self, data_torch):
|
||||||
|
assert data_torch.size() == (5120, 5120)
|
||||||
|
data_torch = data_torch.reshape(8, 5, 128, 5120)
|
||||||
|
data_torch = torch.permute(data_torch, (1, 0, 2, 3))
|
||||||
|
data_torch = torch.reshape(data_torch, (5120, 5120))
|
||||||
|
return data_torch
|
||||||
|
|
||||||
|
def shuffle_attn_output_weight(self, data_torch):
|
||||||
|
assert data_torch.size() == (5120, 5120)
|
||||||
|
data_torch = data_torch.reshape(5120, 8, 5, 128)
|
||||||
|
data_torch = torch.permute(data_torch, (0, 2, 1, 3))
|
||||||
|
data_torch = torch.reshape(data_torch, (5120, 5120))
|
||||||
|
return data_torch
|
||||||
|
|
||||||
def write_tensors(self):
|
def write_tensors(self):
|
||||||
block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
|
block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
|
||||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||||
|
@ -1016,6 +1030,12 @@ class PlamoModel(Model):
|
||||||
print(f"Can not map tensor {name!r}")
|
print(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
|
# shuffle for broadcasting of gqa in ggml_mul_mat
|
||||||
|
if new_name.endswith("attn_q.weight"):
|
||||||
|
data_torch = self.shuffle_attn_q_weight(data_torch)
|
||||||
|
elif new_name.endswith("attn_output.weight"):
|
||||||
|
data_torch = self.shuffle_attn_output_weight(data_torch)
|
||||||
|
|
||||||
old_dtype = data_torch.dtype
|
old_dtype = data_torch.dtype
|
||||||
|
|
||||||
# convert any unsupported data types to float32
|
# convert any unsupported data types to float32
|
||||||
|
|
|
@ -5601,11 +5601,14 @@ struct llm_build_context {
|
||||||
0);
|
0);
|
||||||
cb(k, "k", il);
|
cb(k, "k", il);
|
||||||
|
|
||||||
|
/*
|
||||||
// we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att
|
// we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att
|
||||||
struct ggml_tensor * k_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k->ne[0], k->ne[1], q->ne[2]);
|
struct ggml_tensor * k_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k->ne[0], k->ne[1], q->ne[2]);
|
||||||
cb(k_repeated, "k_repeated", il);
|
cb(k_repeated, "k_repeated", il);
|
||||||
|
|
||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, ggml_repeat(ctx, k, k_repeated), q);
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, ggml_repeat(ctx, k, k_repeated), q);
|
||||||
|
*/
|
||||||
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||||
cb(kq, "kq", il);
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head)));
|
kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head)));
|
||||||
|
@ -5620,11 +5623,14 @@ struct llm_build_context {
|
||||||
0);
|
0);
|
||||||
cb(v, "v", il);
|
cb(v, "v", il);
|
||||||
|
|
||||||
|
/*
|
||||||
// we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att
|
// we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att
|
||||||
struct ggml_tensor * v_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, v->ne[0], v->ne[1], q->ne[2]);
|
struct ggml_tensor * v_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, v->ne[0], v->ne[1], q->ne[2]);
|
||||||
cb(k_repeated, "v_repeated", il);
|
cb(k_repeated, "v_repeated", il);
|
||||||
|
|
||||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx, ggml_repeat(ctx, v, v_repeated), kq);
|
struct ggml_tensor * kqv = ggml_mul_mat(ctx, ggml_repeat(ctx, v, v_repeated), kq);
|
||||||
|
*/
|
||||||
|
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
|
||||||
cb(kqv, "kqv", il);
|
cb(kqv, "kqv", il);
|
||||||
|
|
||||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
|
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue