convert.py : shorten and simplify permute
* idea from @KerfuffleV2
This commit is contained in:
parent
01d16e1a1e
commit
9442c34f49
1 changed files with 5 additions and 8 deletions
|
@ -323,14 +323,11 @@ Vocab = Union[SentencePieceVocab, GGMLVocab]
|
||||||
|
|
||||||
|
|
||||||
def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
|
def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
|
||||||
if n_kv_head is None or n_head == n_kv_head:
|
if n_kv_head is not None and n_head != n_kv_head:
|
||||||
|
n_head //= n_kv_head
|
||||||
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||||
.swapaxes(1, 2)
|
.swapaxes(1, 2)
|
||||||
.reshape(weights.shape))
|
.reshape(weights.shape))
|
||||||
else:
|
|
||||||
return (weights.reshape(n_head // n_kv_head, 2, weights.shape[0] * n_kv_head // n_head // 2, *weights.shape[1:])
|
|
||||||
.swapaxes(1, 2)
|
|
||||||
.reshape(weights.shape))
|
|
||||||
|
|
||||||
|
|
||||||
def dequantize_q4(qvalues_pack32: NDArray, scales: NDArray, addends: Optional[NDArray], g_idx: Optional[NDArray]) -> NDArray:
|
def dequantize_q4(qvalues_pack32: NDArray, scales: NDArray, addends: Optional[NDArray], g_idx: Optional[NDArray]) -> NDArray:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue