Compare commits
4 commits
master
...
compilade/
Author | SHA1 | Date | |
---|---|---|---|
|
11f78c6a2d | ||
|
96a299ff60 | ||
|
d703fa9fa5 | ||
|
93b9baee73 |
1 changed files with 68 additions and 100 deletions
|
@ -228,7 +228,7 @@ class Model:
|
|||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
||||
del name, new_name, bid, n_dims # unused
|
||||
|
@ -250,10 +250,6 @@ class Model:
|
|||
|
||||
old_dtype = data_torch.dtype
|
||||
|
||||
# convert any unsupported data types to float32
|
||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||
data_torch = data_torch.to(torch.float32)
|
||||
|
||||
# use the first number-like part of the tensor name as the block id
|
||||
bid = None
|
||||
for part in name.split("."):
|
||||
|
@ -261,8 +257,13 @@ class Model:
|
|||
bid = int(part)
|
||||
break
|
||||
|
||||
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
|
||||
data: np.ndarray = data # type hint
|
||||
for new_name, new_data_torch in self.modify_tensors(data_torch, name, bid):
|
||||
|
||||
# convert any unsupported-by-Numpy data types to float32
|
||||
if new_data_torch.dtype not in (torch.float16, torch.float32):
|
||||
new_data_torch = new_data_torch.to(torch.float32)
|
||||
|
||||
data: np.ndarray = new_data_torch.squeeze().numpy()
|
||||
n_dims = len(data.shape)
|
||||
data_dtype = data.dtype
|
||||
data_qtype: gguf.GGMLQuantizationType | None = None
|
||||
|
@ -736,8 +737,6 @@ class BloomModel(Model):
|
|||
|
||||
name = re.sub(r'transformer\.', '', name)
|
||||
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name):
|
||||
# Map bloom-style qkv_linear to gpt-style qkv_linear
|
||||
# bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
|
||||
|
@ -764,16 +763,14 @@ class BloomModel(Model):
|
|||
)
|
||||
logger.info("re-format attention.linear_qkv.bias")
|
||||
|
||||
tensors.append((self.map_tensor_name(name), data_torch))
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
if name == "word_embeddings.weight":
|
||||
assert self.tensor_names is not None
|
||||
|
||||
# TODO: tie them at runtime, don't duplicate in the model file
|
||||
if all(s not in self.tensor_names for s in ("lm_head.weight", "output.weight")):
|
||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
|
||||
|
||||
return tensors
|
||||
yield self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch
|
||||
|
||||
|
||||
@Model.register("MPTForCausalLM")
|
||||
|
@ -818,7 +815,7 @@ class MPTModel(Model):
|
|||
else:
|
||||
new_name = self.map_tensor_name(name, try_suffixes=(".weight", ".bias"))
|
||||
|
||||
return [(new_name, data_torch)]
|
||||
yield new_name, data_torch
|
||||
|
||||
|
||||
@Model.register("OrionForCausalLM")
|
||||
|
@ -904,22 +901,16 @@ class BaichuanModel(Model):
|
|||
head_count = self.hparams["num_attention_heads"]
|
||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
if bid is not None and name == f"model.layers.{bid}.self_attn.W_pack.weight":
|
||||
logger.info(f"Unpacking and permuting layer {bid}")
|
||||
tensors = [
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid),
|
||||
self._reverse_hf_permute_part(data_torch, 0, head_count, head_count)),
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid),
|
||||
self._reverse_hf_permute_part(data_torch, 1, head_count, head_count_kv)),
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid),
|
||||
self._reverse_hf_part(data_torch, 2)),
|
||||
]
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid),
|
||||
self._reverse_hf_permute_part(data_torch, 0, head_count, head_count))
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid),
|
||||
self._reverse_hf_permute_part(data_torch, 1, head_count, head_count_kv))
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid),
|
||||
self._reverse_hf_part(data_torch, 2))
|
||||
else:
|
||||
tensors = [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
return tensors
|
||||
yield (self.map_tensor_name(name), data_torch)
|
||||
|
||||
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
|
||||
if n_kv_head is not None and n_head != n_kv_head:
|
||||
|
@ -1035,7 +1026,7 @@ class XverseModel(Model):
|
|||
if name.endswith("k_proj.weight"):
|
||||
data_torch = self._reverse_hf_permute(data_torch, head_count, head_count_kv)
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
|
||||
if n_kv_head is not None and n_head != n_kv_head:
|
||||
|
@ -1100,7 +1091,7 @@ class FalconModel(Model):
|
|||
v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
|
||||
data_torch = torch.cat((q, k, v)).reshape_as(data_torch)
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
|
||||
@Model.register("GPTBigCodeForCausalLM")
|
||||
|
@ -1168,22 +1159,20 @@ class RefactModel(Model):
|
|||
n_head_kv = 1
|
||||
head_dim = self.hparams["n_embd"] // n_head
|
||||
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
if bid is not None:
|
||||
if name == f"transformer.h.{bid}.attn.kv.weight":
|
||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), data_torch[:n_head_kv * head_dim]))
|
||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), data_torch[n_head_kv * head_dim:]))
|
||||
yield self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), data_torch[:n_head_kv * head_dim]
|
||||
yield self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), data_torch[n_head_kv * head_dim:]
|
||||
return
|
||||
elif name == f"transformer.h.{bid}.attn.q.weight":
|
||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), data_torch))
|
||||
yield self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), data_torch
|
||||
return
|
||||
elif name == f"transformer.h.{bid}.mlp.gate_up_proj.weight":
|
||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), data_torch[:ff_dim]))
|
||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), data_torch[ff_dim:]))
|
||||
yield self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), data_torch[:ff_dim]
|
||||
yield self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), data_torch[ff_dim:]
|
||||
return
|
||||
|
||||
if len(tensors) == 0:
|
||||
tensors.append((self.map_tensor_name(name), data_torch))
|
||||
|
||||
return tensors
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
|
||||
@Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM")
|
||||
|
@ -1230,9 +1219,8 @@ class StableLMModel(Model):
|
|||
self._q_norms[bid][name] = data_torch
|
||||
|
||||
if len(self._q_norms[bid]) >= n_head:
|
||||
return self._stack_qk_norm(bid, n_head, self._q_norms[bid], "q_layernorm")
|
||||
else:
|
||||
return []
|
||||
yield self._stack_qk_norm(bid, n_head, self._q_norms[bid], "q_layernorm")
|
||||
return
|
||||
|
||||
if name.find("k_layernorm.norms") != -1:
|
||||
assert bid is not None
|
||||
|
@ -1243,13 +1231,12 @@ class StableLMModel(Model):
|
|||
self._k_norms[bid][name] = data_torch
|
||||
|
||||
if len(self._k_norms[bid]) >= n_kv_head:
|
||||
return self._stack_qk_norm(bid, n_kv_head, self._k_norms[bid], "k_layernorm")
|
||||
else:
|
||||
return []
|
||||
yield self._stack_qk_norm(bid, n_kv_head, self._k_norms[bid], "k_layernorm")
|
||||
return
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
def _stack_qk_norm(self, bid: int, n_head: int, norms: dict[str, Tensor], layer_name: str = "q_layernorm"):
|
||||
def _stack_qk_norm(self, bid: int, n_head: int, norms: dict[str, Tensor], layer_name: str = "q_layernorm") -> tuple[str, Tensor]:
|
||||
datas: list[Tensor] = []
|
||||
# extract the norms in order
|
||||
for xid in range(n_head):
|
||||
|
@ -1261,7 +1248,7 @@ class StableLMModel(Model):
|
|||
merged_name = f"model.layers.{bid}.self_attn.{layer_name}.weight"
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
return [(new_name, data_torch)]
|
||||
return (new_name, data_torch)
|
||||
|
||||
def write_tensors(self):
|
||||
super().write_tensors()
|
||||
|
@ -1345,7 +1332,6 @@ class LlamaModel(Model):
|
|||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for wid in ["w1", "w2", "w3"]:
|
||||
|
@ -1362,12 +1348,10 @@ class LlamaModel(Model):
|
|||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
yield new_name, data_torch
|
||||
return
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
def write_tensors(self):
|
||||
super().write_tensors()
|
||||
|
@ -1408,7 +1392,6 @@ class GrokModel(Model):
|
|||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for wid in ["linear", "linear_1", "linear_v"]:
|
||||
|
@ -1425,12 +1408,10 @@ class GrokModel(Model):
|
|||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
yield new_name, data_torch
|
||||
return
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
|
||||
@Model.register("DbrxForCausalLM")
|
||||
|
@ -1496,7 +1477,7 @@ class DbrxModel(Model):
|
|||
# https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15
|
||||
new_name = self.map_tensor_name(name if not experts else name + ".weight", try_suffixes=(".weight",))
|
||||
|
||||
return [(new_name, data_torch)]
|
||||
yield new_name, data_torch
|
||||
|
||||
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
||||
del name, new_name, bid # unused
|
||||
|
@ -1546,7 +1527,7 @@ class MiniCPMModel(Model):
|
|||
if name.endswith(("k_proj.weight")):
|
||||
data_torch = self._reverse_hf_permute(data_torch, n_head, n_kv_head)
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
|
||||
@Model.register("QWenLMHeadModel")
|
||||
|
@ -1626,7 +1607,6 @@ class Qwen2MoeModel(Model):
|
|||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||
|
@ -1643,12 +1623,10 @@ class Qwen2MoeModel(Model):
|
|||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
yield new_name, data_torch
|
||||
return
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
def write_tensors(self):
|
||||
super().write_tensors()
|
||||
|
@ -1677,24 +1655,20 @@ class GPT2Model(Model):
|
|||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# we don't need these
|
||||
if name.endswith((".attn.bias", ".attn.masked_bias")):
|
||||
return tensors
|
||||
return
|
||||
|
||||
if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight")):
|
||||
data_torch = data_torch.transpose(1, 0)
|
||||
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
yield new_name, data_torch
|
||||
|
||||
# note: GPT2 output is tied to (same as) wte in original model
|
||||
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
|
||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
|
||||
|
||||
return tensors
|
||||
yield self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch
|
||||
|
||||
|
||||
@Model.register("PhiForCausalLM")
|
||||
|
@ -1922,7 +1896,7 @@ class PlamoModel(Model):
|
|||
elif new_name.endswith("attn_output.weight"):
|
||||
data_torch = self.shuffle_attn_output_weight(data_torch)
|
||||
|
||||
return [(new_name, data_torch)]
|
||||
yield new_name, data_torch
|
||||
|
||||
|
||||
@Model.register("CodeShellForCausalLM")
|
||||
|
@ -1950,16 +1924,14 @@ class CodeShellModel(Model):
|
|||
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
tensors: list[tuple[str, Tensor]] = [(new_name, data_torch)]
|
||||
yield new_name, data_torch
|
||||
|
||||
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
|
||||
assert self.tensor_names is not None
|
||||
|
||||
if all(s not in self.tensor_names for s in ("lm_head.weight", "output.weight")):
|
||||
# copy tok_embd.weight to output.weight
|
||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
|
||||
|
||||
return tensors
|
||||
yield self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch
|
||||
|
||||
|
||||
@Model.register("InternLM2ForCausalLM")
|
||||
|
@ -2100,13 +2072,11 @@ in chat mode so that the conversation can end normally.")
|
|||
k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads)
|
||||
# v = rearrange(v, " o g n i -> o (g n i)").T
|
||||
v = v.reshape((v.shape[0], -1)).T
|
||||
return [
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q),
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v),
|
||||
]
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q)
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k)
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v)
|
||||
else:
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
yield (self.map_tensor_name(name), data_torch)
|
||||
|
||||
|
||||
@Model.register("BertModel", "CamembertModel")
|
||||
|
@ -2176,9 +2146,9 @@ class BertModel(Model):
|
|||
|
||||
# we are only using BERT for embeddings so we don't need the pooling layer
|
||||
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
|
||||
return [] # we don't need these
|
||||
return # we don't need these
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
|
||||
@Model.register("NomicBertModel")
|
||||
|
@ -2251,13 +2221,13 @@ class GemmaModel(Model):
|
|||
# To prevent errors, skip loading lm_head.weight.
|
||||
if name == "lm_head.weight":
|
||||
logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
|
||||
return []
|
||||
return
|
||||
|
||||
# ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
|
||||
if name.endswith("norm.weight"):
|
||||
data_torch = data_torch + 1
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
|
||||
@Model.register("Starcoder2ForCausalLM")
|
||||
|
@ -2363,11 +2333,12 @@ class MambaModel(Model):
|
|||
if self._tok_embd is not None and new_name == output_name:
|
||||
if torch.equal(self._tok_embd, data_torch):
|
||||
logger.debug(f"{output_name} is equivalent to {tok_embd_name}, omitting")
|
||||
return []
|
||||
self._tok_embd = None
|
||||
return
|
||||
elif new_name == tok_embd_name:
|
||||
self._tok_embd = data_torch
|
||||
|
||||
return [(new_name, data_torch)]
|
||||
yield new_name, data_torch
|
||||
|
||||
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
|
||||
del n_dims # unused
|
||||
|
@ -2425,7 +2396,7 @@ class OlmoModel(Model):
|
|||
if name.endswith("k_proj.weight"):
|
||||
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
|
||||
@Model.register("JinaBertModel", "JinaBertForMaskedLM")
|
||||
|
@ -2582,7 +2553,6 @@ class ArcticModel(Model):
|
|||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for wid in ["w1", "w2", "w3"]:
|
||||
|
@ -2599,12 +2569,10 @@ class ArcticModel(Model):
|
|||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
yield new_name, data_torch
|
||||
return
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
yield self.map_tensor_name(name), data_torch
|
||||
|
||||
def write_tensors(self):
|
||||
super().write_tensors()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue