gguf : avoid copy-pasted tensor names

This commit is contained in:
Cebtenzzre 2023-09-29 18:44:55 -04:00
parent 40e07a60f9
commit ea90d2aa8c

View file

@ -118,79 +118,103 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.STARCODER: "starcoder", MODEL_ARCH.STARCODER: "starcoder",
} }
MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = { TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_ARCH.LLAMA: { MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
}, MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_ARCH.GPTNEOX: { MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
MODEL_TENSOR.TOKEN_EMBD: "token_embd", }
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.OUTPUT: "output", MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", MODEL_ARCH.LLAMA: [
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.ATTN_NORM,
}, MODEL_TENSOR.ATTN_Q,
MODEL_ARCH.FALCON: { MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", ],
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_ARCH.GPTNEOX: [
}, MODEL_TENSOR.TOKEN_EMBD,
MODEL_ARCH.BAICHUAN: { MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ROPE_FREQS: "rope_freqs", MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", ],
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", MODEL_ARCH.FALCON: [
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.ATTN_NORM_2,
}, MODEL_TENSOR.ATTN_QKV,
MODEL_ARCH.STARCODER: { MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.POS_EMBD: "position_embd", MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.OUTPUT_NORM: "output_norm", ],
MODEL_TENSOR.OUTPUT: "output", MODEL_ARCH.BAICHUAN: [
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.ATTN_Q,
}, MODEL_TENSOR.ATTN_K,
MODEL_ARCH.GPT2: { MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.STARCODER: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.GPT2: [
# TODO # TODO
}, ],
# TODO # TODO
} }
MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
{t: TENSOR_NAMES[t] for t in ts} for m, ts in MODEL_TENSORS.items()
}
# tensors that will not be serialized # tensors that will not be serialized
MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_ARCH.LLAMA: [ MODEL_ARCH.LLAMA: [