gguf: Refactor tensor name mapping
This commit is contained in:
parent
c7b0952eb7
commit
531746e953
8 changed files with 187 additions and 187 deletions
|
@ -238,11 +238,8 @@ for part_name in part_names:
|
||||||
data = data.squeeze().numpy()
|
data = data.squeeze().numpy()
|
||||||
|
|
||||||
# map tensor names
|
# map tensor names
|
||||||
if name.endswith(".weight") and name[:-7] in tensor_map:
|
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
|
||||||
name = tensor_map[name[:-7]] + ".weight"
|
if new_name is None:
|
||||||
elif name.endswith(".bias") and name[:-5] in tensor_map:
|
|
||||||
name = tensor_map[name[:-5]] + ".bias"
|
|
||||||
else:
|
|
||||||
print("Can not map tensor '" + name + "'")
|
print("Can not map tensor '" + name + "'")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
|
@ -261,9 +258,9 @@ for part_name in part_names:
|
||||||
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
||||||
|
|
||||||
gguf_writer.add_tensor(name, data)
|
gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
print("gguf: write header")
|
print("gguf: write header")
|
||||||
|
|
|
@ -111,7 +111,7 @@ gguf_writer.add_layer_norm_eps(hparams["layer_norm_eps"])
|
||||||
|
|
||||||
print("gguf: get tokenizer metadata")
|
print("gguf: get tokenizer metadata")
|
||||||
|
|
||||||
tokens: List[str] = []
|
tokens: List[bytearray] = []
|
||||||
merges: List[str] = []
|
merges: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
@ -200,7 +200,7 @@ tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
|
||||||
print("gguf: get tensor metadata")
|
print("gguf: get tensor metadata")
|
||||||
|
|
||||||
if num_parts == 0:
|
if num_parts == 0:
|
||||||
part_names = ("pytorch_model.bin",)
|
part_names = iter(("pytorch_model.bin",))
|
||||||
else:
|
else:
|
||||||
part_names = (
|
part_names = (
|
||||||
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
|
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
|
||||||
|
@ -226,11 +226,8 @@ for part_name in part_names:
|
||||||
data = data.squeeze().numpy()
|
data = data.squeeze().numpy()
|
||||||
|
|
||||||
# map tensor names
|
# map tensor names
|
||||||
if name.endswith(".weight") and name[:-7] in tensor_map:
|
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
|
||||||
name = tensor_map[name[:-7]] + ".weight"
|
if new_name is None:
|
||||||
elif name.endswith(".bias") and name[:-5] in tensor_map:
|
|
||||||
name = tensor_map[name[:-5]] + ".bias"
|
|
||||||
else:
|
|
||||||
print("Can not map tensor '" + name + "'")
|
print("Can not map tensor '" + name + "'")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
|
@ -249,9 +246,9 @@ for part_name in part_names:
|
||||||
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
||||||
|
|
||||||
gguf_writer.add_tensor(name, data)
|
gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
print("gguf: write header")
|
print("gguf: write header")
|
||||||
|
|
|
@ -11,7 +11,7 @@ import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Any, List
|
from typing import Any, List, TypeAlias
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
|
@ -266,11 +266,8 @@ for part_name in part_names:
|
||||||
data = data.squeeze().numpy()
|
data = data.squeeze().numpy()
|
||||||
|
|
||||||
# map tensor names
|
# map tensor names
|
||||||
if name.endswith(".weight") and name[:-7] in tensor_map:
|
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
|
||||||
name = tensor_map[name[:-7]] + ".weight"
|
if new_name is None:
|
||||||
elif name.endswith(".bias") and name[:-5] in tensor_map:
|
|
||||||
name = tensor_map[name[:-5]] + ".bias"
|
|
||||||
else:
|
|
||||||
print("Can not map tensor '" + name + "'")
|
print("Can not map tensor '" + name + "'")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
|
@ -289,9 +286,9 @@ for part_name in part_names:
|
||||||
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
||||||
|
|
||||||
gguf_writer.add_tensor(name, data)
|
gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
print("gguf: write header")
|
print("gguf: write header")
|
||||||
|
|
|
@ -259,20 +259,13 @@ class GGMLToGGUF:
|
||||||
gguf_writer.add_eos_token_id(2)
|
gguf_writer.add_eos_token_id(2)
|
||||||
|
|
||||||
def add_tensors(self, gguf_writer):
|
def add_tensors(self, gguf_writer):
|
||||||
nm = self.name_map
|
tensor_map = self.name_map
|
||||||
data = self.data
|
data = self.data
|
||||||
print(f'* Adding {len(self.model.tensors)} tensor(s)')
|
print(f'* Adding {len(self.model.tensors)} tensor(s)')
|
||||||
for tensor in self.model.tensors:
|
for tensor in self.model.tensors:
|
||||||
name = str(tensor.name, 'UTF-8')
|
name = str(tensor.name, 'UTF-8')
|
||||||
if name.endswith('.weight'):
|
mapped_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
|
||||||
name = name[:-7]
|
|
||||||
suffix = '.weight'
|
|
||||||
elif name.endswith('.bias'):
|
|
||||||
name = name[:-5]
|
|
||||||
suffix = '.bias'
|
|
||||||
mapped_name = nm.get(name)
|
|
||||||
assert mapped_name is not None, f'Bad name {name}'
|
assert mapped_name is not None, f'Bad name {name}'
|
||||||
mapped_name += suffix
|
|
||||||
tempdims = list(tensor.dims[:])
|
tempdims = list(tensor.dims[:])
|
||||||
if len(tempdims) > 1:
|
if len(tempdims) > 1:
|
||||||
temp = tempdims[1]
|
temp = tempdims[1]
|
||||||
|
|
|
@ -286,11 +286,8 @@ for part_name in part_names:
|
||||||
data = reverse_hf_permute(data, head_count, head_count_kv)
|
data = reverse_hf_permute(data, head_count, head_count_kv)
|
||||||
|
|
||||||
# map tensor names
|
# map tensor names
|
||||||
if name.endswith(".weight") and name[:-7] in tensor_map:
|
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
|
||||||
name = tensor_map[name[:-7]] + ".weight"
|
if new_name is None:
|
||||||
elif name.endswith(".bias") and name[:-5] in tensor_map:
|
|
||||||
name = tensor_map[name[:-5]] + ".bias"
|
|
||||||
else:
|
|
||||||
print("Can not map tensor '" + name + "'")
|
print("Can not map tensor '" + name + "'")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
|
@ -309,9 +306,9 @@ for part_name in part_names:
|
||||||
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
|
||||||
|
|
||||||
gguf_writer.add_tensor(name, data)
|
gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
print("gguf: write header")
|
print("gguf: write header")
|
||||||
|
|
|
@ -4,7 +4,7 @@ import os
|
||||||
import re
|
import re
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, Sequence, TextIO
|
from typing import Any, Dict, Sequence, BinaryIO
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -46,7 +46,7 @@ def translate_tensor_name(t: str) -> str:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def write_file_header(fout: TextIO, params: Dict[str, Any]) -> None:
|
def write_file_header(fout: BinaryIO, params: Dict[str, Any]) -> None:
|
||||||
fout.write(b"ggla"[::-1]) # magic (ggml lora)
|
fout.write(b"ggla"[::-1]) # magic (ggml lora)
|
||||||
fout.write(struct.pack("i", 1)) # file version
|
fout.write(struct.pack("i", 1)) # file version
|
||||||
fout.write(struct.pack("i", params["r"]))
|
fout.write(struct.pack("i", params["r"]))
|
||||||
|
|
22
convert.py
22
convert.py
|
@ -1013,7 +1013,8 @@ def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyM
|
||||||
for (name, tensor) in model.items()}
|
for (name, tensor) in model.items()}
|
||||||
|
|
||||||
def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
|
def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
|
||||||
tmap = gguf.get_tensor_name_map(ARCH, params.n_layer)
|
tmap = gguf.TensorNameMap(ARCH, params.n_layer)
|
||||||
|
should_skip: Set[gguf.MODEL_TENSOR] = gguf.MODEL_TENSOR_SKIP.get(ARCH, set())
|
||||||
|
|
||||||
tmp = model
|
tmp = model
|
||||||
|
|
||||||
|
@ -1035,23 +1036,16 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
|
||||||
|
|
||||||
out: LazyModel = {}
|
out: LazyModel = {}
|
||||||
for name, lazy_tensor in model.items():
|
for name, lazy_tensor in model.items():
|
||||||
name_new = name
|
tensor_type, name_new = tmap.get_both(name, try_suffixes = (".weight", ".bias")) or (None, None)
|
||||||
|
if name_new is None:
|
||||||
if name in tmap:
|
|
||||||
name_new = tmap[name]
|
|
||||||
elif name.endswith(".weight") and name[:-7] in tmap:
|
|
||||||
name_new = tmap[name[:-7]] + ".weight"
|
|
||||||
elif name.endswith(".bias") and name[:-5] in tmap:
|
|
||||||
name_new = tmap[name[:-5]] + ".bias"
|
|
||||||
else:
|
|
||||||
raise Exception(f"Unexpected tensor name: {name}")
|
raise Exception(f"Unexpected tensor name: {name}")
|
||||||
|
|
||||||
if gguf.should_skip_tensor_TMP(ARCH, params.n_layer, name_new):
|
if tensor_type in should_skip:
|
||||||
print(f"skipping tensor {name_new}")
|
print(f"skipping tensor {name_new}")
|
||||||
continue
|
continue
|
||||||
else:
|
|
||||||
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
|
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
|
||||||
out[name_new] = lazy_tensor
|
out[name_new] = lazy_tensor
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import Any, BinaryIO, IO, Dict, List, Optional, Tuple
|
from typing import Any, BinaryIO, IO, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
#
|
#
|
||||||
# constants
|
# constants
|
||||||
|
@ -162,167 +162,192 @@ MODEL_TENSOR_SKIP: Dict[MODEL_ARCH, List[MODEL_TENSOR]] = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# TODO: the following helper functions should be removed
|
class TensorNameMap:
|
||||||
# instead, get_tensor_name_map should return tuples of (name, MODEL_TENSOR)
|
mappings_cfg: Dict[MODEL_TENSOR, Tuple[str, ...]] = {
|
||||||
# however, my Python is very bad, and I couldn't figure out how to do this, hence these functions
|
# Token embeddings
|
||||||
# REMOVE
|
MODEL_TENSOR.TOKEN_EMBD: (
|
||||||
def should_skip_tensor_TMP(arch: MODEL_ARCH, n_blocks: int, name: str) -> bool:
|
"gpt_neox.embed_in", # gptneox
|
||||||
for skip in MODEL_TENSOR_SKIP.get(arch, []):
|
"transformer.wte", # gpt2 mpt
|
||||||
for i in range(n_blocks):
|
"transformer.word_embeddings", # falcon
|
||||||
if name == MODEL_TENSOR_NAMES[arch][skip].format(bid=i):
|
"model.embed_tokens", # llama-hf
|
||||||
return True
|
"tok_embeddings", # llama-pth
|
||||||
|
),
|
||||||
|
|
||||||
return False
|
# Position embeddings
|
||||||
|
MODEL_TENSOR.POS_EMBD: (
|
||||||
|
"transformer.wpe", # gpt2
|
||||||
|
),
|
||||||
|
|
||||||
|
# Output
|
||||||
|
MODEL_TENSOR.OUTPUT: (
|
||||||
|
"embed_out", # gptneox
|
||||||
|
"lm_head", # gpt2 mpt falcon llama-hf
|
||||||
|
"output", # llama-pth
|
||||||
|
),
|
||||||
|
|
||||||
def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> dict:
|
# Output norm
|
||||||
tensor_map = {}
|
MODEL_TENSOR.OUTPUT_NORM: (
|
||||||
|
"gpt_neox.final_layer_norm", # gptneox
|
||||||
|
"transformer.ln_f", # gpt2 falcon
|
||||||
|
"model.norm", # llama-hf
|
||||||
|
"norm", # llama-pth
|
||||||
|
),
|
||||||
|
|
||||||
# Token embeddings
|
# Rope frequencies
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.TOKEN_EMBD, None)
|
MODEL_TENSOR.ROPE_FREQS: (
|
||||||
|
"rope.freqs", # llama-pth
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox
|
block_mappings_cfg: Dict[MODEL_TENSOR, Tuple[str, ...]] = {
|
||||||
tensor_map["transformer.wte"] = mapped_to # gpt2 mpt
|
|
||||||
tensor_map["transformer.word_embeddings"] = mapped_to # falcon
|
|
||||||
tensor_map["model.embed_tokens"] = mapped_to # llama-hf
|
|
||||||
tensor_map["tok_embeddings"] = mapped_to # llama-pth
|
|
||||||
|
|
||||||
# Position embeddings
|
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.POS_EMBD, None)
|
|
||||||
|
|
||||||
tensor_map["transformer.wpe"] = mapped_to # gpt2
|
|
||||||
|
|
||||||
# Output
|
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT, None)
|
|
||||||
|
|
||||||
tensor_map["embed_out"] = mapped_to # gptneox
|
|
||||||
tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf
|
|
||||||
tensor_map["output"] = mapped_to # llama-pth
|
|
||||||
|
|
||||||
# Output norm
|
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT_NORM, None)
|
|
||||||
|
|
||||||
tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox
|
|
||||||
tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon
|
|
||||||
tensor_map["transformer.norm_f"] = mapped_to # mpt
|
|
||||||
tensor_map["model.norm"] = mapped_to # llama-hf
|
|
||||||
tensor_map["norm"] = mapped_to # llama-pth
|
|
||||||
|
|
||||||
# Rope frequencies
|
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ROPE_FREQS, None)
|
|
||||||
|
|
||||||
tensor_map["rope.freqs"] = mapped_to # llama-pth
|
|
||||||
|
|
||||||
# Attention and feed-forward blocks
|
|
||||||
for i in range(0, n_blocks):
|
|
||||||
# Attention norm
|
# Attention norm
|
||||||
# TODO: is there are simpler way to write these 2 lines in Python?
|
MODEL_TENSOR.ATTN_NORM: (
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM, None)
|
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
|
||||||
mapped_to = mapped_to.format(bid=i) if mapped_to else None
|
"transformer.h.{bid}.ln_1", # gpt2
|
||||||
|
"transformer.blocks.{bid}.norm_1", # mpt
|
||||||
tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox
|
"transformer.h.{bid}.input_layernorm", # falcon7b
|
||||||
tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2
|
"transformer.h.{bid}.ln_mlp", # falcon40b
|
||||||
tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt
|
"model.layers.{bid}.input_layernorm", # llama-hf
|
||||||
tensor_map["transformer.h."+str(i)+".input_layernorm"] = mapped_to # falcon7b
|
"layers.{bid}.attention_norm", # llama-pth
|
||||||
tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b
|
),
|
||||||
tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf
|
|
||||||
tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth
|
|
||||||
|
|
||||||
# Attention norm 2
|
# Attention norm 2
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM_2, None)
|
MODEL_TENSOR.ATTN_NORM_2: (
|
||||||
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
|
"transformer.h.{bid}.ln_attn", # falcon40b
|
||||||
|
),
|
||||||
tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b
|
|
||||||
|
|
||||||
# Attention query-key-value
|
# Attention query-key-value
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_QKV, None)
|
MODEL_TENSOR.ATTN_QKV: (
|
||||||
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
|
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
|
||||||
|
"transformer.h.{bid}.attn.c_attn", # gpt2
|
||||||
tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox
|
"transformer.blocks.{bid}.attn.Wqkv", # mpt
|
||||||
tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2
|
"transformer.h.{bid}.self_attention.query_key_value", # falcon
|
||||||
tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt
|
),
|
||||||
tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon
|
|
||||||
|
|
||||||
# Attention query
|
# Attention query
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_Q, None)
|
MODEL_TENSOR.ATTN_Q: (
|
||||||
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
|
"model.layers.{bid}.self_attn.q_proj", # llama-hf
|
||||||
|
"layers.{bid}.attention.wq", # llama-pth
|
||||||
tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf
|
),
|
||||||
tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth
|
|
||||||
|
|
||||||
# Attention key
|
# Attention key
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_K, None)
|
MODEL_TENSOR.ATTN_K: (
|
||||||
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
|
"model.layers.{bid}.self_attn.k_proj", # llama-hf
|
||||||
|
"layers.{bid}.attention.wk", # llama-pth
|
||||||
tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf
|
),
|
||||||
tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth
|
|
||||||
|
|
||||||
# Attention value
|
# Attention value
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_V, None)
|
MODEL_TENSOR.ATTN_V: (
|
||||||
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
|
"model.layers.{bid}.self_attn.v_proj", # llama-hf
|
||||||
|
"layers.{bid}.attention.wv", # llama-pth
|
||||||
tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf
|
),
|
||||||
tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth
|
|
||||||
|
|
||||||
# Attention output
|
# Attention output
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_OUT, None)
|
MODEL_TENSOR.ATTN_OUT: (
|
||||||
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
|
"gpt_neox.layers.{bid}.attention.dense", # gptneox
|
||||||
|
"transformer.h.{bid}.attn.c_proj", # gpt2
|
||||||
tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox
|
"transformer.blocks.{bid}.attn.out_proj", # mpt
|
||||||
tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2
|
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||||
tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt
|
"model.layers.{bid}.self_attn.o_proj", # llama-hf
|
||||||
tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon
|
"layers.{bid}.attention.wo", # llama-pth
|
||||||
tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf
|
),
|
||||||
tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth
|
|
||||||
|
|
||||||
# Rotary embeddings
|
# Rotary embeddings
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_ROT_EMBD, None)
|
MODEL_TENSOR.ATTN_ROT_EMBD: (
|
||||||
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
|
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
|
||||||
|
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
|
||||||
tensor_map["model.layers."+str(i)+".self_attn.rotary_emb.inv_freq"] = mapped_to # llama-hf
|
),
|
||||||
tensor_map["layers."+str(i)+".attention.inner_attention.rope.freqs"] = mapped_to # llama-pth
|
|
||||||
|
|
||||||
# Feed-forward norm
|
# Feed-forward norm
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_NORM, None)
|
MODEL_TENSOR.FFN_NORM: (
|
||||||
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
|
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
|
||||||
|
"transformer.h.{bid}.ln_2", # gpt2
|
||||||
tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox
|
"transformer.blocks.{bid}.norm_2", # mpt
|
||||||
tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2
|
"model.layers.{bid}.post_attention_layernorm", # llama-hf
|
||||||
tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt
|
"layers.{bid}.ffn_norm", # llama-pth
|
||||||
tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf
|
),
|
||||||
tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth
|
|
||||||
|
|
||||||
# Feed-forward up
|
# Feed-forward up
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_UP, None)
|
MODEL_TENSOR.FFN_UP: (
|
||||||
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
|
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
|
||||||
|
"transformer.h.{bid}.mlp.c_fc", # gpt2
|
||||||
tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox
|
"transformer.blocks.{bid}.ffn.up_proj", # mpt
|
||||||
tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2
|
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
|
||||||
tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt
|
"model.layers.{bid}.mlp.up_proj", # llama-hf
|
||||||
tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon
|
"layers.{bid}.feed_forward.w3", # llama-pth
|
||||||
tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf
|
),
|
||||||
tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth
|
|
||||||
|
|
||||||
# Feed-forward gate
|
# Feed-forward gate
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_GATE, None)
|
MODEL_TENSOR.FFN_GATE: (
|
||||||
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
|
"model.layers.{bid}.mlp.gate_proj", # llama-hf
|
||||||
|
"layers.{bid}.feed_forward.w1", # llama-pth
|
||||||
tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf
|
),
|
||||||
tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth
|
|
||||||
|
|
||||||
# Feed-forward down
|
# Feed-forward down
|
||||||
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_DOWN, None)
|
MODEL_TENSOR.FFN_DOWN: (
|
||||||
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
|
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
|
||||||
|
"transformer.h.{bid}.mlp.c_proj", # gpt2
|
||||||
|
"transformer.blocks.{bid}.ffn.down_proj", # mpt
|
||||||
|
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
|
||||||
|
"model.layers.{bid}.mlp.down_proj", # llama-hf
|
||||||
|
"layers.{bid}.feed_forward.w2", # llama-pth
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox
|
mapping: Dict[str, Tuple[MODEL_TENSOR, str]]
|
||||||
tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2
|
|
||||||
tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt
|
|
||||||
tensor_map["transformer.h."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # falcon
|
|
||||||
tensor_map["model.layers."+str(i)+".mlp.down_proj"] = mapped_to # llama-hf
|
|
||||||
tensor_map["layers."+str(i)+".feed_forward.w2"] = mapped_to # llama-pth
|
|
||||||
|
|
||||||
return tensor_map
|
tensor_names: Dict[MODEL_TENSOR, str]
|
||||||
|
|
||||||
|
def __init__(self, arch: MODEL_ARCH, n_blocks: int):
|
||||||
|
mapping = self.mapping = {}
|
||||||
|
tensor_names = self.tensor_names = MODEL_TENSOR_NAMES[arch]
|
||||||
|
for tensor, keys in self.mappings_cfg.items():
|
||||||
|
tensor_name = tensor_names.get(tensor)
|
||||||
|
if tensor_name is None:
|
||||||
|
continue
|
||||||
|
for key in keys:
|
||||||
|
mapping[key] = (tensor, tensor_name)
|
||||||
|
for bid in range(n_blocks):
|
||||||
|
for tensor, keys in self.block_mappings_cfg.items():
|
||||||
|
tensor_name = tensor_names.get(tensor)
|
||||||
|
if tensor_name is None:
|
||||||
|
continue
|
||||||
|
tensor_name = tensor_name.format(bid = bid)
|
||||||
|
for key in keys:
|
||||||
|
key = key.format(bid = bid)
|
||||||
|
mapping[key] = (tensor, tensor_name)
|
||||||
|
|
||||||
|
def get_both(self, key: str, try_suffixes: Sequence[str]) -> Optional[Tuple[MODEL_TENSOR, str]]:
|
||||||
|
result = self.mapping.get(key)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
for suffix in try_suffixes:
|
||||||
|
if key.endswith(suffix):
|
||||||
|
result = self.mapping.get(key[:-len(suffix)])
|
||||||
|
if result is not None:
|
||||||
|
return (result[0], result[1] + suffix)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[str]:
|
||||||
|
result = self.get_both(key, try_suffixes = try_suffixes)
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
return result[1]
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> str:
|
||||||
|
try:
|
||||||
|
return self.mapping[key][1]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError(key)
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
return key in self.mapping
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return repr(self.mapping)
|
||||||
|
|
||||||
|
def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> TensorNameMap:
|
||||||
|
return TensorNameMap(arch, n_blocks)
|
||||||
|
|
||||||
class TokenType(IntEnum):
|
class TokenType(IntEnum):
|
||||||
NORMAL = 1
|
NORMAL = 1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue