convert-new.py : tensor name mapping

This commit is contained in:
Georgi Gerganov 2023-08-17 13:15:17 +03:00
parent e970845383
commit 86bc9d2750
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 205 additions and 113 deletions

View file

@ -45,14 +45,6 @@ DT_BF16 = UnquantizedDataType('BF16')
DataType = Union[UnquantizedDataType] DataType = Union[UnquantizedDataType]
DATA_TYPE_TO_FTYPE: Dict[DataType, int] = {
DT_F32: 0,
DT_F16: 1,
}
FTYPE_TO_DATA_TYPE: Dict[int, DataType] = \
{ftype: dtype for (dtype, ftype) in DATA_TYPE_TO_FTYPE.items()}
DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = { DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
DT_BF16: np.dtype(np.uint16), DT_BF16: np.dtype(np.uint16),
DT_F16: np.dtype(np.float16), DT_F16: np.dtype(np.float16),
@ -78,31 +70,6 @@ class GGMLFileType(enum.Enum):
else: else:
raise ValueError(self) raise ValueError(self)
# TODO: this is LLaMA specific
def make_tensors_list() -> List[str]:
ret = [
'tok_embeddings.weight',
'norm.weight',
'output.weight',
]
for i in range(80): # maximum number of layer
ret += [
f'layers.{i}.attention.wq.weight',
f'layers.{i}.attention.wk.weight',
f'layers.{i}.attention.wv.weight',
f'layers.{i}.attention.wo.weight',
f'layers.{i}.attention_norm.weight',
f'layers.{i}.feed_forward.w1.weight',
f'layers.{i}.feed_forward.w2.weight',
f'layers.{i}.feed_forward.w3.weight',
f'layers.{i}.ffn_norm.weight',
]
return ret
# TODO: this should be generalized for non-LLaMA models
TENSORS_LIST = make_tensors_list()
TENSORS_SET = set(TENSORS_LIST)
def find_n_mult(n_ff: int, n_embd: int) -> int: def find_n_mult(n_ff: int, n_embd: int) -> int:
# hardcoded magic range # hardcoded magic range
for n_mult in range(8192, 1, -1): for n_mult in range(8192, 1, -1):
@ -533,34 +500,6 @@ def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor:
s[0] = s[0] // 3 s[0] = s[0] // 3
return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description) return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description)
def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
out: LazyModel = {}
out["tok_embeddings.weight"] = model["model.embed_tokens.weight"]
out["norm.weight"] = model["model.norm.weight"]
out["output.weight"] = model["lm_head.weight"]
for i in itertools.count():
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head, params.n_head_kv)
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv)
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
out[f"layers.{i}.attention.wk.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head)
out[f"layers.{i}.attention.wv.weight"] = part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
else:
break
out[f"layers.{i}.attention.wo.weight"] = model[f"model.layers.{i}.self_attn.o_proj.weight"]
out[f"layers.{i}.feed_forward.w1.weight"] = model[f"model.layers.{i}.mlp.gate_proj.weight"]
out[f"layers.{i}.feed_forward.w2.weight"] = model[f"model.layers.{i}.mlp.down_proj.weight"]
out[f"layers.{i}.feed_forward.w3.weight"] = model[f"model.layers.{i}.mlp.up_proj.weight"]
out[f"layers.{i}.attention_norm.weight"] = model[f"model.layers.{i}.input_layernorm.weight"]
out[f"layers.{i}.ffn_norm.weight"] = model[f"model.layers.{i}.post_attention_layernorm.weight"]
return out
# Functionality that simulates `torch.load` but where individual tensors are # Functionality that simulates `torch.load` but where individual tensors are
# only loaded into memory on demand, not all at once. # only loaded into memory on demand, not all at once.
@ -750,8 +689,8 @@ class OutputFile:
def __init__(self, fname_out: Path) -> None: def __init__(self, fname_out: Path) -> None:
self.gguf = gguf.GGUFWriter.open(fname_out) self.gguf = gguf.GGUFWriter.open(fname_out)
def add_meta_arch(self, params: Params, file_type: GGMLFileType) -> None: def add_meta_arch(self, params: Params) -> None:
llm_arch = "llama" llm_arch = gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA]
self.gguf.add_architecture (llm_arch) self.gguf.add_architecture (llm_arch)
self.gguf.add_context_length (llm_arch, params.n_ctx) self.gguf.add_context_length (llm_arch, params.n_ctx)
@ -763,13 +702,6 @@ class OutputFile:
self.gguf.add_head_count_kv (llm_arch, params.n_head_kv) self.gguf.add_head_count_kv (llm_arch, params.n_head_kv)
self.gguf.add_layer_norm_rms_eps (llm_arch, params.f_norm_eps) self.gguf.add_layer_norm_rms_eps (llm_arch, params.f_norm_eps)
#def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None:
# sname = name.encode('utf-8')
# self.fout.write(struct.pack("iii", len(shape), len(sname), DATA_TYPE_TO_FTYPE[data_type]))
# self.fout.write(struct.pack("i" * len(shape), *shape[::-1]))
# self.fout.write(sname)
# self.fout.seek((self.fout.tell() + 31) & -32)
def add_meta_vocab(self, vocab: Vocab) -> None: def add_meta_vocab(self, vocab: Vocab) -> None:
tokens = [] tokens = []
scores = [] scores = []
@ -794,17 +726,17 @@ class OutputFile:
@staticmethod @staticmethod
def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None: def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None:
of = OutputFile(fname_out) of = OutputFile(fname_out)
of.add_meta_arch(params, file_type=GGMLFileType.AllF32) of.add_meta_arch(params)
of.add_meta_vocab(vocab) of.add_meta_vocab(vocab)
of.write_meta() of.write_meta()
of.close() of.close()
@staticmethod @staticmethod
def write_all(fname_out: Path, params: Params, file_type: GGMLFileType, model: LazyModel, vocab: Vocab) -> None: def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None:
check_vocab_size(params, vocab) check_vocab_size(params, vocab)
of = OutputFile(fname_out) of = OutputFile(fname_out)
of.add_meta_arch(params, file_type) of.add_meta_arch(params)
of.add_meta_vocab(vocab) of.add_meta_vocab(vocab)
def do_item(item: Tuple[str, LazyTensor]) -> NDArray: def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
@ -822,21 +754,39 @@ class OutputFile:
def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType: def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType:
wq_type = model["layers.0.attention.wq.weight"].data_type wq_type = model[gguf.MODEL_TENSOR_NAMES[gguf.MODEL_ARCH.LLAMA][gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type
if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)): if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)):
return GGMLFileType.AllF32 return GGMLFileType.AllF32
if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16): if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16):
return GGMLFileType.MostlyF16 return GGMLFileType.MostlyF16
name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
raise Exception(f"Unexpected combination of types: {name_to_type}") raise Exception(f"Unexpected combination of types: {name_to_type}")
def do_necessary_conversions(model: LazyModel, params: Params) -> LazyModel: def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
if "lm_head.weight" in model: tmap = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAMA, params.n_layer)
model = convert_transformers_to_orig(model, params)
model = filter_and_sort_tensors(model)
return model out: LazyModel = {}
for name, lazy_tensor in model.items():
name_new = name
if name in tmap:
name_new = tmap[name]
elif name.endswith(".weight") and name[:-7] in tmap:
name_new = tmap[name[:-7]] + ".weight"
elif name.endswith(".bias") and name[:-5] in tmap:
name_new = tmap[name[:-5]] + ".bias"
else:
raise Exception(f"Unexpected tensor name: {name}")
out[name_new] = lazy_tensor
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}")
return out
def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel: def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
@ -893,11 +843,6 @@ def load_some_model(path: Path) -> ModelPlus:
# Try the PyTorch patterns too, with lower priority # Try the PyTorch patterns too, with lower priority
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"] globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"]
files = [file for glob in globs for file in path.glob(glob)] files = [file for glob in globs for file in path.glob(glob)]
if not files:
# Try GGML too, but with lower priority, since if both a non-GGML
# model and a GGML model exist in the same directory, we assume the
# latter was converted from the former.
files = list(path.glob("ggml-model*.bin*"))
if not files: if not files:
raise Exception(f"Can't find model in directory {path}") raise Exception(f"Can't find model in directory {path}")
if len(files) > 1: if len(files) > 1:
@ -914,10 +859,6 @@ def load_some_model(path: Path) -> ModelPlus:
return model_plus return model_plus
def filter_and_sort_tensors(model: LazyModel) -> LazyModel:
return {name: model[name] for name in TENSORS_LIST if name in model}
def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, SentencePieceVocab]: def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, SentencePieceVocab]:
# Be extra-friendly and accept either a file or a directory. Also, if it's # Be extra-friendly and accept either a file or a directory. Also, if it's
# a directory, it might be the model directory, and tokenizer.model might # a directory, it might be the model directory, and tokenizer.model might
@ -937,8 +878,10 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, Sentence
raise FileNotFoundError( raise FileNotFoundError(
f"Could not find tokenizer.model in {path} or its parent; " f"Could not find tokenizer.model in {path} or its parent; "
"if it's in another directory, pass the directory as --vocab-dir") "if it's in another directory, pass the directory as --vocab-dir")
added_tokens_path = path.parent / "added_tokens.json"
print(f"Loading vocab file '{path}', type '{vocabtype}'") print(f"Loading vocab file '{path}', type '{vocabtype}'")
added_tokens_path = path.parent / "added_tokens.json"
if vocabtype == "bpe": if vocabtype == "bpe":
return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None) return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None)
elif vocabtype == "spm": elif vocabtype == "spm":
@ -1018,12 +961,12 @@ def main(args_in: Optional[List[str]] = None) -> None:
vocab = load_vocab(vocab_dir, args.vocabtype) vocab = load_vocab(vocab_dir, args.vocabtype)
model = model_plus.model model = model_plus.model
model = do_necessary_conversions(model, params) # TODO: utilize gguf.get_tensor_name_map model = convert_model_names(model, params) # TODO: utilize gguf.get_tensor_name_map
output_type = pick_output_type(model, args.outtype) output_type = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, output_type) model = convert_to_output_type(model, output_type)
outfile = args.outfile or default_outfile(model_plus.paths, output_type) outfile = args.outfile or default_outfile(model_plus.paths, output_type)
OutputFile.write_all(outfile, params, output_type, model, vocab) OutputFile.write_all(outfile, params, model, vocab)
print(f"Wrote {outfile}") print(f"Wrote {outfile}")

