Merge branch 'master' into compilade/mamba2
This commit is contained in:
commit
0e601cafe9
197 changed files with 25587 additions and 16246 deletions
|
@ -3,6 +3,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import logging
|
||||
import argparse
|
||||
import contextlib
|
||||
|
@ -63,6 +64,7 @@ class Model:
|
|||
model_name: str | None
|
||||
metadata_override: Path | None
|
||||
dir_model_card: Path
|
||||
is_lora: bool
|
||||
|
||||
# subclasses should define this!
|
||||
model_arch: gguf.MODEL_ARCH
|
||||
|
@ -70,7 +72,7 @@ class Model:
|
|||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
|
||||
use_temp_file: bool = False, eager: bool = False,
|
||||
metadata_override: Path | None = None, model_name: str | None = None,
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, is_lora: bool = False):
|
||||
if type(self) is Model:
|
||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||
|
||||
|
@ -92,6 +94,7 @@ class Model:
|
|||
self.metadata_override = metadata_override
|
||||
self.model_name = model_name
|
||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||
self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py
|
||||
|
||||
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||
|
@ -129,12 +132,14 @@ class Model:
|
|||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
tensor_names_from_parts: set[str] = set()
|
||||
|
||||
if len(self.part_names) > 1:
|
||||
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
|
||||
index_name += ".index.json"
|
||||
index_file = self.dir_model / index_name
|
||||
|
||||
if index_file.is_file():
|
||||
self.tensor_names = set()
|
||||
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
|
||||
index_name += ".index.json"
|
||||
logger.info(f"gguf: loading model weight map from '{index_name}'")
|
||||
with open(self.dir_model / index_name, "r", encoding="utf-8") as f:
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index: dict[str, Any] = json.load(f)
|
||||
weight_map = index.get("weight_map")
|
||||
if weight_map is None or not isinstance(weight_map, dict):
|
||||
|
@ -142,6 +147,7 @@ class Model:
|
|||
self.tensor_names.update(weight_map.keys())
|
||||
else:
|
||||
self.tensor_names = tensor_names_from_parts
|
||||
weight_map = {}
|
||||
|
||||
for part_name in self.part_names:
|
||||
logger.info(f"gguf: loading model part '{part_name}'")
|
||||
|
@ -168,9 +174,17 @@ class Model:
|
|||
data = LazyTorchTensor.from_eager(data)
|
||||
yield name, data
|
||||
|
||||
# only verify tensor name presence; it doesn't matter if they are not in the right files
|
||||
if len(sym_diff := tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
|
||||
raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")
|
||||
# verify tensor name presence and identify potentially missing files
|
||||
if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
|
||||
missing = sorted(self.tensor_names.difference(tensor_names_from_parts))
|
||||
extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
|
||||
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
|
||||
if len(extra) == 0 and len(missing_files) > 0:
|
||||
raise ValueError(f"Missing or incomplete model files: {missing_files}")
|
||||
else:
|
||||
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
|
||||
f"Missing tensors: {missing}\n"
|
||||
f"Extra tensors: {extra}")
|
||||
|
||||
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
|
||||
if key not in gguf.MODEL_TENSORS[self.model_arch]:
|
||||
|
@ -296,12 +310,31 @@ class Model:
|
|||
gguf.MODEL_TENSOR.POS_EMBD,
|
||||
gguf.MODEL_TENSOR.TOKEN_TYPES,
|
||||
gguf.MODEL_TENSOR.SSM_CONV1D,
|
||||
gguf.MODEL_TENSOR.TIME_MIX_FIRST,
|
||||
gguf.MODEL_TENSOR.TIME_MIX_W1,
|
||||
gguf.MODEL_TENSOR.TIME_MIX_W2,
|
||||
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1,
|
||||
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2,
|
||||
)
|
||||
)
|
||||
or not name.endswith(".weight")
|
||||
or not new_name.endswith(".weight")
|
||||
):
|
||||
data_qtype = gguf.GGMLQuantizationType.F32
|
||||
|
||||
if data_qtype is False and any(
|
||||
self.match_model_tensor_name(new_name, key, bid)
|
||||
for key in (
|
||||
gguf.MODEL_TENSOR.TOKEN_EMBD,
|
||||
gguf.MODEL_TENSOR.OUTPUT,
|
||||
)
|
||||
):
|
||||
if self.ftype in (
|
||||
gguf.LlamaFileType.MOSTLY_TQ1_0,
|
||||
gguf.LlamaFileType.MOSTLY_TQ2_0,
|
||||
):
|
||||
# TODO: use Q4_K and Q6_K
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
|
||||
# No override (data_qtype is False), or wants to be quantized (data_qtype is True)
|
||||
if isinstance(data_qtype, bool):
|
||||
if self.ftype == gguf.LlamaFileType.ALL_F32:
|
||||
|
@ -312,6 +345,10 @@ class Model:
|
|||
data_qtype = gguf.GGMLQuantizationType.BF16
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
|
||||
data_qtype = gguf.GGMLQuantizationType.Q8_0
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ1_0:
|
||||
data_qtype = gguf.GGMLQuantizationType.TQ1_0
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0:
|
||||
data_qtype = gguf.GGMLQuantizationType.TQ2_0
|
||||
else:
|
||||
raise ValueError(f"Unknown file type: {self.ftype.name}")
|
||||
|
||||
|
@ -600,6 +637,9 @@ class Model:
|
|||
if chkhsh == "4e2b24cc4770243d65a2c9ec19770a72f08cffc161adbb73fcbb6b7dd45a0aae":
|
||||
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct
|
||||
res = "exaone"
|
||||
if chkhsh == "fcace8b9cac38ce847670c970cd5892031a753a1ef381abd1d9af00f713da085":
|
||||
# ref: https://huggingface.co/microsoft/phi-2
|
||||
res = "phi-2"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
@ -1458,7 +1498,7 @@ class StableLMModel(Model):
|
|||
raise ValueError(f"Unprocessed norms: {norms}")
|
||||
|
||||
|
||||
@Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
|
||||
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
|
||||
class LlamaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
|
||||
|
@ -1570,7 +1610,7 @@ class LlamaModel(Model):
|
|||
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
|
||||
if rope_scaling.get("rope_type", '').lower() == "llama3":
|
||||
base = self.hparams.get("rope_theta", 10000.0)
|
||||
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||
|
||||
factor = rope_scaling.get("factor", 8.0)
|
||||
|
@ -1593,7 +1633,8 @@ class LlamaModel(Model):
|
|||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
|
||||
|
||||
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
||||
if not self.is_lora:
|
||||
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
||||
|
||||
super().prepare_tensors()
|
||||
|
||||
|
@ -1616,15 +1657,16 @@ class BitnetModel(Model):
|
|||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(1.0)
|
||||
|
||||
def weight_quant(self, weight):
|
||||
def weight_quant(self, weight: Tensor) -> Tensor:
|
||||
dtype = weight.dtype
|
||||
weight = weight.float()
|
||||
s = 1 / weight.abs().mean().clamp(min=1e-5)
|
||||
weight = (weight * s).round().clamp(-1, 1) / s
|
||||
scale = weight.abs().max().unsqueeze(0)
|
||||
weight = torch.where(weight.abs().less(1e-6), 0, weight).type(dtype)
|
||||
weight = torch.sign(weight).type(dtype)
|
||||
return weight.type(dtype), scale.type(torch.float32)
|
||||
scale = weight.abs().mean().clamp(min=1e-5)
|
||||
iscale = 1 / scale
|
||||
# TODO: multiply by the scale directly instead of inverting it twice
|
||||
# (this is also unnecessarily doubly inverted upstream)
|
||||
# ref: https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/af89e318d78a70802061246bf037199d2fb97020/utils_quant.py#L10
|
||||
result = (weight * iscale).round().clamp(-1, 1) / iscale
|
||||
return result.type(dtype)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
@ -1639,11 +1681,9 @@ class BitnetModel(Model):
|
|||
gguf.MODEL_TENSOR.FFN_GATE,
|
||||
]):
|
||||
# transform weight into 1/0/-1 (in fp32)
|
||||
weight_torch, scale_torch = self.weight_quant(data_torch)
|
||||
yield (new_name, weight_torch)
|
||||
yield (new_name.removesuffix(".weight") + ".scale", scale_torch)
|
||||
else:
|
||||
yield (new_name, data_torch)
|
||||
data_torch = self.weight_quant(data_torch)
|
||||
|
||||
yield (new_name, data_torch)
|
||||
|
||||
|
||||
@Model.register("GrokForCausalLM")
|
||||
|
@ -1812,6 +1852,60 @@ class MiniCPMModel(Model):
|
|||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("MiniCPM3ForCausalLM")
|
||||
class MiniCPM3Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MINICPM3
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
|
||||
rope_dims = hparams["qk_rope_head_dim"]
|
||||
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
|
||||
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
|
||||
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
|
||||
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||
if rope_scaling is None:
|
||||
return
|
||||
|
||||
long_factors = rope_scaling.get('long_factor', None)
|
||||
short_factors = rope_scaling.get('short_factor', None)
|
||||
|
||||
if long_factors is None or short_factors is None:
|
||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||
|
||||
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
|
||||
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
|
||||
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_llama_hf()
|
||||
|
||||
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
|
||||
if n_kv_head is not None and n_head != n_kv_head:
|
||||
n_head //= n_kv_head
|
||||
|
||||
return (
|
||||
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weights.shape)
|
||||
)
|
||||
|
||||
|
||||
@Model.register("QWenLMHeadModel")
|
||||
class QwenModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN
|
||||
|
@ -2140,8 +2234,9 @@ class Phi3MiniModel(Model):
|
|||
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
|
||||
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
|
||||
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
|
||||
if not self.is_lora:
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
|
||||
|
||||
|
||||
@Model.register("PlamoForCausalLM")
|
||||
|
@ -2712,6 +2807,86 @@ class StarCoder2Model(Model):
|
|||
model_arch = gguf.MODEL_ARCH.STARCODER2
|
||||
|
||||
|
||||
@Model.register("Rwkv6ForCausalLM")
|
||||
class Rwkv6Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.RWKV6
|
||||
|
||||
def set_vocab(self):
|
||||
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
|
||||
vocab_size = self.hparams.get("vocab_size", 65536)
|
||||
|
||||
tokens: list[bytes] = ['<s>'.encode("utf-8")]
|
||||
toktypes: list[int] = [gguf.TokenType.CONTROL]
|
||||
|
||||
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.split(' ')
|
||||
assert len(parts) >= 3
|
||||
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
|
||||
token = token.encode("utf-8") if isinstance(token, str) else token
|
||||
assert isinstance(token, bytes)
|
||||
assert len(token) == token_len
|
||||
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
|
||||
tokens.append(token_text.encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
remainder = vocab_size - len(tokens)
|
||||
assert remainder >= 0
|
||||
for i in range(len(tokens), vocab_size):
|
||||
tokens.append(f"[PAD{i}]".encode("utf-8"))
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("rwkv")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
head_size = self.hparams["head_size"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
||||
rescale_every_n_layers = self.hparams["rescale_every"]
|
||||
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else int((hidden_size * 3.5) // 32 * 32)
|
||||
time_mix_extra_dim = 64 if hidden_size == 4096 else 32
|
||||
time_decay_extra_dim = 128 if hidden_size == 4096 else 64
|
||||
|
||||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
||||
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
|
||||
self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
|
||||
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
# required by llama.cpp, unused
|
||||
self.gguf_writer.add_head_count(0)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
|
||||
new_name += ".weight"
|
||||
|
||||
if new_name.endswith("time_mix_w1.weight") or new_name.endswith("time_mix_decay_w1.weight") or new_name.endswith("time_mix_decay_w2.weight"):
|
||||
data_torch = data_torch.transpose(0, 1)
|
||||
|
||||
if new_name.endswith("time_mix_w2.weight"):
|
||||
data_torch = data_torch.permute(0, 2, 1)
|
||||
|
||||
rescale_every_n_layers = self.hparams["rescale_every"]
|
||||
if rescale_every_n_layers > 0:
|
||||
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
|
||||
data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
|
||||
|
||||
yield (new_name, data_torch)
|
||||
|
||||
|
||||
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
|
||||
class MambaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MAMBA
|
||||
|
@ -2905,6 +3080,66 @@ class OlmoModel(Model):
|
|||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("OlmoeForCausalLM")
|
||||
class OlmoeModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.OLMOE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_layer_norm_rms_eps(1e-5)
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
# Copied from: Qwen2MoeModel
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
if name.find("experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
# Copied from: Qwen2MoeModel
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("JinaBertModel", "JinaBertForMaskedLM")
|
||||
class JinaBertV2Model(BertModel):
|
||||
model_arch = gguf.MODEL_ARCH.JINA_BERT_V2
|
||||
|
@ -3887,7 +4122,7 @@ class ExaoneModel(Model):
|
|||
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
|
||||
if rope_scaling.get("rope_type", '').lower() == "llama3":
|
||||
base = self.hparams.get("rope_theta", 10000.0)
|
||||
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||
|
||||
factor = rope_scaling.get("factor", 8.0)
|
||||
|
@ -3910,11 +4145,42 @@ class ExaoneModel(Model):
|
|||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
|
||||
|
||||
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
||||
if not self.is_lora:
|
||||
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
||||
|
||||
super().prepare_tensors()
|
||||
|
||||
|
||||
@Model.register("GraniteForCausalLM")
|
||||
class GraniteModel(LlamaModel):
|
||||
"""Conversion for IBM's GraniteForCausalLM"""
|
||||
model_arch = gguf.MODEL_ARCH.GRANITE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
"""Granite uses standard llama parameters with the following differences:
|
||||
|
||||
- No head_dim support
|
||||
- New multiplier params:
|
||||
- attention_scale
|
||||
- embedding_scale
|
||||
- residual_scale
|
||||
- logits_scaling
|
||||
"""
|
||||
if head_dim := self.hparams.pop("head_dim", None):
|
||||
logger.warning("Ignoring head_dim (%s) from config for Granite", head_dim)
|
||||
super().set_gguf_parameters()
|
||||
# NOTE: Convert _multiplier params to _scale params for naming
|
||||
# consistency
|
||||
if attention_scale := self.hparams.get("attention_multiplier"):
|
||||
self.gguf_writer.add_attention_scale(attention_scale)
|
||||
if embedding_scale := self.hparams.get("embedding_multiplier"):
|
||||
self.gguf_writer.add_embedding_scale(embedding_scale)
|
||||
if residual_scale := self.hparams.get("residual_multiplier"):
|
||||
self.gguf_writer.add_residual_scale(residual_scale)
|
||||
if logits_scaling := self.hparams.get("logits_scaling"):
|
||||
self.gguf_writer.add_logit_scale(logits_scaling)
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
# tree of lazy tensors
|
||||
|
@ -3995,8 +4261,8 @@ def parse_args() -> argparse.Namespace:
|
|||
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
|
||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16",
|
||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bigendian", action="store_true",
|
||||
|
@ -4083,6 +4349,8 @@ def main() -> None:
|
|||
"f16": gguf.LlamaFileType.MOSTLY_F16,
|
||||
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
|
||||
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
|
||||
"tq1_0": gguf.LlamaFileType.MOSTLY_TQ1_0,
|
||||
"tq2_0": gguf.LlamaFileType.MOSTLY_TQ2_0,
|
||||
"auto": gguf.LlamaFileType.GUESSED,
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue