llama: dbrx: document changes, permute only FFN_DOWN_EXPS. Add a check for ftype

This commit is contained in:
Pierrick HYMBERT 2024-04-08 20:08:54 +02:00
parent 9968952921
commit e66f1e3448

View file

@ -1522,16 +1522,21 @@ class DbrxModel(Model):
n_ff = self.hparams["ffn_config"]["ffn_hidden_size"]
n_embd = self.hparams["d_model"]
# Specific behavior for experts tensors: suffix .weight, reshape to 3D and transpose
# orginal implementation expects (n_expert, n_ff, n_embd)
exp_tensor_names = {"ffn.experts.mlp.v1": (0, 1, 2), # LLM_TENSOR_FFN_GATE_EXPS(n_embd, n_ff, n_expert)
"ffn.experts.mlp.w2": (0, 2, 1), # LLM_TENSOR_FFN_DOWN_EXPS(n_ff, n_embd, n_expert)
"ffn.experts.mlp.w1": (0, 1, 2)} # LLM_TENSOR_FFN_UP_EXPS (n_embd, n_ff, n_expert)
# Specific behavior for experts tensors: suffix .weight, view as 3D and transpose
# original implementation expects (n_expert, n_ff, n_embd) for all experts weights
# But llama.cpp moe graph works differently
# AND the dimensions in ggml are typically in the reverse order of the pytorch dimensions
# so (n_expert, n_ff, n_embd) in pytorch is {n_embd, n_ff, n_expert} in ggml_tensor
exp_tensor_names = {"ffn.experts.mlp.v1": None, # LLM_TENSOR_FFN_GATE_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert}
"ffn.experts.mlp.w2": (0, 2, 1), # LLM_TENSOR_FFN_DOWN_EXPS ggml_tensor->ne{n_ff, n_embd, n_expert}
"ffn.experts.mlp.w1": None} # LLM_TENSOR_FFN_UP_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert}
experts = False
for exp_tensor_name in exp_tensor_names.keys():
if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
experts = True
data_torch = data_torch.view(n_expert, n_ff, n_embd).permute(*exp_tensor_names[exp_tensor_name])
data_torch = data_torch.view(n_expert, n_ff, n_embd)
if permute_tensor := exp_tensor_names[exp_tensor_name] is not None:
data_torch = data_torch.permute(*permute_tensor)
break
old_dtype = data_torch.dtype
@ -1556,6 +1561,12 @@ class DbrxModel(Model):
n_dims = len(data.shape)
data_dtype = data.dtype
# Most of the codebase that takes in 1D tensors only handles F32 tensors
# and most of the outputs tensors are F32.
if data_dtype != np.float32 and n_dims == 1:
print(f"Can not map tensor {name!r}: all 1D tensors must be F32")
sys.exit()
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)