193
gguf.py
View file

@ -8,7 +8,7 @@ import sys
import struct import struct
import numpy as np import numpy as np
from enum import IntEnum from enum import IntEnum, auto
from typing import Any, IO, List from typing import Any, IO, List
# #
@ -70,34 +70,146 @@ KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world"
# recommended mapping of model tensor names for storage in gguf # recommended mapping of model tensor names for storage in gguf
# #
def get_tensor_name_map(n_blocks : int): #LLAMA_TOKEN_EMBD = "token_embd"
#LLAMA_OUTPUT_NORM = "output_norm"
#LLAMA_OUTPUT = "output"
#LLAMA_ATTN_NORM = "blk.{bid}.attn_norm"
#LLAMA_ATTN_Q = "blk.{bid}.attn_q"
#LLAMA_ATTN_K = "blk.{bid}.attn_k"
#LLAMA_ATTN_V = "blk.{bid}.attn_v"
#LLAMA_ATTN_OUTPUT = "blk.{bid}.attn_output"
#LLAMA_FFN_NORM = "blk.{bid}.ffn_norm"
#LLAMA_FFN_GATE = "blk.{bid}.ffn_gate"
#LLAMA_FFN_DOWN = "blk.{bid}.ffn_down"
#LLAMA_FFN_UP = "blk.{bid}.ffn_up"
#
#GPT_POS_EMBD = "pos_embd"
#
#FALCON_ATTN_NORM_2 = "blk.{bid}.attn_norm_2"
class MODEL_ARCH(IntEnum):
LLAMA = auto()
FALCON = auto()
GPT2 = auto()
GPTJ = auto()
GPTNEOX = auto()
MPT = auto()
class MODEL_TENSOR(IntEnum):
TOKEN_EMBD = auto()
POS_EMBD = auto()
OUTPUT = auto()
OUTPUT_NORM = auto()
ROPE_FREQS = auto()
ATTN_Q = auto()
ATTN_K = auto()
ATTN_V = auto()
ATTN_QKV = auto()
ATTN_OUT = auto()
ATTN_NORM = auto()
ATTN_NORM_2 = auto()
ATTN_ROT_EMBD = auto()
FFN_GATE = auto()
FFN_DOWN = auto()
FFN_UP = auto()
FFN_NORM = auto()
MODEL_ARCH_NAMES = {
MODEL_ARCH.LLAMA : "llama",
MODEL_ARCH.FALCON : "falcon",
MODEL_ARCH.GPT2 : "gpt-2",
MODEL_ARCH.GPTJ : "gpt-j",
MODEL_ARCH.GPTNEOX : "gpt-neox",
MODEL_ARCH.MPT : "mpt",
}
MODEL_TENSOR_NAMES = {
MODEL_ARCH.LLAMA : {
MODEL_TENSOR.TOKEN_EMBD : "tok_embd",
MODEL_TENSOR.OUTPUT_NORM : "output_norm",
MODEL_TENSOR.OUTPUT : "output",
MODEL_TENSOR.ROPE_FREQS : "rope_freqs",
MODEL_TENSOR.ATTN_NORM : "blk.{bid}.attn_norm",
MODEL_TENSOR.ATTN_Q : "blk.{bid}.attn_q",
MODEL_TENSOR.ATTN_K : "blk.{bid}.attn_k",
MODEL_TENSOR.ATTN_V : "blk.{bid}.attn_v",
MODEL_TENSOR.ATTN_OUT : "blk.{bid}.attn_output",
MODEL_TENSOR.ATTN_ROT_EMBD : "blk.{bid}.attn_rot_embd",
MODEL_TENSOR.FFN_NORM : "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_GATE : "blk.{bid}.ffn_gate",
MODEL_TENSOR.FFN_DOWN : "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP : "blk.{bid}.ffn_up",
},
MODEL_ARCH.FALCON : {
MODEL_TENSOR.TOKEN_EMBD : "tok_embd",
MODEL_TENSOR.OUTPUT_NORM : "output_norm",
MODEL_TENSOR.OUTPUT : "output",
MODEL_TENSOR.ATTN_NORM : "blk.{bid}.attn_norm",
MODEL_TENSOR.ATTN_NORM_2 : "blk.{bid}.attn_norm_2",
MODEL_TENSOR.ATTN_QKV : "blk.{bid}.attn_qkv",
MODEL_TENSOR.ATTN_OUT : "blk.{bid}.attn_output",
MODEL_TENSOR.FFN_DOWN : "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP : "blk.{bid}.ffn_up",
},
MODEL_ARCH.GPT2 : {
# TODO
},
# TODO
}
# tensors that will not be serialized
MODEL_TENSOR_SKIP = {
MODEL_ARCH.LLAMA : {
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
},
},
def get_tensor_name_map(arch : MODEL_ARCH, n_blocks : int) -> dict:
tensor_map = {} tensor_map = {}
# Token embeddings # Token embeddings
mapped_to = "token_embd" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.TOKEN_EMBD, None)
tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox
tensor_map["transformer.wte"] = mapped_to # gpt2 mpt tensor_map["transformer.wte"] = mapped_to # gpt2 mpt
tensor_map["transformer.word_embeddings"] = mapped_to # falcon tensor_map["transformer.word_embeddings"] = mapped_to # falcon
tensor_map["model.embed_tokens"] = mapped_to # llama-hf tensor_map["model.embed_tokens"] = mapped_to # llama-hf
tensor_map["tok_embeddings"] = mapped_to # llama-pth tensor_map["tok_embeddings"] = mapped_to # llama-pth
# Position embeddings # Position embeddings
mapped_to = "pos_embd" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.POS_EMBD, None)
tensor_map["transformer.wpe"] = mapped_to # gpt2 tensor_map["transformer.wpe"] = mapped_to # gpt2
# Output
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT, None)
tensor_map["embed_out"] = mapped_to # gptneox
tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf
tensor_map["output"] = mapped_to # llama-pth
# Output norm # Output norm
mapped_to = "output_norm" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT_NORM, None)
tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox
tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon
tensor_map["transformer.norm_f"] = mapped_to # mpt tensor_map["transformer.norm_f"] = mapped_to # mpt
tensor_map["model.norm"] = mapped_to # llama-hf tensor_map["model.norm"] = mapped_to # llama-hf
tensor_map["norm"] = mapped_to # llama-pth tensor_map["norm"] = mapped_to # llama-pth
# Output
mapped_to = "output" # Rope frequencies
tensor_map["embed_out"] = mapped_to # gptneox mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ROPE_FREQS, None)
tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf
tensor_map["output"] = mapped_to # llama-pth tensor_map["rope.freqs"] = mapped_to # llama-pth
# Attention and fee-forward layer blocks
# Attention and feed-forward blocks
for i in range(0,n_blocks): for i in range(0,n_blocks):
# Attention norm # Attention norm
mapped_to = "blk."+str(i)+".attn_norm" # TODO: is there are simpler way to write these 2 lines in Python?
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM, None)
mapped_to = mapped_to.format(bid=i) if mapped_to else None
tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt
@ -105,56 +217,93 @@ def get_tensor_name_map(n_blocks : int):
tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b
tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth
# Attention norm 2 # Attention norm 2
mapped_to = "blk."+str(i)+".attn_norm_2" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM_2, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b
# Attention query-key-value # Attention query-key-value
mapped_to = "blk."+str(i)+".attn_qkv" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_QKV, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon
# Attention query # Attention query
mapped_to = "blk."+str(i)+".attn_q" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_Q, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth
# Attention key # Attention key
mapped_to = "blk."+str(i)+".attn_k" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_K, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth
# Attention value # Attention value
mapped_to = "blk."+str(i)+".attn_v" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_V, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth
# Attention output # Attention output
mapped_to = "blk."+str(i)+".attn_output" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_OUT, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon
tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth
# Rotary embeddings
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_ROT_EMBD, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.rotary_emb.inv_freq"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.inner_attention.rope.freqs"] = mapped_to # llama-pth
# Feed-forward norm # Feed-forward norm
mapped_to = "blk."+str(i)+".ffn_norm" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_NORM, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt
tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth
# Feed-forward up # Feed-forward up
mapped_to = "blk."+str(i)+".ffn_up" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_UP, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon
tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth
# Feed-forward gate # Feed-forward gate
mapped_to = "blk."+str(i)+".ffn_gate" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_GATE, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth
# Feed-forward down # Feed-forward down
mapped_to = "blk."+str(i)+".ffn_down" mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_DOWN, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2 tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt