diff --git a/convert.py b/convert.py index 548bd9d3b..ab6a4e10e 100644 --- a/convert.py +++ b/convert.py @@ -323,14 +323,11 @@ Vocab = Union[SentencePieceVocab, GGMLVocab] def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray: - if n_kv_head is None or n_head == n_kv_head: - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) - else: - return (weights.reshape(n_head // n_kv_head, 2, weights.shape[0] * n_kv_head // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) + if n_kv_head is not None and n_head != n_kv_head: + n_head //= n_kv_head + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) def dequantize_q4(qvalues_pack32: NDArray, scales: NDArray, addends: Optional[NDArray], g_idx: Optional[NDArray]) -> NDArray: