llama: quantize: remove wrong look for tensor qkv name as it was badly missing the .weight suffix

model: dbrx: convert to gguf force experts tensors to have .weight suffix
This commit is contained in:
Pierrick HYMBERT 2024-04-07 17:55:22 +02:00
parent 2449ef48a9
commit 1bd94270e5
2 changed files with 67 additions and 13 deletions

View file

@ -95,7 +95,7 @@ class Model(ABC):
self.gguf_writer.add_context_length(n_ctx) self.gguf_writer.add_context_length(n_ctx)
print(f"gguf: context length = {n_ctx}") print(f"gguf: context length = {n_ctx}")
if (n_embd := self.find_hparam(["hidden_size", "n_embd"], optional=True)) is not None: n_embd = self.find_hparam(["hidden_size", "n_embd"])
self.gguf_writer.add_embedding_length(n_embd) self.gguf_writer.add_embedding_length(n_embd)
print(f"gguf: embedding length = {n_embd}") print(f"gguf: embedding length = {n_embd}")
@ -103,7 +103,7 @@ class Model(ABC):
self.gguf_writer.add_feed_forward_length(n_ff) self.gguf_writer.add_feed_forward_length(n_ff)
print(f"gguf: feed forward length = {n_ff}") print(f"gguf: feed forward length = {n_ff}")
if (n_head := self.find_hparam(["num_attention_heads", "n_head"], optional=True)) is not None: n_head = self.find_hparam(["num_attention_heads", "n_head"])
self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count(n_head)
print(f"gguf: head count = {n_head}") print(f"gguf: head count = {n_head}")
@ -1489,23 +1489,77 @@ class DbrxModel(Model):
model_arch = gguf.MODEL_ARCH.DBRX model_arch = gguf.MODEL_ARCH.DBRX
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters()
ffn_config = self.hparams["ffn_config"] ffn_config = self.hparams["ffn_config"]
attn_config = self.hparams["attn_config"] attn_config = self.hparams["attn_config"]
self.gguf_writer.add_name(self.hparams["model_type"]) self.gguf_writer.add_name(self.hparams["model_type"])
self.gguf_writer.add_block_count(self.hparams["n_layers"])
self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_block_count(self.hparams["n_layers"]) self.gguf_writer.add_feed_forward_length(ffn_config["ffn_hidden_size"])
self.gguf_writer.add_head_count(self.hparams["n_heads"]) self.gguf_writer.add_head_count(self.hparams["n_heads"])
self.gguf_writer.add_head_count_kv(attn_config["kv_n_heads"]) self.gguf_writer.add_head_count_kv(attn_config["kv_n_heads"])
self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"]) self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])
self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"]) self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"])
self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_feed_forward_length(ffn_config["ffn_hidden_size"])
self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"]) self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])
self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"]) self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"])
self.gguf_writer.add_file_type(self.ftype)
print(f"gguf: file type = {self.ftype}")
def write_tensors(self):
block_count = self.hparams.get("n_layers")
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
for name, data_torch in self.get_tensors():
# In MoE models the ffn tensors are typically most of the model weights,
# and need to be quantizable. Quantize expects tensor names to be suffixed by .weight.
# Every other model has the weight names ending in .weight,
# let's assume that is the convention which is not the case for dbrx:
# https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15
exp_tensor_names = ["ffn.experts.mlp.v1", "ffn.experts.mlp.w1", "ffn.experts.mlp.w2"]
for exp_tensor_name in exp_tensor_names:
if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
name += ".weight"
break
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
data = data_torch.squeeze().numpy()
# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
def set_vocab(self): def set_vocab(self):
self._set_vocab_tiktoken() self._set_vocab_tiktoken()

View file

@ -4692,9 +4692,9 @@ static bool llm_load_tensors(
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd}); layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd});
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, i), {n_embd, n_ff, n_expert}); layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS,"weight", i), {n_embd, n_ff, n_expert});
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, i), {n_ff, n_embd, n_expert}); layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS,"weight", i), {n_ff, n_embd, n_expert});
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, i), {n_embd, n_ff, n_expert}); layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
} }