convert-new.py : add gguf key-value pairs
This commit is contained in:
parent
250cf83847
commit
8be49fdf9e
1 changed files with 99 additions and 63 deletions
108
convert-new.py
108
convert-new.py
|
@ -117,10 +117,12 @@ class Params:
|
||||||
n_vocab: int
|
n_vocab: int
|
||||||
n_embd: int
|
n_embd: int
|
||||||
n_mult: int
|
n_mult: int
|
||||||
n_head: int
|
|
||||||
n_layer: int
|
n_layer: int
|
||||||
n_ctx: int
|
n_ctx: int
|
||||||
n_kv_head: Optional[int] # This parameter is only used for Llama 2
|
n_ff: int
|
||||||
|
n_head: int
|
||||||
|
n_head_kv: int
|
||||||
|
f_norm_eps: float
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def guessed(model: 'LazyModel') -> 'Params':
|
def guessed(model: 'LazyModel') -> 'Params':
|
||||||
|
@ -139,16 +141,23 @@ class Params:
|
||||||
raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n"
|
raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n"
|
||||||
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")
|
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")
|
||||||
|
|
||||||
n_head=n_embd // 128 # guessed
|
n_head = n_embd // 128 # guessed
|
||||||
|
n_mult = 255 # guessed
|
||||||
|
|
||||||
|
# TODO: verify this
|
||||||
|
n_ff = int(2 * (4 * n_embd) / 3)
|
||||||
|
n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult)
|
||||||
|
|
||||||
return Params(
|
return Params(
|
||||||
n_vocab = n_vocab,
|
n_vocab = n_vocab,
|
||||||
n_embd = n_embd,
|
n_embd = n_embd,
|
||||||
n_mult = 256,
|
n_mult = 256,
|
||||||
n_head = n_head,
|
|
||||||
n_layer = n_layer,
|
n_layer = n_layer,
|
||||||
n_ctx = -1,
|
n_ctx = -1,
|
||||||
n_kv_head = None,
|
n_ff = n_ff,
|
||||||
|
n_head = n_head,
|
||||||
|
n_head_kv = n_head,
|
||||||
|
f_norm_eps = 1e-5,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -157,10 +166,11 @@ class Params:
|
||||||
|
|
||||||
n_vocab = config["vocab_size"];
|
n_vocab = config["vocab_size"];
|
||||||
n_embd = config["hidden_size"];
|
n_embd = config["hidden_size"];
|
||||||
n_head = config["num_attention_heads"];
|
|
||||||
n_layer = config["num_hidden_layers"];
|
n_layer = config["num_hidden_layers"];
|
||||||
n_ff = config["intermediate_size"];
|
n_ff = config["intermediate_size"];
|
||||||
n_kv_head = config.get("num_key_value_heads")
|
n_head = config["num_attention_heads"];
|
||||||
|
n_head_kv = config["num_key_value_heads"];
|
||||||
|
f_norm_eps = config["rms_norm_eps"];
|
||||||
|
|
||||||
n_mult = find_n_mult(n_ff, n_embd);
|
n_mult = find_n_mult(n_ff, n_embd);
|
||||||
|
|
||||||
|
@ -176,10 +186,12 @@ class Params:
|
||||||
n_vocab = n_vocab,
|
n_vocab = n_vocab,
|
||||||
n_embd = n_embd,
|
n_embd = n_embd,
|
||||||
n_mult = n_mult,
|
n_mult = n_mult,
|
||||||
n_head = n_head,
|
|
||||||
n_layer = n_layer,
|
n_layer = n_layer,
|
||||||
n_ctx = n_ctx,
|
n_ctx = n_ctx,
|
||||||
n_kv_head = n_kv_head,
|
n_ff = n_ff,
|
||||||
|
n_head = n_head,
|
||||||
|
n_head_kv = n_head_kv,
|
||||||
|
f_norm_eps = f_norm_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
# LLaMA v2 70B params.json
|
# LLaMA v2 70B params.json
|
||||||
|
@ -190,21 +202,30 @@ class Params:
|
||||||
|
|
||||||
n_vocab = config["vocab_size"];
|
n_vocab = config["vocab_size"];
|
||||||
n_embd = config["dim"];
|
n_embd = config["dim"];
|
||||||
n_head = config["n_heads"];
|
|
||||||
n_layer = config["n_layers"];
|
n_layer = config["n_layers"];
|
||||||
n_mult = config["multiple_of"];
|
n_mult = config["multiple_of"];
|
||||||
|
n_ctx = 2048 if config["norm_eps"] == 1e-06 else 4096 # hack to determine LLaMA v1 vs v2
|
||||||
|
n_ff = -1;
|
||||||
|
n_head = config["n_heads"];
|
||||||
|
n_head_kv = config["n_kv_head"] if "n_kv_head" in config else n_head;
|
||||||
|
f_norm_eps = config["norm_eps"];
|
||||||
|
|
||||||
if n_vocab == -1:
|
if n_vocab == -1:
|
||||||
n_vocab = model["tok_embeddings.weight"].shape[0]
|
n_vocab = model["tok_embeddings.weight"].shape[0]
|
||||||
|
|
||||||
|
if n_ff == -1:
|
||||||
|
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
|
||||||
|
|
||||||
return Params(
|
return Params(
|
||||||
n_vocab = n_vocab,
|
n_vocab = n_vocab,
|
||||||
n_embd = n_embd,
|
n_embd = n_embd,
|
||||||
n_mult = n_mult,
|
n_mult = n_mult,
|
||||||
n_head = n_head,
|
|
||||||
n_layer = n_layer,
|
n_layer = n_layer,
|
||||||
n_ctx = -1,
|
n_ctx = n_ctx,
|
||||||
n_kv_head = None,
|
n_ff = n_ff,
|
||||||
|
n_head = n_head,
|
||||||
|
n_head_kv = n_head_kv,
|
||||||
|
f_norm_eps = f_norm_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -310,9 +331,9 @@ class SentencePieceVocab:
|
||||||
Vocab = Union[BpeVocab, SentencePieceVocab]
|
Vocab = Union[BpeVocab, SentencePieceVocab]
|
||||||
|
|
||||||
|
|
||||||
def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
|
def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
|
||||||
if n_kv_head is not None and n_head != n_kv_head:
|
if n_head_kv is not None and n_head != n_head_kv:
|
||||||
n_head //= n_kv_head
|
n_head //= n_head_kv
|
||||||
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||||
.swapaxes(1, 2)
|
.swapaxes(1, 2)
|
||||||
.reshape(weights.shape))
|
.reshape(weights.shape))
|
||||||
|
@ -324,7 +345,7 @@ class Tensor(metaclass=ABCMeta):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def astype(self, data_type: DataType) -> 'Tensor': ...
|
def astype(self, data_type: DataType) -> 'Tensor': ...
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'Tensor': ...
|
def permute(self, n_head: int, n_head_kv: int) -> 'Tensor': ...
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ...
|
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ...
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -362,8 +383,8 @@ class UnquantizedTensor(Tensor):
|
||||||
r = self.ndarray.shape[0] // 3
|
r = self.ndarray.shape[0] // 3
|
||||||
return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
|
return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
|
||||||
|
|
||||||
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'UnquantizedTensor':
|
def permute(self, n_head: int, n_head_kv: int) -> 'UnquantizedTensor':
|
||||||
return UnquantizedTensor(permute(self.ndarray, n_head, n_kv_head))
|
return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv))
|
||||||
|
|
||||||
|
|
||||||
def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, convert: bool = False) -> NDArray:
|
def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, convert: bool = False) -> NDArray:
|
||||||
|
@ -386,18 +407,18 @@ GGMLCompatibleTensor = Union[UnquantizedTensor]
|
||||||
|
|
||||||
|
|
||||||
class DeferredPermutedTensor(Tensor):
|
class DeferredPermutedTensor(Tensor):
|
||||||
def __init__(self, base: Tensor, n_head: int, n_kv_head: Optional[int] = None) -> None:
|
def __init__(self, base: Tensor, n_head: int, n_head_kv: int) -> None:
|
||||||
self.base = base
|
self.base = base
|
||||||
self.n_head = n_head
|
self.n_head = n_head
|
||||||
self.data_type = self.base.data_type
|
self.data_type = self.base.data_type
|
||||||
|
|
||||||
def astype(self, data_type: DataType) -> Tensor:
|
def astype(self, data_type: DataType) -> Tensor:
|
||||||
return self.base.astype(data_type).permute(self.n_head, self.n_kv_head)
|
return self.base.astype(data_type).permute(self.n_head, self.n_head_kv)
|
||||||
|
|
||||||
def to_ggml(self) -> GGMLCompatibleTensor:
|
def to_ggml(self) -> GGMLCompatibleTensor:
|
||||||
return self.base.to_ggml().permute(self.n_head, self.n_kv_head)
|
return self.base.to_ggml().permute(self.n_head, self.n_head_kv)
|
||||||
|
|
||||||
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> Tensor:
|
def permute(self, n_head: int, n_head_kv: int) -> Tensor:
|
||||||
raise Exception("shouldn't permute twice")
|
raise Exception("shouldn't permute twice")
|
||||||
|
|
||||||
|
|
||||||
|
@ -493,10 +514,10 @@ def merge_multifile_models(models_plus: List[ModelPlus]) -> ModelPlus:
|
||||||
return ModelPlus(model, paths, format, vocab)
|
return ModelPlus(model, paths, format, vocab)
|
||||||
|
|
||||||
|
|
||||||
def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_kv_head: Optional[int] = None) -> LazyTensor:
|
def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor:
|
||||||
def load() -> Tensor:
|
def load() -> Tensor:
|
||||||
return lazy_tensor.load().permute(n_head, n_kv_head)
|
return lazy_tensor.load().permute(n_head, n_head_kv)
|
||||||
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_kv_head}) ' + lazy_tensor.description)
|
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description)
|
||||||
|
|
||||||
def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor:
|
def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor:
|
||||||
def load() -> Tensor:
|
def load() -> Tensor:
|
||||||
|
@ -521,7 +542,7 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
|
||||||
for i in itertools.count():
|
for i in itertools.count():
|
||||||
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
|
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
|
||||||
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
|
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
|
||||||
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_kv_head)
|
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv)
|
||||||
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
|
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
|
||||||
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
|
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
|
||||||
out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
|
out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
|
||||||
|
@ -732,9 +753,15 @@ class OutputFile:
|
||||||
def write_file_header(self, params: Params, file_type: GGMLFileType) -> None:
|
def write_file_header(self, params: Params, file_type: GGMLFileType) -> None:
|
||||||
llm_arch = "llama"
|
llm_arch = "llama"
|
||||||
|
|
||||||
self.gguf.add_architecture(llm_arch)
|
self.gguf.add_architecture (llm_arch)
|
||||||
self.gguf.add_context_length(llm_arch, params.n_ctx)
|
self.gguf.add_context_length (llm_arch, params.n_ctx)
|
||||||
self.gguf.add_embedding_length(llm_arch, params.n_embd)
|
self.gguf.add_embedding_length (llm_arch, params.n_embd)
|
||||||
|
self.gguf.add_block_count (llm_arch, params.n_layer)
|
||||||
|
self.gguf.add_feed_forward_length (llm_arch, params.n_ff)
|
||||||
|
self.gguf.add_rope_dimension_count(llm_arch, params.n_embd // params.n_head)
|
||||||
|
self.gguf.add_head_count (llm_arch, params.n_head)
|
||||||
|
self.gguf.add_head_count_kv (llm_arch, params.n_head_kv)
|
||||||
|
self.gguf.add_layer_norm_rms_eps (llm_arch, params.f_norm_eps)
|
||||||
|
|
||||||
def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None:
|
def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None:
|
||||||
sname = name.encode('utf-8')
|
sname = name.encode('utf-8')
|
||||||
|
@ -744,15 +771,22 @@ class OutputFile:
|
||||||
self.fout.seek((self.fout.tell() + 31) & -32)
|
self.fout.seek((self.fout.tell() + 31) & -32)
|
||||||
|
|
||||||
def write_vocab(self, vocab: Vocab) -> None:
|
def write_vocab(self, vocab: Vocab) -> None:
|
||||||
|
tokens = []
|
||||||
|
scores = []
|
||||||
for text, score in vocab.all_tokens():
|
for text, score in vocab.all_tokens():
|
||||||
self.fout.write(struct.pack("i", len(text)))
|
tokens.append(text)
|
||||||
self.fout.write(text)
|
scores.append(score)
|
||||||
self.fout.write(struct.pack("f", score))
|
|
||||||
|
self.gguf.add_tokenizer_model("llama")
|
||||||
|
self.gguf.add_token_list(tokens)
|
||||||
|
self.gguf.add_token_scores(scores)
|
||||||
|
#self.gguf.add_token_types(toktypes) # TODO: add this
|
||||||
|
|
||||||
|
# TODO: added / special tokens
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_vocab_only(fname_out: Path, vocab: Vocab) -> None:
|
def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None:
|
||||||
of = OutputFile(fname_out)
|
of = OutputFile(fname_out)
|
||||||
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0)
|
|
||||||
of = OutputFile(fname_out)
|
of = OutputFile(fname_out)
|
||||||
of.write_file_header(params, file_type=GGMLFileType.AllF32)
|
of.write_file_header(params, file_type=GGMLFileType.AllF32)
|
||||||
of.write_vocab(vocab)
|
of.write_vocab(vocab)
|
||||||
|
@ -941,12 +975,12 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
||||||
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
|
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
|
||||||
args = parser.parse_args(args_in)
|
args = parser.parse_args(args_in)
|
||||||
|
|
||||||
vocab: Vocab
|
|
||||||
if args.dump_single:
|
if args.dump_single:
|
||||||
model_plus = lazy_load_file(args.model)
|
model_plus = lazy_load_file(args.model)
|
||||||
do_dump_model(model_plus)
|
do_dump_model(model_plus)
|
||||||
|
|
||||||
model_plus = load_some_model(args.model)
|
model_plus = load_some_model(args.model)
|
||||||
|
|
||||||
params = Params.load(model_plus)
|
params = Params.load(model_plus)
|
||||||
if params.n_ctx == -1:
|
if params.n_ctx == -1:
|
||||||
if args.ctx is None:
|
if args.ctx is None:
|
||||||
|
@ -958,6 +992,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
||||||
|
|
||||||
print(f"params = {params}")
|
print(f"params = {params}")
|
||||||
|
|
||||||
|
vocab: Vocab
|
||||||
if args.vocab_only:
|
if args.vocab_only:
|
||||||
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
|
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
|
||||||
assert args.outfile, "need --outfile if using --vocab-only"
|
assert args.outfile, "need --outfile if using --vocab-only"
|
||||||
|
@ -968,6 +1003,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
||||||
if args.dump:
|
if args.dump:
|
||||||
do_dump_model(model_plus)
|
do_dump_model(model_plus)
|
||||||
return
|
return
|
||||||
|
|
||||||
if model_plus.vocab is not None and args.vocab_dir is None:
|
if model_plus.vocab is not None and args.vocab_dir is None:
|
||||||
vocab = model_plus.vocab
|
vocab = model_plus.vocab
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue