convert-new.py : add gguf key-value pairs

This commit is contained in:
Georgi Gerganov 2023-08-16 21:52:06 +03:00
parent 250cf83847
commit 8be49fdf9e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -117,10 +117,12 @@ class Params:
n_vocab: int n_vocab: int
n_embd: int n_embd: int
n_mult: int n_mult: int
n_head: int
n_layer: int n_layer: int
n_ctx: int n_ctx: int
n_kv_head: Optional[int] # This parameter is only used for Llama 2 n_ff: int
n_head: int
n_head_kv: int
f_norm_eps: float
@staticmethod @staticmethod
def guessed(model: 'LazyModel') -> 'Params': def guessed(model: 'LazyModel') -> 'Params':
@ -140,15 +142,22 @@ class Params:
"Suggestion: provide 'config.json' of the model in the same directory containing model files.") "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
n_head = n_embd // 128 # guessed n_head = n_embd // 128 # guessed
n_mult = 255 # guessed
# TODO: verify this
n_ff = int(2 * (4 * n_embd) / 3)
n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult)
return Params( return Params(
n_vocab = n_vocab, n_vocab = n_vocab,
n_embd = n_embd, n_embd = n_embd,
n_mult = 256, n_mult = 256,
n_head = n_head,
n_layer = n_layer, n_layer = n_layer,
n_ctx = -1, n_ctx = -1,
n_kv_head = None, n_ff = n_ff,
n_head = n_head,
n_head_kv = n_head,
f_norm_eps = 1e-5,
) )
@staticmethod @staticmethod
@ -157,10 +166,11 @@ class Params:
n_vocab = config["vocab_size"]; n_vocab = config["vocab_size"];
n_embd = config["hidden_size"]; n_embd = config["hidden_size"];
n_head = config["num_attention_heads"];
n_layer = config["num_hidden_layers"]; n_layer = config["num_hidden_layers"];
n_ff = config["intermediate_size"]; n_ff = config["intermediate_size"];
n_kv_head = config.get("num_key_value_heads") n_head = config["num_attention_heads"];
n_head_kv = config["num_key_value_heads"];
f_norm_eps = config["rms_norm_eps"];
n_mult = find_n_mult(n_ff, n_embd); n_mult = find_n_mult(n_ff, n_embd);
@ -176,10 +186,12 @@ class Params:
n_vocab = n_vocab, n_vocab = n_vocab,
n_embd = n_embd, n_embd = n_embd,
n_mult = n_mult, n_mult = n_mult,
n_head = n_head,
n_layer = n_layer, n_layer = n_layer,
n_ctx = n_ctx, n_ctx = n_ctx,
n_kv_head = n_kv_head, n_ff = n_ff,
n_head = n_head,
n_head_kv = n_head_kv,
f_norm_eps = f_norm_eps,
) )
# LLaMA v2 70B params.json # LLaMA v2 70B params.json
@ -190,21 +202,30 @@ class Params:
n_vocab = config["vocab_size"]; n_vocab = config["vocab_size"];
n_embd = config["dim"]; n_embd = config["dim"];
n_head = config["n_heads"];
n_layer = config["n_layers"]; n_layer = config["n_layers"];
n_mult = config["multiple_of"]; n_mult = config["multiple_of"];
n_ctx = 2048 if config["norm_eps"] == 1e-06 else 4096 # hack to determine LLaMA v1 vs v2
n_ff = -1;
n_head = config["n_heads"];
n_head_kv = config["n_kv_head"] if "n_kv_head" in config else n_head;
f_norm_eps = config["norm_eps"];
if n_vocab == -1: if n_vocab == -1:
n_vocab = model["tok_embeddings.weight"].shape[0] n_vocab = model["tok_embeddings.weight"].shape[0]
if n_ff == -1:
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
return Params( return Params(
n_vocab = n_vocab, n_vocab = n_vocab,
n_embd = n_embd, n_embd = n_embd,
n_mult = n_mult, n_mult = n_mult,
n_head = n_head,
n_layer = n_layer, n_layer = n_layer,
n_ctx = -1, n_ctx = n_ctx,
n_kv_head = None, n_ff = n_ff,
n_head = n_head,
n_head_kv = n_head_kv,
f_norm_eps = f_norm_eps,
) )
@staticmethod @staticmethod
@ -310,9 +331,9 @@ class SentencePieceVocab:
Vocab = Union[BpeVocab, SentencePieceVocab] Vocab = Union[BpeVocab, SentencePieceVocab]
def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray: def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
if n_kv_head is not None and n_head != n_kv_head: if n_head_kv is not None and n_head != n_head_kv:
n_head //= n_kv_head n_head //= n_head_kv
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2) .swapaxes(1, 2)
.reshape(weights.shape)) .reshape(weights.shape))
@ -324,7 +345,7 @@ class Tensor(metaclass=ABCMeta):
@abstractmethod @abstractmethod
def astype(self, data_type: DataType) -> 'Tensor': ... def astype(self, data_type: DataType) -> 'Tensor': ...
@abstractmethod @abstractmethod
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'Tensor': ... def permute(self, n_head: int, n_head_kv: int) -> 'Tensor': ...
@abstractmethod @abstractmethod
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ... def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ...
@abstractmethod @abstractmethod
@ -362,8 +383,8 @@ class UnquantizedTensor(Tensor):
r = self.ndarray.shape[0] // 3 r = self.ndarray.shape[0] // 3
return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...]) return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'UnquantizedTensor': def permute(self, n_head: int, n_head_kv: int) -> 'UnquantizedTensor':
return UnquantizedTensor(permute(self.ndarray, n_head, n_kv_head)) return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv))
def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, convert: bool = False) -> NDArray: def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, convert: bool = False) -> NDArray:
@ -386,18 +407,18 @@ GGMLCompatibleTensor = Union[UnquantizedTensor]
class DeferredPermutedTensor(Tensor): class DeferredPermutedTensor(Tensor):
def __init__(self, base: Tensor, n_head: int, n_kv_head: Optional[int] = None) -> None: def __init__(self, base: Tensor, n_head: int, n_head_kv: int) -> None:
self.base = base self.base = base
self.n_head = n_head self.n_head = n_head
self.data_type = self.base.data_type self.data_type = self.base.data_type
def astype(self, data_type: DataType) -> Tensor: def astype(self, data_type: DataType) -> Tensor:
return self.base.astype(data_type).permute(self.n_head, self.n_kv_head) return self.base.astype(data_type).permute(self.n_head, self.n_head_kv)
def to_ggml(self) -> GGMLCompatibleTensor: def to_ggml(self) -> GGMLCompatibleTensor:
return self.base.to_ggml().permute(self.n_head, self.n_kv_head) return self.base.to_ggml().permute(self.n_head, self.n_head_kv)
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> Tensor: def permute(self, n_head: int, n_head_kv: int) -> Tensor:
raise Exception("shouldn't permute twice") raise Exception("shouldn't permute twice")
@ -493,10 +514,10 @@ def merge_multifile_models(models_plus: List[ModelPlus]) -> ModelPlus:
return ModelPlus(model, paths, format, vocab) return ModelPlus(model, paths, format, vocab)
def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_kv_head: Optional[int] = None) -> LazyTensor: def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor:
def load() -> Tensor: def load() -> Tensor:
return lazy_tensor.load().permute(n_head, n_kv_head) return lazy_tensor.load().permute(n_head, n_head_kv)
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_kv_head}) ' + lazy_tensor.description) return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description)
def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor: def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor:
def load() -> Tensor: def load() -> Tensor:
@ -521,7 +542,7 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
for i in itertools.count(): for i in itertools.count():
if f"model.layers.{i}.self_attn.q_proj.weight" in model: if f"model.layers.{i}.self_attn.q_proj.weight" in model:
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head) out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_kv_head) out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv)
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
elif f"model.layers.{i}.self_attn.W_pack.weight" in model: elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head) out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
@ -735,6 +756,12 @@ class OutputFile:
self.gguf.add_architecture (llm_arch) self.gguf.add_architecture (llm_arch)
self.gguf.add_context_length (llm_arch, params.n_ctx) self.gguf.add_context_length (llm_arch, params.n_ctx)
self.gguf.add_embedding_length (llm_arch, params.n_embd) self.gguf.add_embedding_length (llm_arch, params.n_embd)
self.gguf.add_block_count (llm_arch, params.n_layer)
self.gguf.add_feed_forward_length (llm_arch, params.n_ff)
self.gguf.add_rope_dimension_count(llm_arch, params.n_embd // params.n_head)
self.gguf.add_head_count (llm_arch, params.n_head)
self.gguf.add_head_count_kv (llm_arch, params.n_head_kv)
self.gguf.add_layer_norm_rms_eps (llm_arch, params.f_norm_eps)
def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None: def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None:
sname = name.encode('utf-8') sname = name.encode('utf-8')
@ -744,15 +771,22 @@ class OutputFile:
self.fout.seek((self.fout.tell() + 31) & -32) self.fout.seek((self.fout.tell() + 31) & -32)
def write_vocab(self, vocab: Vocab) -> None: def write_vocab(self, vocab: Vocab) -> None:
tokens = []
scores = []
for text, score in vocab.all_tokens(): for text, score in vocab.all_tokens():
self.fout.write(struct.pack("i", len(text))) tokens.append(text)
self.fout.write(text) scores.append(score)
self.fout.write(struct.pack("f", score))
self.gguf.add_tokenizer_model("llama")
self.gguf.add_token_list(tokens)
self.gguf.add_token_scores(scores)
#self.gguf.add_token_types(toktypes) # TODO: add this
# TODO: added / special tokens
@staticmethod @staticmethod
def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None:
of = OutputFile(fname_out) of = OutputFile(fname_out)
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0)
of = OutputFile(fname_out) of = OutputFile(fname_out)
of.write_file_header(params, file_type=GGMLFileType.AllF32) of.write_file_header(params, file_type=GGMLFileType.AllF32)
of.write_vocab(vocab) of.write_vocab(vocab)
@ -941,12 +975,12 @@ def main(args_in: Optional[List[str]] = None) -> None:
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
args = parser.parse_args(args_in) args = parser.parse_args(args_in)
vocab: Vocab
if args.dump_single: if args.dump_single:
model_plus = lazy_load_file(args.model) model_plus = lazy_load_file(args.model)
do_dump_model(model_plus) do_dump_model(model_plus)
model_plus = load_some_model(args.model) model_plus = load_some_model(args.model)
params = Params.load(model_plus) params = Params.load(model_plus)
if params.n_ctx == -1: if params.n_ctx == -1:
if args.ctx is None: if args.ctx is None:
@ -958,6 +992,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
print(f"params = {params}") print(f"params = {params}")
vocab: Vocab
if args.vocab_only: if args.vocab_only:
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype) vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
assert args.outfile, "need --outfile if using --vocab-only" assert args.outfile, "need --outfile if using --vocab-only"
@ -968,6 +1003,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
if args.dump: if args.dump:
do_dump_model(model_plus) do_dump_model(model_plus)
return return
if model_plus.vocab is not None and args.vocab_dir is None: if model_plus.vocab is not None and args.vocab_dir is None:
vocab = model_plus.vocab vocab = model_plus.vocab
else: else: