Reuse definitions from convert.py
This commit is contained in:
parent
63da54e016
commit
0a6d5ad7cc
1 changed files with 7 additions and 39 deletions
|
@ -3,45 +3,11 @@ import os
|
||||||
import re
|
import re
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass
|
from typing import Any, Dict, Sequence, TextIO
|
||||||
from typing import Any, Sequence
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from convert import DATA_TYPE_TO_FTYPE, NUMPY_TYPE_TO_DATA_TYPE, DataType
|
||||||
# TODO: import this from convert.py once #545 is merged
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class UnquantizedDataType:
|
|
||||||
name: str
|
|
||||||
|
|
||||||
|
|
||||||
DT_F16 = UnquantizedDataType("F16")
|
|
||||||
DT_F32 = UnquantizedDataType("F32")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class QuantizedDataType:
|
|
||||||
groupsize: int
|
|
||||||
have_addends: bool
|
|
||||||
have_g_idx: bool
|
|
||||||
|
|
||||||
|
|
||||||
DataType = UnquantizedDataType
|
|
||||||
|
|
||||||
DATA_TYPE_TO_FTYPE: dict[DataType, int] = {
|
|
||||||
DT_F32: 0,
|
|
||||||
DT_F16: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
DATA_TYPE_TO_NUMPY: dict[DataType, np.dtype[Any]] = {
|
|
||||||
DT_F16: np.dtype(np.float16),
|
|
||||||
DT_F32: np.dtype(np.float32),
|
|
||||||
}
|
|
||||||
|
|
||||||
NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {
|
|
||||||
dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
HF_SUBLAYER_TO_GGML = {
|
HF_SUBLAYER_TO_GGML = {
|
||||||
"self_attn.q_proj": "attention.wq",
|
"self_attn.q_proj": "attention.wq",
|
||||||
|
@ -59,7 +25,7 @@ HF_SUBLAYER_TO_GGML = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def translate_tensor_name(t):
|
def translate_tensor_name(t: str) -> str:
|
||||||
match = re.match(r".*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight", t)
|
match = re.match(r".*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight", t)
|
||||||
if match:
|
if match:
|
||||||
nn = match.group(1)
|
nn = match.group(1)
|
||||||
|
@ -80,13 +46,15 @@ def translate_tensor_name(t):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def write_file_header(fout, params):
|
def write_file_header(fout: TextIO, params: Dict[str, Any]) -> None:
|
||||||
fout.write(b"ggla"[::-1]) # magic (ggml lora)
|
fout.write(b"ggla"[::-1]) # magic (ggml lora)
|
||||||
fout.write(struct.pack("i", 1)) # file version
|
fout.write(struct.pack("i", 1)) # file version
|
||||||
fout.write(struct.pack("ii", params["r"], params["lora_alpha"]))
|
fout.write(struct.pack("ii", params["r"], params["lora_alpha"]))
|
||||||
|
|
||||||
|
|
||||||
def write_tensor_header(self, name: str, shape: Sequence[int], data_type: 1) -> None:
|
def write_tensor_header(
|
||||||
|
self, name: str, shape: Sequence[int], data_type: DataType
|
||||||
|
) -> None:
|
||||||
sname = name.encode("utf-8")
|
sname = name.encode("utf-8")
|
||||||
fout.write(
|
fout.write(
|
||||||
struct.pack(
|
struct.pack(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue