From 9442c34f4960586f745245f3dd2782b3e4fb8129 Mon Sep 17 00:00:00 2001 From: Maximilian Markewitz <77107165+mj-shifu@users.noreply.github.com> Date: Thu, 27 Jul 2023 20:59:43 +0200 Subject: [PATCH] convert.py : shorten and simplify permute * idea from @KerfuffleV2 --- convert.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/convert.py b/convert.py index 548bd9d3b..ab6a4e10e 100644 --- a/convert.py +++ b/convert.py @@ -323,14 +323,11 @@ Vocab = Union[SentencePieceVocab, GGMLVocab] def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray: - if n_kv_head is None or n_head == n_kv_head: - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) - else: - return (weights.reshape(n_head // n_kv_head, 2, weights.shape[0] * n_kv_head // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) + if n_kv_head is not None and n_head != n_kv_head: + n_head //= n_kv_head + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) def dequantize_q4(qvalues_pack32: NDArray, scales: NDArray, addends: Optional[NDArray], g_idx: Optional[NDArray]) -> NDArray: