Skip qkv reshaping for non-parallel attention

This commit is contained in:
akawrykow 2023-08-29 15:13:04 -07:00
parent e276e4b606
commit de64f091c8

View file

@ -206,6 +206,7 @@ tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
# params for qkv transform
head_dim = hparams["hidden_size"] // n_head
parallel_attn = hparams["parallel_attn"]
# tensor info
print("gguf: get tensor metadata")
@ -240,7 +241,7 @@ for part_name in part_names:
# in contiguous fashion.
# ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
if "query_key_value" in name:
if "query_key_value" in name and parallel_attn:
qkv = data.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
q = qkv[:, :-2 ].reshape(n_head * head_dim, head_dim * n_head)
k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)