convert-new.py : add map for skipping tensor serialization
This commit is contained in:
parent
86bc9d2750
commit
7eaa315631
2 changed files with 65 additions and 37 deletions
|
@ -34,6 +34,13 @@ if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
|
|||
|
||||
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
|
||||
|
||||
ARCH=gguf.MODEL_ARCH.LLAMA
|
||||
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
|
||||
|
||||
#
|
||||
# data types
|
||||
#
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UnquantizedDataType:
|
||||
name: str
|
||||
|
@ -55,6 +62,13 @@ DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
|
|||
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \
|
||||
{dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()}
|
||||
|
||||
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
|
||||
'BF16': DT_BF16,
|
||||
'F16': DT_F16,
|
||||
'F32': DT_F32,
|
||||
'I32': DT_I32,
|
||||
}
|
||||
|
||||
class GGMLFileType(enum.Enum):
|
||||
AllF32 = 0
|
||||
MostlyF16 = 1 # except 1d tensors
|
||||
|
@ -70,14 +84,10 @@ class GGMLFileType(enum.Enum):
|
|||
else:
|
||||
raise ValueError(self)
|
||||
|
||||
def find_n_mult(n_ff: int, n_embd: int) -> int:
|
||||
# hardcoded magic range
|
||||
for n_mult in range(8192, 1, -1):
|
||||
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
|
||||
if calc_ff == n_ff:
|
||||
return n_mult
|
||||
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
|
||||
|
||||
#
|
||||
# hparams loading
|
||||
#
|
||||
|
||||
@dataclass
|
||||
class Params:
|
||||
|
@ -91,6 +101,15 @@ class Params:
|
|||
n_head_kv: int
|
||||
f_norm_eps: float
|
||||
|
||||
@staticmethod
|
||||
def find_n_mult(n_ff: int, n_embd: int) -> int:
|
||||
# hardcoded magic range
|
||||
for n_mult in range(8192, 1, -1):
|
||||
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
|
||||
if calc_ff == n_ff:
|
||||
return n_mult
|
||||
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
|
||||
|
||||
@staticmethod
|
||||
def guessed(model: 'LazyModel') -> 'Params':
|
||||
# try transformer naming first
|
||||
|
@ -139,7 +158,7 @@ class Params:
|
|||
n_head_kv = config["num_key_value_heads"];
|
||||
f_norm_eps = config["rms_norm_eps"];
|
||||
|
||||
n_mult = find_n_mult(n_ff, n_embd);
|
||||
n_mult = Params.find_n_mult(n_ff, n_embd);
|
||||
|
||||
if "max_sequence_length" in config:
|
||||
n_ctx = config["max_sequence_length"]
|
||||
|
@ -210,6 +229,10 @@ class Params:
|
|||
return params
|
||||
|
||||
|
||||
#
|
||||
# vocab
|
||||
#
|
||||
|
||||
class BpeVocab:
|
||||
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
|
||||
self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
|
||||
|
@ -294,10 +317,14 @@ class SentencePieceVocab:
|
|||
def __repr__(self) -> str:
|
||||
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
|
||||
|
||||
|
||||
Vocab = Union[BpeVocab, SentencePieceVocab]
|
||||
|
||||
|
||||
#
|
||||
# data loading
|
||||
# TODO: reuse (probably move to gguf.py?)
|
||||
#
|
||||
|
||||
def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
|
||||
if n_head_kv is not None and n_head != n_head_kv:
|
||||
n_head //= n_head_kv
|
||||
|
@ -593,14 +620,6 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
|
|||
return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None)
|
||||
|
||||
|
||||
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
|
||||
'BF16': DT_BF16,
|
||||
'F16': DT_F16,
|
||||
'F32': DT_F32,
|
||||
'I32': DT_I32,
|
||||
}
|
||||
|
||||
|
||||
def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
|
||||
header_size, = struct.unpack('<Q', fp.read(8))
|
||||
header: Dict[str, Dict[str, Any]] = json.loads(fp.read(header_size))
|
||||
|
@ -690,17 +709,16 @@ class OutputFile:
|
|||
self.gguf = gguf.GGUFWriter.open(fname_out)
|
||||
|
||||
def add_meta_arch(self, params: Params) -> None:
|
||||
llm_arch = gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA]
|
||||
|
||||
self.gguf.add_architecture (llm_arch)
|
||||
self.gguf.add_context_length (llm_arch, params.n_ctx)
|
||||
self.gguf.add_embedding_length (llm_arch, params.n_embd)
|
||||
self.gguf.add_block_count (llm_arch, params.n_layer)
|
||||
self.gguf.add_feed_forward_length (llm_arch, params.n_ff)
|
||||
self.gguf.add_rope_dimension_count(llm_arch, params.n_embd // params.n_head)
|
||||
self.gguf.add_head_count (llm_arch, params.n_head)
|
||||
self.gguf.add_head_count_kv (llm_arch, params.n_head_kv)
|
||||
self.gguf.add_layer_norm_rms_eps (llm_arch, params.f_norm_eps)
|
||||
arch = gguf.MODEL_ARCH_NAMES[ARCH]
|
||||
self.gguf.add_architecture (arch)
|
||||
self.gguf.add_context_length (arch, params.n_ctx)
|
||||
self.gguf.add_embedding_length (arch, params.n_embd)
|
||||
self.gguf.add_block_count (arch, params.n_layer)
|
||||
self.gguf.add_feed_forward_length (arch, params.n_ff)
|
||||
self.gguf.add_rope_dimension_count(arch, params.n_embd // params.n_head)
|
||||
self.gguf.add_head_count (arch, params.n_head)
|
||||
self.gguf.add_head_count_kv (arch, params.n_head_kv)
|
||||
self.gguf.add_layer_norm_rms_eps (arch, params.f_norm_eps)
|
||||
|
||||
def add_meta_vocab(self, vocab: Vocab) -> None:
|
||||
tokens = []
|
||||
|
@ -754,7 +772,7 @@ class OutputFile:
|
|||
|
||||
|
||||
def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType:
|
||||
wq_type = model[gguf.MODEL_TENSOR_NAMES[gguf.MODEL_ARCH.LLAMA][gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type
|
||||
wq_type = model[NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type
|
||||
|
||||
if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)):
|
||||
return GGMLFileType.AllF32
|
||||
|
@ -767,7 +785,7 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi
|
|||
|
||||
|
||||
def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
|
||||
tmap = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAMA, params.n_layer)
|
||||
tmap = gguf.get_tensor_name_map(ARCH, params.n_layer)
|
||||
|
||||
out: LazyModel = {}
|
||||
for name, lazy_tensor in model.items():
|
||||
|
@ -782,9 +800,11 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
|
|||
else:
|
||||
raise Exception(f"Unexpected tensor name: {name}")
|
||||
|
||||
out[name_new] = lazy_tensor
|
||||
|
||||
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}")
|
||||
if gguf.should_skip_tensor(ARCH, params.n_layer, name_new):
|
||||
print(f"skipping tensor {name_new}")
|
||||
else:
|
||||
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}")
|
||||
out[name_new] = lazy_tensor
|
||||
|
||||
return out
|
||||
|
||||
|
@ -961,7 +981,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
|||
vocab = load_vocab(vocab_dir, args.vocabtype)
|
||||
|
||||
model = model_plus.model
|
||||
model = convert_model_names(model, params) # TODO: utilize gguf.get_tensor_name_map
|
||||
model = convert_model_names(model, params)
|
||||
output_type = pick_output_type(model, args.outtype)
|
||||
model = convert_to_output_type(model, output_type)
|
||||
outfile = args.outfile or default_outfile(model_plus.paths, output_type)
|
||||
|
|
14
gguf.py
14
gguf.py
|
@ -159,11 +159,19 @@ MODEL_TENSOR_NAMES = {
|
|||
|
||||
# tensors that will not be serialized
|
||||
MODEL_TENSOR_SKIP = {
|
||||
MODEL_ARCH.LLAMA : {
|
||||
MODEL_ARCH.LLAMA : [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
def should_skip_tensor(arch : MODEL_ARCH, n_blocks : int, name : str) -> bool:
|
||||
for skip in MODEL_TENSOR_SKIP.get(arch, []):
|
||||
for i in range(n_blocks):
|
||||
if name == MODEL_TENSOR_NAMES[arch][skip].format(bid=i):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_tensor_name_map(arch : MODEL_ARCH, n_blocks : int) -> dict:
|
||||
tensor_map = {}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue