update convert.py

This commit is contained in:
slaren 2024-03-30 23:49:41 +01:00
parent 2abb6c7225
commit 6203d72651
4 changed files with 70 additions and 35 deletions

View file

@ -828,6 +828,15 @@ def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor:
return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description) return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description)
def pack_experts_lazy(lazy_tensors: list[LazyTensor]) -> LazyTensor:
def load() -> Tensor:
tensors = [lazy_tensor.load() for lazy_tensor in lazy_tensors]
return UnquantizedTensor(np.array([tensor.ndarray for tensor in tensors]))
s = lazy_tensors[0].shape.copy()
s.insert(0, len(lazy_tensors))
return LazyTensor(load, s, lazy_tensors[0].data_type, 'pack_experts ' + ' | '.join(lt.description for lt in lazy_tensors))
# Functionality that simulates `torch.load` but where individual tensors are # Functionality that simulates `torch.load` but where individual tensors are
# only loaded into memory on demand, not all at once. # only loaded into memory on demand, not all at once.
# PyTorch can't do this natively as of time of writing: # PyTorch can't do this natively as of time of writing:
@ -1246,6 +1255,19 @@ def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) ->
tmp = model tmp = model
# merge experts into one tensor
if params.n_experts > 0:
for l in range(params.n_layer):
for w in range(1, 4):
experts = []
for e in range(params.n_experts):
if f"layers.{l}.feed_forward.experts.{e}.w{w}.weight" in model:
experts.append(model[f"layers.{l}.feed_forward.experts.{e}.w{w}.weight"])
del tmp[f"layers.{l}.feed_forward.experts.{e}.w{w}.weight"]
else:
raise ValueError(f"Expert tensor not found: layers.{l}.feed_forward.experts.{e}.w{w}.weight")
tmp[f"layers.{l}.feed_forward.experts.w{w}.weight"] = pack_experts_lazy(experts)
# HF models permut or pack some of the tensors, so we need to undo that # HF models permut or pack some of the tensors, so we need to undo that
for i in itertools.count(): for i in itertools.count():
if f"model.layers.{i}.self_attn.q_proj.weight" in model: if f"model.layers.{i}.self_attn.q_proj.weight" in model:

View file

@ -221,9 +221,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn", MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}", MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",

View file

@ -231,9 +231,9 @@ class TensorNameMap:
), ),
MODEL_TENSOR.FFN_UP_EXP: ( MODEL_TENSOR.FFN_UP_EXP: (
"layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral "layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral #"model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral
"transformer.decoder_layer.{bid}.moe.{xid}.linear_v", # Grok #"transformer.decoder_layer.{bid}.moe.{xid}.linear_v", # Grok
), ),
# AWQ-activation gate # AWQ-activation gate
@ -252,9 +252,9 @@ class TensorNameMap:
), ),
MODEL_TENSOR.FFN_GATE_EXP: ( MODEL_TENSOR.FFN_GATE_EXP: (
"layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral "layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral #"model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral
"transformer.decoder_layer.{bid}.moe.{xid}.linear" # Grok #"transformer.decoder_layer.{bid}.moe.{xid}.linear" # Grok
), ),
# Feed-forward down # Feed-forward down
@ -280,10 +280,9 @@ class TensorNameMap:
), ),
MODEL_TENSOR.FFN_DOWN_EXP: ( MODEL_TENSOR.FFN_DOWN_EXP: (
"layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral "layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral #"model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral
"transformer.decoder_layer.{bid}.moe.{xid}.linear_1", # Grok #"transformer.decoder_layer.{bid}.moe.{xid}.linear_1", # Grok
), ),
MODEL_TENSOR.ATTN_Q_NORM: ( MODEL_TENSOR.ATTN_Q_NORM: (

View file

@ -426,9 +426,12 @@ enum llm_tensor {
LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP, LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_ACT, LLM_TENSOR_FFN_ACT,
LLM_TENSOR_FFN_DOWN_EXP, LLM_TENSOR_FFN_DOWN_EXP, // split experts for backward compatibility
LLM_TENSOR_FFN_GATE_EXP, LLM_TENSOR_FFN_GATE_EXP,
LLM_TENSOR_FFN_UP_EXP, LLM_TENSOR_FFN_UP_EXP,
LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_LAYER_OUT_NORM, LLM_TENSOR_LAYER_OUT_NORM,
@ -463,6 +466,9 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
}, },
}, },
{ {
@ -4464,28 +4470,34 @@ static bool llm_load_tensors(
GGML_ASSERT(hparams.n_expert > 0); GGML_ASSERT(hparams.n_expert > 0);
GGML_ASSERT(hparams.n_expert_used > 0); GGML_ASSERT(hparams.n_expert_used > 0);
// hack to merge tensors, need to clean this up layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, hparams.n_expert}, false);
// merged tensors layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, hparams.n_expert}, false);
ggml_type type_gate = ml.get_tensor_meta(tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, 0).c_str())->type; layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, hparams.n_expert}, false);
ggml_type type_down = ml.get_tensor_meta(tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, 0).c_str())->type;
ggml_type type_up = ml.get_tensor_meta(tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, 0).c_str())->type;
layer.ffn_gate_exps = ggml_new_tensor_3d(ctx_split, type_gate, n_embd, n_ff, hparams.n_expert); if (layer.ffn_down_exps == nullptr) {
layer.ffn_down_exps = ggml_new_tensor_3d(ctx_split, type_down, n_ff, n_embd, hparams.n_expert); // hack to merge tensors, need to clean this up
layer.ffn_up_exps = ggml_new_tensor_3d(ctx_split, type_up, n_embd, n_ff, hparams.n_expert); // merged tensors
ggml_type type_gate = ml.get_tensor_meta(tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, 0).c_str())->type;
ggml_type type_down = ml.get_tensor_meta(tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, 0).c_str())->type;
ggml_type type_up = ml.get_tensor_meta(tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, 0).c_str())->type;
// MoE branch layer.ffn_gate_exps = ggml_new_tensor_3d(ctx_split, type_gate, n_embd, n_ff, hparams.n_expert);
for (uint32_t x = 0; x < hparams.n_expert; ++x) { layer.ffn_down_exps = ggml_new_tensor_3d(ctx_split, type_down, n_ff, n_embd, hparams.n_expert);
// individual tensors as views layer.ffn_up_exps = ggml_new_tensor_3d(ctx_split, type_up, n_embd, n_ff, hparams.n_expert);
ggml_tensor * ffn_gate_exp = ggml_view_2d(ctx_split, layer.ffn_gate_exps, n_embd, n_ff, layer.ffn_gate_exps->nb[1], layer.ffn_gate_exps->nb[2]*x);
ggml_tensor * ffn_down_exp = ggml_view_2d(ctx_split, layer.ffn_down_exps, n_ff, n_embd, layer.ffn_down_exps->nb[1], layer.ffn_down_exps->nb[2]*x);
ggml_tensor * ffn_up_exp = ggml_view_2d(ctx_split, layer.ffn_up_exps, n_embd, n_ff, layer.ffn_up_exps->nb[1], layer.ffn_up_exps->nb[2]*x);
ggml_set_name(ffn_gate_exp, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x).c_str()); // MoE branch
ggml_set_name(ffn_down_exp, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x).c_str()); for (uint32_t x = 0; x < hparams.n_expert; ++x) {
ggml_set_name(ffn_up_exp, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x).c_str()); // individual tensors as views
ggml_tensor * ffn_gate_exp = ggml_view_2d(ctx_split, layer.ffn_gate_exps, n_embd, n_ff, layer.ffn_gate_exps->nb[1], layer.ffn_gate_exps->nb[2]*x);
ggml_tensor * ffn_down_exp = ggml_view_2d(ctx_split, layer.ffn_down_exps, n_ff, n_embd, layer.ffn_down_exps->nb[1], layer.ffn_down_exps->nb[2]*x);
ggml_tensor * ffn_up_exp = ggml_view_2d(ctx_split, layer.ffn_up_exps, n_embd, n_ff, layer.ffn_up_exps->nb[1], layer.ffn_up_exps->nb[2]*x);
ml.n_created += 3; // hack ggml_set_name(ffn_gate_exp, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x).c_str());
ggml_set_name(ffn_down_exp, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x).c_str());
ggml_set_name(ffn_up_exp, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x).c_str());
ml.n_created += 3; // hack
}
} }
} }
} }
@ -12933,7 +12945,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
// sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
// for getting the current layer as I initially thought, and we need to resort to parsing the // for getting the current layer as I initially thought, and we need to resort to parsing the
// tensor name. // tensor name.
n_layer /= n_expert;
// hack
//n_layer /= n_expert;
if (sscanf(name, "blk.%d.", &i_layer) != 1) { if (sscanf(name, "blk.%d.", &i_layer) != 1) {
throw std::runtime_error(format("Failed to determine layer for tensor %s", name)); throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
} }
@ -13412,8 +13426,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// This used to be a regex, but <regex> has an extreme cost to compile times. // This used to be a regex, but <regex> has an extreme cost to compile times.
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
// quantize only 2D tensors // quantize only 2D and 3D tensors (experts)
quantize &= (ggml_n_dims(tensor) == 2); quantize &= (ggml_n_dims(tensor) >= 2);
quantize &= params->quantize_output_tensor || name != "output.weight"; quantize &= params->quantize_output_tensor || name != "output.weight";
quantize &= !params->only_copy; quantize &= !params->only_copy;