fix linter complaints
This commit is contained in:
parent
66ccd62102
commit
e9abcc9c7c
2 changed files with 211 additions and 224 deletions
|
@ -41,12 +41,12 @@ model_instance.set_gguf_parameters()
|
||||||
print("Set model tokenizer")
|
print("Set model tokenizer")
|
||||||
model_instance.set_vocab()
|
model_instance.set_vocab()
|
||||||
|
|
||||||
if not args.vocab_only:
|
if args.vocab_only:
|
||||||
print(f"Exporting model to '{fname_out}'")
|
|
||||||
model_instance.write()
|
|
||||||
else:
|
|
||||||
print(f"Exporting model vocab to '{fname_out}'")
|
print(f"Exporting model vocab to '{fname_out}'")
|
||||||
model_instance.write_vocab()
|
model_instance.write_vocab()
|
||||||
|
else:
|
||||||
|
print(f"Exporting model to '{fname_out}'")
|
||||||
|
model_instance.write()
|
||||||
|
|
||||||
print(f"Model successfully exported to '{fname_out}'")
|
print(f"Model successfully exported to '{fname_out}'")
|
||||||
print("")
|
print("")
|
||||||
|
|
427
model.py
427
model.py
|
@ -5,7 +5,7 @@ import re
|
||||||
import sys
|
import sys
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, TypeAlias, cast
|
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -39,6 +39,128 @@ class Model:
|
||||||
self.model_arch = self._get_model_architecture()
|
self.model_arch = self._get_model_architecture()
|
||||||
self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch])
|
self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch])
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
self._set_vocab_gpt2()
|
||||||
|
|
||||||
|
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||||
|
for part_name in self.part_names:
|
||||||
|
print(f"gguf: loading model part '{part_name}'")
|
||||||
|
ctx: ContextManager[Any]
|
||||||
|
if self.is_safetensors:
|
||||||
|
from safetensors import safe_open
|
||||||
|
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
|
||||||
|
else:
|
||||||
|
ctx = contextlib.nullcontext(torch.load(self.dir_model / part_name, map_location="cpu"))
|
||||||
|
|
||||||
|
with ctx as model_part:
|
||||||
|
for name in model_part.keys():
|
||||||
|
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
|
||||||
|
yield name, data
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
self.gguf_writer.add_name(self.dir_model.name)
|
||||||
|
self.gguf_writer.add_block_count(self.hparams.get(
|
||||||
|
"n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")),
|
||||||
|
))
|
||||||
|
if (n_ctx := self.hparams.get("max_position_embeddings")) is not None:
|
||||||
|
self.gguf_writer.add_context_length(n_ctx)
|
||||||
|
if (n_embd := self.hparams.get("hidden_size")) is not None:
|
||||||
|
self.gguf_writer.add_embedding_length(n_embd)
|
||||||
|
if (n_ff := self.hparams.get("intermediate_size")) is not None:
|
||||||
|
self.gguf_writer.add_feed_forward_length(n_ff)
|
||||||
|
if (n_head := self.hparams.get("num_attention_head")) is not None:
|
||||||
|
self.gguf_writer.add_head_count(n_head)
|
||||||
|
self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
|
||||||
|
|
||||||
|
def write_tensors(self):
|
||||||
|
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
||||||
|
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||||
|
for name, data_torch in self.get_tensors():
|
||||||
|
# we don't need these
|
||||||
|
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
|
||||||
|
continue
|
||||||
|
|
||||||
|
old_dtype = data_torch.dtype
|
||||||
|
|
||||||
|
# convert any unsupported data types to float32
|
||||||
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
|
data_torch = data_torch.to(torch.float32)
|
||||||
|
|
||||||
|
data = data_torch.squeeze().numpy()
|
||||||
|
|
||||||
|
# map tensor names
|
||||||
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
|
if new_name is None:
|
||||||
|
print(f"Can not map tensor {name!r}")
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
n_dims = len(data.shape)
|
||||||
|
data_dtype = data.dtype
|
||||||
|
|
||||||
|
# if f32 desired, convert any float16 to float32
|
||||||
|
if self.ftype == 0 and data_dtype == np.float16:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
|
||||||
|
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||||
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
def write(self):
|
||||||
|
self.write_tensors()
|
||||||
|
self.gguf_writer.write_header_to_file()
|
||||||
|
self.gguf_writer.write_kv_data_to_file()
|
||||||
|
self.gguf_writer.write_tensors_to_file()
|
||||||
|
self.gguf_writer.close()
|
||||||
|
|
||||||
|
def write_vocab(self):
|
||||||
|
self.gguf_writer.write_header_to_file()
|
||||||
|
self.gguf_writer.write_kv_data_to_file()
|
||||||
|
self.gguf_writer.close()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def count_model_parts(dir_model: Path, prefix: str) -> int:
|
||||||
|
num_parts = 0
|
||||||
|
for filename in os.listdir(dir_model):
|
||||||
|
if filename.endswith(prefix):
|
||||||
|
num_parts += 1
|
||||||
|
|
||||||
|
return num_parts
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_hparams(dir_model):
|
||||||
|
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_model_architecture(model_architecture):
|
||||||
|
if model_architecture == "StableLMEpochForCausalLM":
|
||||||
|
return StableLMModel
|
||||||
|
if model_architecture == "GPTNeoXForCausalLM":
|
||||||
|
return GPTNeoXModel
|
||||||
|
if model_architecture == "BloomForCausalLM":
|
||||||
|
return BloomModel
|
||||||
|
if model_architecture == "MPTForCausalLM":
|
||||||
|
return MPTModel
|
||||||
|
if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
|
||||||
|
return BaichuanModel
|
||||||
|
if model_architecture == "FalconForCausalLM":
|
||||||
|
return FalconModel
|
||||||
|
if model_architecture == "GPTBigCodeForCausalLM":
|
||||||
|
return StarCoderModel
|
||||||
|
if model_architecture == "GPTRefactForCausalLM":
|
||||||
|
return RefactModel
|
||||||
|
if model_architecture == "PersimmonForCausalLM":
|
||||||
|
return PersimmonModel
|
||||||
|
return Model
|
||||||
|
|
||||||
def _is_model_safetensors(self) -> bool:
|
def _is_model_safetensors(self) -> bool:
|
||||||
return Model.count_model_parts(self.dir_model, ".safetensors") > 0
|
return Model.count_model_parts(self.dir_model, ".safetensors") > 0
|
||||||
|
|
||||||
|
@ -47,10 +169,10 @@ class Model:
|
||||||
if self.num_parts == 1: # there's only one .safetensors file
|
if self.num_parts == 1: # there's only one .safetensors file
|
||||||
return ("model.safetensors",)
|
return ("model.safetensors",)
|
||||||
return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))
|
return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1))
|
||||||
else:
|
|
||||||
if self.num_parts == 1: # there's only one .bin file
|
if self.num_parts == 1: # there's only one .bin file
|
||||||
return ("pytorch_model.bin",)
|
return ("pytorch_model.bin",)
|
||||||
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
|
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
|
||||||
|
|
||||||
def _get_model_architecture(self) -> gguf.MODEL_ARCH:
|
def _get_model_architecture(self) -> gguf.MODEL_ARCH:
|
||||||
arch = self.hparams["architectures"][0]
|
arch = self.hparams["architectures"][0]
|
||||||
|
@ -84,7 +206,7 @@ class Model:
|
||||||
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
|
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
|
||||||
assert max(tokenizer.vocab.values()) < vocab_size
|
assert max(tokenizer.vocab.values()) < vocab_size
|
||||||
|
|
||||||
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
|
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
|
||||||
added_vocab = tokenizer.get_added_vocab()
|
added_vocab = tokenizer.get_added_vocab()
|
||||||
|
|
||||||
for i in range(vocab_size):
|
for i in range(vocab_size):
|
||||||
|
@ -162,135 +284,13 @@ class Model:
|
||||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||||
special_vocab.add_to_gguf(self.gguf_writer)
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
|
|
||||||
def set_vocab(self):
|
|
||||||
self._set_vocab_gpt2()
|
|
||||||
|
|
||||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
|
||||||
for part_name in self.part_names:
|
|
||||||
print("gguf: loading model part '" + part_name + "'")
|
|
||||||
ctx: ContextManager[Any]
|
|
||||||
if self.is_safetensors:
|
|
||||||
from safetensors import safe_open
|
|
||||||
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
|
|
||||||
else:
|
|
||||||
ctx = contextlib.nullcontext(torch.load(self.dir_model / part_name, map_location="cpu"))
|
|
||||||
|
|
||||||
with ctx as model_part:
|
|
||||||
for name in model_part.keys():
|
|
||||||
data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
|
|
||||||
yield name, data
|
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
|
||||||
self.gguf_writer.add_name(self.dir_model.name)
|
|
||||||
self.gguf_writer.add_block_count(self.hparams.get(
|
|
||||||
"n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))))
|
|
||||||
if "max_position_embeddings" in self.hparams:
|
|
||||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
|
||||||
if "hidden_size" in self.hparams:
|
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
|
||||||
if "intermediate_size" in self.hparams:
|
|
||||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
|
||||||
if "num_attention_head" in self.hparams:
|
|
||||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
|
||||||
self.gguf_writer.add_parallel_residual(
|
|
||||||
self.hparams["use_parallel_residual"] if "use_parallel_residual" in self.hparams else True)
|
|
||||||
|
|
||||||
def write_tensors(self):
|
|
||||||
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
|
||||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
|
||||||
for name, data_torch in self.get_tensors():
|
|
||||||
# we don't need these
|
|
||||||
if name.endswith(".attention.masked_bias") or name.endswith(".attention.bias") or name.endswith(".attention.rotary_emb.inv_freq"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
old_dtype = data_torch.dtype
|
|
||||||
|
|
||||||
# convert any unsupported data types to float32
|
|
||||||
if data_torch.dtype != torch.float16 and data_torch.dtype != torch.float32:
|
|
||||||
data_torch = data_torch.to(torch.float32)
|
|
||||||
|
|
||||||
data = data_torch.squeeze().numpy()
|
|
||||||
|
|
||||||
# map tensor names
|
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
|
||||||
if new_name is None:
|
|
||||||
print("Can not map tensor '" + name + "'")
|
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
|
||||||
data_dtype = data.dtype
|
|
||||||
|
|
||||||
# if f32 desired, convert any float16 to float32
|
|
||||||
if self.ftype == 0 and data_dtype == np.float16:
|
|
||||||
data = data.astype(np.float32)
|
|
||||||
|
|
||||||
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
|
|
||||||
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
|
|
||||||
data = data.astype(np.float32)
|
|
||||||
|
|
||||||
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
|
||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
|
||||||
data = data.astype(np.float16)
|
|
||||||
|
|
||||||
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
|
||||||
|
|
||||||
def write(self):
|
|
||||||
self.write_tensors()
|
|
||||||
self.gguf_writer.write_header_to_file()
|
|
||||||
self.gguf_writer.write_kv_data_to_file()
|
|
||||||
self.gguf_writer.write_tensors_to_file()
|
|
||||||
self.gguf_writer.close()
|
|
||||||
|
|
||||||
def write_vocab(self):
|
|
||||||
self.gguf_writer.write_header_to_file()
|
|
||||||
self.gguf_writer.write_kv_data_to_file()
|
|
||||||
self.gguf_writer.close()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def count_model_parts(dir_model: Path, prefix: str) -> int:
|
|
||||||
num_parts = 0
|
|
||||||
for filename in os.listdir(dir_model):
|
|
||||||
if filename.endswith(prefix):
|
|
||||||
num_parts += 1
|
|
||||||
|
|
||||||
return num_parts
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_hparams(dir_model):
|
|
||||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
|
||||||
hparams = json.load(f)
|
|
||||||
return hparams
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_model_architecture(model_architecture):
|
|
||||||
if model_architecture == "StableLMEpochForCausalLM":
|
|
||||||
return StableLMModel
|
|
||||||
if model_architecture == "GPTNeoXForCausalLM":
|
|
||||||
return GPTNeoXModel
|
|
||||||
if model_architecture == "BloomForCausalLM":
|
|
||||||
return BloomModel
|
|
||||||
if model_architecture == "MPTForCausalLM":
|
|
||||||
return MPTModel
|
|
||||||
if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
|
|
||||||
return BaichuanModel
|
|
||||||
if model_architecture == "FalconForCausalLM":
|
|
||||||
return FalconModel
|
|
||||||
if model_architecture == "GPTBigCodeForCausalLM":
|
|
||||||
return StarCoderModel
|
|
||||||
if model_architecture == "GPTRefactForCausalLM":
|
|
||||||
return RefactModel
|
|
||||||
if model_architecture == "PersimmonForCausalLM":
|
|
||||||
return PersimmonModel
|
|
||||||
return Model
|
|
||||||
|
|
||||||
|
|
||||||
class StableLMModel(Model):
|
class StableLMModel(Model):
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
self.gguf_writer.add_rope_dimension_count(
|
self.gguf_writer.add_rope_dimension_count(
|
||||||
int(self.hparams["rope_pct"]*(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])))
|
int(self.hparams["rope_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
|
||||||
|
)
|
||||||
self.gguf_writer.add_layer_norm_eps(1e-5)
|
self.gguf_writer.add_layer_norm_eps(1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
@ -304,10 +304,10 @@ class GPTNeoXModel(Model):
|
||||||
self.gguf_writer.add_block_count(block_count)
|
self.gguf_writer.add_block_count(block_count)
|
||||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||||
self.gguf_writer.add_rope_dimension_count(
|
self.gguf_writer.add_rope_dimension_count(
|
||||||
int(self.hparams["rotary_pct"]*(self.hparams["hidden_size"]//self.hparams["num_attention_heads"])))
|
int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
|
||||||
|
)
|
||||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||||
self.gguf_writer.add_parallel_residual(
|
self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
|
||||||
self.hparams["use_parallel_residual"] if "use_parallel_residual" in self.hparams else True)
|
|
||||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
|
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
|
||||||
|
|
||||||
|
|
||||||
|
@ -342,7 +342,7 @@ class BloomModel(Model):
|
||||||
old_dtype = data_torch.dtype
|
old_dtype = data_torch.dtype
|
||||||
|
|
||||||
# convert any unsupported data types to float32
|
# convert any unsupported data types to float32
|
||||||
if data_torch.dtype != torch.float16 and data_torch.dtype != torch.float32:
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
data_torch = data_torch.to(torch.float32)
|
data_torch = data_torch.to(torch.float32)
|
||||||
|
|
||||||
data = data_torch.squeeze().numpy()
|
data = data_torch.squeeze().numpy()
|
||||||
|
@ -353,26 +353,30 @@ class BloomModel(Model):
|
||||||
# gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
|
# gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
|
||||||
qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed))
|
qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed))
|
||||||
data = np.concatenate(
|
data = np.concatenate(
|
||||||
(qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
|
(
|
||||||
qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
|
qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
|
||||||
qkv_weights[:, 2, :, :].reshape((-1, n_embed))),
|
qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
|
||||||
axis=0
|
qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
|
||||||
|
),
|
||||||
|
axis=0,
|
||||||
)
|
)
|
||||||
print("re-format attention.linear_qkv.weight")
|
print("re-format attention.linear_qkv.weight")
|
||||||
elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
|
elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
|
||||||
qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
|
qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
|
||||||
data = np.concatenate(
|
data = np.concatenate(
|
||||||
(qkv_bias[:, 0, :].reshape((n_embed,)),
|
(
|
||||||
qkv_bias[:, 1, :].reshape((n_embed,)),
|
qkv_bias[:, 0, :].reshape((n_embed,)),
|
||||||
qkv_bias[:, 2, :].reshape((n_embed,))),
|
qkv_bias[:, 1, :].reshape((n_embed,)),
|
||||||
axis=0
|
qkv_bias[:, 2, :].reshape((n_embed,)),
|
||||||
|
),
|
||||||
|
axis=0,
|
||||||
)
|
)
|
||||||
print("re-format attention.linear_qkv.bias")
|
print("re-format attention.linear_qkv.bias")
|
||||||
|
|
||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print("Can not map tensor '" + name + "'")
|
print(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
|
@ -390,13 +394,13 @@ class BloomModel(Model):
|
||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(name, "=>", new_name + ", shape = " + str(data.shape) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
if not has_lm_head and name == "word_embeddings.weight":
|
if not has_lm_head and name == "word_embeddings.weight":
|
||||||
self.gguf_writer.add_tensor("output.weight", data)
|
self.gguf_writer.add_tensor("output.weight", data)
|
||||||
print(name, "=>", "output.weight" + ", shape = " + str(data.shape) + ", " + str(old_dtype) + " --> " + str(data.dtype)) # noqa
|
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
|
|
||||||
class MPTModel(Model):
|
class MPTModel(Model):
|
||||||
|
@ -410,7 +414,7 @@ class MPTModel(Model):
|
||||||
self.gguf_writer.add_head_count(self.hparams["n_heads"])
|
self.gguf_writer.add_head_count(self.hparams["n_heads"])
|
||||||
if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
|
if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
|
||||||
self.gguf_writer.add_head_count_kv(kv_n_heads)
|
self.gguf_writer.add_head_count_kv(kv_n_heads)
|
||||||
self.gguf_writer.add_layer_norm_eps(1e-05)
|
self.gguf_writer.add_layer_norm_eps(1e-5)
|
||||||
if self.hparams["attn_config"]["clip_qkv"] is not None:
|
if self.hparams["attn_config"]["clip_qkv"] is not None:
|
||||||
self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"])
|
self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"])
|
||||||
self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"])
|
self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"])
|
||||||
|
@ -420,13 +424,13 @@ class MPTModel(Model):
|
||||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||||
for name, data_torch in self.get_tensors():
|
for name, data_torch in self.get_tensors():
|
||||||
# we don't need these
|
# we don't need these
|
||||||
if name.endswith(".attention.masked_bias") or name.endswith(".attention.bias") or name.endswith(".attention.rotary_emb.inv_freq"):
|
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
old_dtype = data_torch.dtype
|
old_dtype = data_torch.dtype
|
||||||
|
|
||||||
# convert any unsupported data types to float32
|
# convert any unsupported data types to float32
|
||||||
if data_torch.dtype != torch.float16 and data_torch.dtype != torch.float32:
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
data_torch = data_torch.to(torch.float32)
|
data_torch = data_torch.to(torch.float32)
|
||||||
|
|
||||||
data = data_torch.squeeze().numpy()
|
data = data_torch.squeeze().numpy()
|
||||||
|
@ -434,7 +438,7 @@ class MPTModel(Model):
|
||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print("Can not map tensor '" + name + "'")
|
print(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
|
@ -452,7 +456,7 @@ class MPTModel(Model):
|
||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
@ -469,16 +473,8 @@ class BaichuanModel(Model):
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
block_count = self.hparams["num_hidden_layers"]
|
block_count = self.hparams["num_hidden_layers"]
|
||||||
head_count = self.hparams["num_attention_heads"]
|
head_count = self.hparams["num_attention_heads"]
|
||||||
|
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||||
if "num_key_value_heads" in self.hparams:
|
hf_repo = self.hparams.get("_name_or_path", "")
|
||||||
head_count_kv = self.hparams["num_key_value_heads"]
|
|
||||||
else:
|
|
||||||
head_count_kv = head_count
|
|
||||||
|
|
||||||
if "_name_or_path" in self.hparams:
|
|
||||||
hf_repo = self.hparams["_name_or_path"]
|
|
||||||
else:
|
|
||||||
hf_repo = ""
|
|
||||||
|
|
||||||
ctx_length = 0
|
ctx_length = 0
|
||||||
if "max_sequence_length" in self.hparams:
|
if "max_sequence_length" in self.hparams:
|
||||||
|
@ -503,26 +499,9 @@ class BaichuanModel(Model):
|
||||||
self.gguf_writer.add_head_count_kv(head_count_kv)
|
self.gguf_writer.add_head_count_kv(head_count_kv)
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||||
|
|
||||||
if "rope_scaling" in self.hparams and self.hparams["rope_scaling"] is not None and "factor" in self.hparams["rope_scaling"]:
|
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||||
if "type" in self.hparams["rope_scaling"]:
|
if self.hparams["rope_scaling"].get("type") == "linear":
|
||||||
if self.hparams["rope_scaling"]["type"] == "linear":
|
self.gguf_writer.add_rope_scale_linear(self.hparams["rope_scaling"]["factor"])
|
||||||
self.gguf_writer.add_rope_scale_linear(self.hparams["rope_scaling"]["factor"])
|
|
||||||
|
|
||||||
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
|
|
||||||
if n_kv_head is not None and n_head != n_kv_head:
|
|
||||||
n_head //= n_kv_head
|
|
||||||
|
|
||||||
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
|
||||||
.swapaxes(1, 2)
|
|
||||||
.reshape(weights.shape))
|
|
||||||
|
|
||||||
def _reverse_hf_permute_part(self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None) -> Tensor:
|
|
||||||
r = weights.shape[0] // 3
|
|
||||||
return (self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv))
|
|
||||||
|
|
||||||
def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor:
|
|
||||||
r = weights.shape[0] // 3
|
|
||||||
return weights[r * n_part:r * n_part + r, ...]
|
|
||||||
|
|
||||||
def write_tensors(self):
|
def write_tensors(self):
|
||||||
# Collect tensors from generator object
|
# Collect tensors from generator object
|
||||||
|
@ -530,21 +509,17 @@ class BaichuanModel(Model):
|
||||||
block_count = self.hparams["num_hidden_layers"]
|
block_count = self.hparams["num_hidden_layers"]
|
||||||
head_count = self.hparams["num_attention_heads"]
|
head_count = self.hparams["num_attention_heads"]
|
||||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||||
|
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||||
if "num_key_value_heads" in self.hparams:
|
|
||||||
head_count_kv = self.hparams["num_key_value_heads"]
|
|
||||||
else:
|
|
||||||
head_count_kv = head_count
|
|
||||||
|
|
||||||
for i in range(block_count):
|
for i in range(block_count):
|
||||||
if f"model.layers.{i}.self_attn.W_pack.weight" in model_kv:
|
if (w := model_kv.get(f"model.layers.{i}.self_attn.W_pack.weight")) is not None:
|
||||||
print(f"Unpacking and permuting layer {i}")
|
print(f"Unpacking and permuting layer {i}")
|
||||||
model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = self._reverse_hf_permute_part(
|
model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = \
|
||||||
model_kv[f"model.layers.{i}.self_attn.W_pack.weight"], 0, head_count, head_count)
|
self._reverse_hf_permute_part(w, 0, head_count, head_count)
|
||||||
model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = self._reverse_hf_permute_part(
|
model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = \
|
||||||
model_kv[f"model.layers.{i}.self_attn.W_pack.weight"], 1, head_count, head_count_kv)
|
self._reverse_hf_permute_part(w, 1, head_count, head_count_kv)
|
||||||
model_kv[f"model.layers.{i}.self_attn.v_proj.weight"] = self._reverse_hf_part(
|
model_kv[f"model.layers.{i}.self_attn.v_proj.weight"] = \
|
||||||
model_kv[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
|
self._reverse_hf_part(w, 2)
|
||||||
del model_kv[f"model.layers.{i}.self_attn.W_pack.weight"]
|
del model_kv[f"model.layers.{i}.self_attn.W_pack.weight"]
|
||||||
|
|
||||||
for name, data_torch in model_kv.items():
|
for name, data_torch in model_kv.items():
|
||||||
|
@ -555,7 +530,7 @@ class BaichuanModel(Model):
|
||||||
old_dtype = data_torch.dtype
|
old_dtype = data_torch.dtype
|
||||||
|
|
||||||
# convert any unsupported data types to float32
|
# convert any unsupported data types to float32
|
||||||
if data_torch.dtype != torch.float16 and data_torch.dtype != torch.float32:
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
data_torch = data_torch.to(torch.float32)
|
data_torch = data_torch.to(torch.float32)
|
||||||
|
|
||||||
data = data_torch.squeeze().numpy()
|
data = data_torch.squeeze().numpy()
|
||||||
|
@ -563,7 +538,7 @@ class BaichuanModel(Model):
|
||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print("Can not map tensor '" + name + "'")
|
print(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
|
@ -581,9 +556,29 @@ class BaichuanModel(Model):
|
||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(name + " -> " + new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
|
||||||
|
if n_kv_head is not None and n_head != n_kv_head:
|
||||||
|
n_head //= n_kv_head
|
||||||
|
|
||||||
|
return (
|
||||||
|
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||||
|
.swapaxes(1, 2)
|
||||||
|
.reshape(weights.shape)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _reverse_hf_permute_part(
|
||||||
|
self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None,
|
||||||
|
) -> Tensor:
|
||||||
|
r = weights.shape[0] // 3
|
||||||
|
return self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv)
|
||||||
|
|
||||||
|
def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor:
|
||||||
|
r = weights.shape[0] // 3
|
||||||
|
return weights[r * n_part:r * n_part + r, ...]
|
||||||
|
|
||||||
|
|
||||||
class FalconModel(Model):
|
class FalconModel(Model):
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
|
@ -630,7 +625,7 @@ class FalconModel(Model):
|
||||||
old_dtype = data_torch.dtype
|
old_dtype = data_torch.dtype
|
||||||
|
|
||||||
# convert any unsupported data types to float32
|
# convert any unsupported data types to float32
|
||||||
if data_torch.dtype != torch.float16 and data_torch.dtype != torch.float32:
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
data_torch = data_torch.to(torch.float32)
|
data_torch = data_torch.to(torch.float32)
|
||||||
|
|
||||||
# QKV tensor transform
|
# QKV tensor transform
|
||||||
|
@ -655,7 +650,7 @@ class FalconModel(Model):
|
||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print("Can not map tensor '" + name + "'")
|
print(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
|
@ -673,7 +668,7 @@ class FalconModel(Model):
|
||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
@ -704,7 +699,7 @@ class RefactModel(Model):
|
||||||
block_count = self.hparams["n_layer"]
|
block_count = self.hparams["n_layer"]
|
||||||
|
|
||||||
self.gguf_writer.add_name("Refact")
|
self.gguf_writer.add_name("Refact")
|
||||||
# refact uses Alibi. So this is from config.json which might be used by training.
|
# refact uses Alibi. So this is from config.json which might be used by training.
|
||||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||||
|
|
||||||
|
@ -730,31 +725,23 @@ class RefactModel(Model):
|
||||||
|
|
||||||
tensors = dict(self.get_tensors())
|
tensors = dict(self.get_tensors())
|
||||||
for i in range(block_count):
|
for i in range(block_count):
|
||||||
if f"transformer.h.{i}.attn.kv.weight" in tensors:
|
if (w := tensors.get(f"transformer.h.{i}.attn.kv.weight")) is not None:
|
||||||
weight = tensors[f"transformer.h.{i}.attn.kv.weight"]
|
tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = w[:n_head_kv * head_dim]
|
||||||
tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = weight[
|
tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = w[n_head_kv * head_dim:]
|
||||||
: n_head_kv * head_dim
|
|
||||||
]
|
|
||||||
tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = weight[
|
|
||||||
n_head_kv * head_dim:
|
|
||||||
]
|
|
||||||
del tensors[f"transformer.h.{i}.attn.kv.weight"]
|
del tensors[f"transformer.h.{i}.attn.kv.weight"]
|
||||||
if f"transformer.h.{i}.attn.q.weight" in tensors:
|
if (w := tensors.get(f"transformer.h.{i}.attn.q.weight")) is not None:
|
||||||
tensors[f"model.layers.{i}.self_attn.q_proj.weight"] = tensors[
|
tensors[f"model.layers.{i}.self_attn.q_proj.weight"] = w
|
||||||
f"transformer.h.{i}.attn.q.weight"
|
|
||||||
]
|
|
||||||
del tensors[f"transformer.h.{i}.attn.q.weight"]
|
del tensors[f"transformer.h.{i}.attn.q.weight"]
|
||||||
if f"transformer.h.{i}.mlp.gate_up_proj.weight" in tensors:
|
if (w := tensors.get(f"transformer.h.{i}.mlp.gate_up_proj.weight")) is not None:
|
||||||
weight = tensors[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
|
tensors[f"model.layers.{i}.mlp.gate_proj.weight"] = w[:ff_dim]
|
||||||
tensors[f"model.layers.{i}.mlp.gate_proj.weight"] = weight[:ff_dim]
|
tensors[f"model.layers.{i}.mlp.up_proj.weight"] = w[ff_dim:]
|
||||||
tensors[f"model.layers.{i}.mlp.up_proj.weight"] = weight[ff_dim:]
|
|
||||||
del tensors[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
|
del tensors[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
|
||||||
|
|
||||||
for name, data_torch in tensors.items():
|
for name, data_torch in tensors.items():
|
||||||
old_dtype = data_torch.dtype
|
old_dtype = data_torch.dtype
|
||||||
|
|
||||||
# convert any unsupported data types to float32
|
# convert any unsupported data types to float32
|
||||||
if data_torch.dtype != torch.float16 and data_torch.dtype != torch.float32:
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
data_torch = data_torch.to(torch.float32)
|
data_torch = data_torch.to(torch.float32)
|
||||||
|
|
||||||
data = data_torch.squeeze().numpy()
|
data = data_torch.squeeze().numpy()
|
||||||
|
@ -762,7 +749,7 @@ class RefactModel(Model):
|
||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight",))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight",))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print("Can not map tensor '" + name + "'")
|
print(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
|
@ -780,7 +767,7 @@ class RefactModel(Model):
|
||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
@ -820,8 +807,8 @@ class PersimmonModel(Model):
|
||||||
data = data_torch.to(torch.float32).squeeze().numpy()
|
data = data_torch.to(torch.float32).squeeze().numpy()
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print("Can not map tensor '" + name + "'")
|
print(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue