convert.py : shorten and simplify permute

* idea from @KerfuffleV2
This commit is contained in:
Maximilian Markewitz 2023-07-27 20:59:43 +02:00
parent 01d16e1a1e
commit 9442c34f49

View file

@ -323,14 +323,11 @@ Vocab = Union[SentencePieceVocab, GGMLVocab]
def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray: def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
if n_kv_head is None or n_head == n_kv_head: if n_kv_head is not None and n_head != n_kv_head:
n_head //= n_kv_head
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2) .swapaxes(1, 2)
.reshape(weights.shape)) .reshape(weights.shape))
else:
return (weights.reshape(n_head // n_kv_head, 2, weights.shape[0] * n_kv_head // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))
def dequantize_q4(qvalues_pack32: NDArray, scales: NDArray, addends: Optional[NDArray], g_idx: Optional[NDArray]) -> NDArray: def dequantize_q4(qvalues_pack32: NDArray, scales: NDArray, addends: Optional[NDArray], g_idx: Optional[NDArray]) -> NDArray: