convert-new.py : add map for skipping tensor serialization

This commit is contained in:
Georgi Gerganov 2023-08-17 15:40:39 +03:00
parent 86bc9d2750
commit 7eaa315631
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 65 additions and 37 deletions

View file

@ -34,6 +34,13 @@ if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]' NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
ARCH=gguf.MODEL_ARCH.LLAMA
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
#
# data types
#
@dataclass(frozen=True) @dataclass(frozen=True)
class UnquantizedDataType: class UnquantizedDataType:
name: str name: str
@ -55,6 +62,13 @@ DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \ NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \
{dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()} {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()}
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
'BF16': DT_BF16,
'F16': DT_F16,
'F32': DT_F32,
'I32': DT_I32,
}
class GGMLFileType(enum.Enum): class GGMLFileType(enum.Enum):
AllF32 = 0 AllF32 = 0
MostlyF16 = 1 # except 1d tensors MostlyF16 = 1 # except 1d tensors
@ -70,14 +84,10 @@ class GGMLFileType(enum.Enum):
else: else:
raise ValueError(self) raise ValueError(self)
def find_n_mult(n_ff: int, n_embd: int) -> int:
# hardcoded magic range
for n_mult in range(8192, 1, -1):
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
if calc_ff == n_ff:
return n_mult
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
#
# hparams loading
#
@dataclass @dataclass
class Params: class Params:
@ -91,6 +101,15 @@ class Params:
n_head_kv: int n_head_kv: int
f_norm_eps: float f_norm_eps: float
@staticmethod
def find_n_mult(n_ff: int, n_embd: int) -> int:
# hardcoded magic range
for n_mult in range(8192, 1, -1):
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
if calc_ff == n_ff:
return n_mult
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
@staticmethod @staticmethod
def guessed(model: 'LazyModel') -> 'Params': def guessed(model: 'LazyModel') -> 'Params':
# try transformer naming first # try transformer naming first
@ -139,7 +158,7 @@ class Params:
n_head_kv = config["num_key_value_heads"]; n_head_kv = config["num_key_value_heads"];
f_norm_eps = config["rms_norm_eps"]; f_norm_eps = config["rms_norm_eps"];
n_mult = find_n_mult(n_ff, n_embd); n_mult = Params.find_n_mult(n_ff, n_embd);
if "max_sequence_length" in config: if "max_sequence_length" in config:
n_ctx = config["max_sequence_length"] n_ctx = config["max_sequence_length"]
@ -210,6 +229,10 @@ class Params:
return params return params
#
# vocab
#
class BpeVocab: class BpeVocab:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None: def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read()) self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
@ -294,10 +317,14 @@ class SentencePieceVocab:
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
Vocab = Union[BpeVocab, SentencePieceVocab] Vocab = Union[BpeVocab, SentencePieceVocab]
#
# data loading
# TODO: reuse (probably move to gguf.py?)
#
def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray: def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
if n_head_kv is not None and n_head != n_head_kv: if n_head_kv is not None and n_head != n_head_kv:
n_head //= n_head_kv n_head //= n_head_kv
@ -593,14 +620,6 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None) return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None)
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
'BF16': DT_BF16,
'F16': DT_F16,
'F32': DT_F32,
'I32': DT_I32,
}
def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
header_size, = struct.unpack('<Q', fp.read(8)) header_size, = struct.unpack('<Q', fp.read(8))
header: Dict[str, Dict[str, Any]] = json.loads(fp.read(header_size)) header: Dict[str, Dict[str, Any]] = json.loads(fp.read(header_size))
@ -690,17 +709,16 @@ class OutputFile:
self.gguf = gguf.GGUFWriter.open(fname_out) self.gguf = gguf.GGUFWriter.open(fname_out)
def add_meta_arch(self, params: Params) -> None: def add_meta_arch(self, params: Params) -> None:
llm_arch = gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA] arch = gguf.MODEL_ARCH_NAMES[ARCH]
self.gguf.add_architecture (arch)
self.gguf.add_architecture (llm_arch) self.gguf.add_context_length (arch, params.n_ctx)
self.gguf.add_context_length (llm_arch, params.n_ctx) self.gguf.add_embedding_length (arch, params.n_embd)
self.gguf.add_embedding_length (llm_arch, params.n_embd) self.gguf.add_block_count (arch, params.n_layer)
self.gguf.add_block_count (llm_arch, params.n_layer) self.gguf.add_feed_forward_length (arch, params.n_ff)
self.gguf.add_feed_forward_length (llm_arch, params.n_ff) self.gguf.add_rope_dimension_count(arch, params.n_embd // params.n_head)
self.gguf.add_rope_dimension_count(llm_arch, params.n_embd // params.n_head) self.gguf.add_head_count (arch, params.n_head)
self.gguf.add_head_count (llm_arch, params.n_head) self.gguf.add_head_count_kv (arch, params.n_head_kv)
self.gguf.add_head_count_kv (llm_arch, params.n_head_kv) self.gguf.add_layer_norm_rms_eps (arch, params.f_norm_eps)
self.gguf.add_layer_norm_rms_eps (llm_arch, params.f_norm_eps)
def add_meta_vocab(self, vocab: Vocab) -> None: def add_meta_vocab(self, vocab: Vocab) -> None:
tokens = [] tokens = []
@ -754,7 +772,7 @@ class OutputFile:
def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType: def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType:
wq_type = model[gguf.MODEL_TENSOR_NAMES[gguf.MODEL_ARCH.LLAMA][gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type wq_type = model[NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type
if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)): if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)):
return GGMLFileType.AllF32 return GGMLFileType.AllF32
@ -767,7 +785,7 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi
def convert_model_names(model: LazyModel, params: Params) -> LazyModel: def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
tmap = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAMA, params.n_layer) tmap = gguf.get_tensor_name_map(ARCH, params.n_layer)
out: LazyModel = {} out: LazyModel = {}
for name, lazy_tensor in model.items(): for name, lazy_tensor in model.items():
@ -782,9 +800,11 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
else: else:
raise Exception(f"Unexpected tensor name: {name}") raise Exception(f"Unexpected tensor name: {name}")
out[name_new] = lazy_tensor if gguf.should_skip_tensor(ARCH, params.n_layer, name_new):
print(f"skipping tensor {name_new}")
else:
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}") print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}")
out[name_new] = lazy_tensor
return out return out
@ -961,7 +981,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
vocab = load_vocab(vocab_dir, args.vocabtype) vocab = load_vocab(vocab_dir, args.vocabtype)
model = model_plus.model model = model_plus.model
model = convert_model_names(model, params) # TODO: utilize gguf.get_tensor_name_map model = convert_model_names(model, params)
output_type = pick_output_type(model, args.outtype) output_type = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, output_type) model = convert_to_output_type(model, output_type)
outfile = args.outfile or default_outfile(model_plus.paths, output_type) outfile = args.outfile or default_outfile(model_plus.paths, output_type)

14
gguf.py
View file

@ -159,11 +159,19 @@ MODEL_TENSOR_NAMES = {
# tensors that will not be serialized # tensors that will not be serialized
MODEL_TENSOR_SKIP = { MODEL_TENSOR_SKIP = {
MODEL_ARCH.LLAMA : { MODEL_ARCH.LLAMA : [
MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD, MODEL_TENSOR.ATTN_ROT_EMBD,
}, ],
}, }
def should_skip_tensor(arch : MODEL_ARCH, n_blocks : int, name : str) -> bool:
for skip in MODEL_TENSOR_SKIP.get(arch, []):
for i in range(n_blocks):
if name == MODEL_TENSOR_NAMES[arch][skip].format(bid=i):
return True
return False
def get_tensor_name_map(arch : MODEL_ARCH, n_blocks : int) -> dict: def get_tensor_name_map(arch : MODEL_ARCH, n_blocks : int) -> dict:
tensor_map = {} tensor_map = {}