Use only one list of weight names, with values from the gguf module.

This saves weights in the order in which they are in the Grok-1 files.
Since we operate weight-by-weight now, we no longer need caches and
name2key translations.

Per reviewer request, I also moved to using keys in gguf.TENSOR_NAMES.
This commit is contained in:
Heiner 2024-05-03 23:44:47 +02:00
parent 3c57743874
commit 08427630c3

View file

@ -196,55 +196,47 @@ def get_dtype_and_ggml_type(tensor, ggml_type):
return np.float32, gguf.GGMLQuantizationType.F32
def dump_state_dict(f, weight_names, model_files, ggml_type, config):
keys2names = {}
meta_tensors = {}
weight_scales = {}
def dump_state_dict(f, ggml_type, input_dir, config):
weight_names = get_weight_names(config.num_hidden_layers)
weights = {}
# First operate on meta tensors to find shapes and dtypes for GGUF header.
for name, fn in model_files:
weight, scales = get_weights(fn)
meta_state_dict = convert_weight(name, weight, scales, config.experts, device="meta")
weight_scales[name] = (weight, scales)
for key in meta_state_dict.keys():
keys2names[key] = name
meta_tensors.update(meta_state_dict)
for key in weight_names:
meta_tensor = meta_tensors[key]
for idx, name in enumerate(weight_names):
weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000")
meta_tensor = convert_weight(name, weight, scales, config.experts, device="meta")
dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type)
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
f.add_tensor_info(
key, list(meta_tensor.shape), dtype, quantized_meta_tensor.nbytes, tensor_ggml_type
f"{name}.weight",
list(meta_tensor.shape),
dtype,
quantized_meta_tensor.nbytes,
tensor_ggml_type,
)
print("Loaded", len(meta_tensors), "files")
weights[name] = weight, scales
print("Loaded", len(weight_names), "files")
f.write_header_to_file()
f.write_kv_data_to_file()
f.write_ti_data_to_file()
cache = {}
# Now write actual tensor data.
tensor_info = []
for key in weight_names:
if key not in cache:
name = keys2names[key]
weight, scales = weight_scales.pop(name)
state_dict = convert_weight(name, weight, scales, config.experts)
permute_tensors(state_dict, config)
cache.update(state_dict)
tensor = cache.pop(key)
for name in weight_names:
weight, scales = weights.pop(name)
tensor = convert_weight(name, weight, scales, config.experts)
tensor = maybe_permute_tensor(name, tensor, config)
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()
print(
f"dumping {key}:",
f"dumping {name}:",
f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes",
)
f.write_tensor_data(array)
tensor_info.append((key, list(tensor.shape), tensor_ggml_type.name))
tensor_info.append((name, list(tensor.shape), tensor_ggml_type.name))
try:
print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql"))
@ -263,10 +255,8 @@ def from_numpy(array):
return torch.from_numpy(array)
def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, device=None):
def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=None):
# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
result = {}
weight = from_numpy(weight).to(device=device, dtype=dtype)
if scales is not None:
scale = from_numpy(scales).to(device=device, dtype=dtype)
@ -279,18 +269,16 @@ def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, de
weight = weight * scale
# Transpose linear matrix
if len(weight.shape) >= 2 and "token_embd" not in tensor_name:
if len(weight.shape) >= 2 and "token_embd" not in name:
weight = weight.transpose(-1, -2)
if tensor_name.endswith("ffn_gate_inp.weight") or tensor_name.endswith("_exps.weight"):
result[tensor_name] = weight[experts] # gather.
elif "experts" not in tensor_name:
result[tensor_name] = weight
if name.endswith("ffn_gate_inp") or name.endswith("_exps"):
weight = weight[experts] # gather.
return result
return weight
def permute_tensors(state_dict, config):
def maybe_permute_tensor(name, tensor, config):
def permute(weights, n_head):
return (
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
@ -298,11 +286,12 @@ def permute_tensors(state_dict, config):
.reshape(weights.shape)
)
for name, tensor in state_dict.items():
if name.endswith("attn_k.weight"):
state_dict[name] = permute(tensor, config.num_key_value_heads)
elif name.endswith("attn_q.weight"):
state_dict[name] = permute(tensor, config.num_attention_heads)
if name.endswith("attn_k"):
return permute(tensor, config.num_key_value_heads)
elif name.endswith("attn_q"):
return permute(tensor, config.num_attention_heads)
return tensor
def extract_vocabulary_from_model(vocab):
@ -320,25 +309,32 @@ def extract_vocabulary_from_model(vocab):
return tokens, scores, toktypes
def get_weight_names(config):
weight_names = ["token_embd.weight"]
for i in range(config.num_hidden_layers):
weight_names += [
f"blk.{i}.ffn_gate_exps.weight",
f"blk.{i}.ffn_down_exps.weight",
f"blk.{i}.ffn_up_exps.weight",
f"blk.{i}.attn_k.weight",
f"blk.{i}.attn_output.weight",
f"blk.{i}.attn_q.weight",
f"blk.{i}.attn_v.weight",
f"blk.{i}.attn_norm.weight",
f"blk.{i}.attn_output_norm.weight",
f"blk.{i}.ffn_norm.weight",
f"blk.{i}.layer_output_norm.weight",
f"blk.{i}.ffn_gate_inp.weight",
]
def get_weight_names(num_hidden_layers=64):
"""Return Grok-1 weight names, in the order in which they are in the tensor#####_000 files."""
weight_names += ["output_norm.weight"]
weight_names = [
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD],
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT_NORM],
]
layer = (
gguf.MODEL_TENSOR.FFN_GATE_EXP,
gguf.MODEL_TENSOR.FFN_DOWN_EXP,
gguf.MODEL_TENSOR.FFN_UP_EXP,
gguf.MODEL_TENSOR.ATTN_K,
gguf.MODEL_TENSOR.ATTN_OUT,
gguf.MODEL_TENSOR.ATTN_Q,
gguf.MODEL_TENSOR.ATTN_V,
gguf.MODEL_TENSOR.ATTN_NORM,
gguf.MODEL_TENSOR.ATTN_OUT_NORM,
gguf.MODEL_TENSOR.FFN_NORM,
gguf.MODEL_TENSOR.LAYER_OUT_NORM,
gguf.MODEL_TENSOR.FFN_GATE_INP,
)
for bid in range(num_hidden_layers):
for key in layer:
weight_names.append(gguf.TENSOR_NAMES[key].format(bid=bid))
return weight_names
@ -383,28 +379,6 @@ def convert_grok(args, vocab, ggml_type):
assert config.num_experts >= 2, "need at least 2 experts"
print("experts to export:", config.experts)
# Contents of in Grok-1 pickle files, in order. Weights with "experts" will be split later.
tensor_names = [
"token_embd.weight",
"output_norm.weight",
]
for i in range(config.num_hidden_layers):
tensor_names += [
f"blk.{i}.ffn_gate_exps.weight",
f"blk.{i}.ffn_down_exps.weight",
f"blk.{i}.ffn_up_exps.weight",
f"blk.{i}.attn_k.weight",
f"blk.{i}.attn_output.weight",
f"blk.{i}.attn_q.weight",
f"blk.{i}.attn_v.weight",
f"blk.{i}.attn_norm.weight",
f"blk.{i}.attn_output_norm.weight",
f"blk.{i}.ffn_norm.weight",
f"blk.{i}.layer_output_norm.weight",
f"blk.{i}.ffn_gate_inp.weight",
]
tensor_map = [(name, f"{args.input}/tensor{i:05}_000") for i, name in enumerate(tensor_names)]
f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE)
f.add_name("grok")
@ -430,8 +404,7 @@ def convert_grok(args, vocab, ggml_type):
f.add_token_scores(scores)
f.add_token_types(toktypes)
weight_names = get_weight_names(config)
dump_state_dict(f, weight_names, tensor_map, ggml_type, config)
dump_state_dict(f, ggml_type, args.input_dir, config)
f.close()
delta = time.time() - start
@ -465,7 +438,7 @@ def load_vocab(path):
def main():
parser = argparse.ArgumentParser("convert_grok")
parser.add_argument("-i", "--input", type=str)
parser.add_argument("-i", "--input_dir", type=str)
parser.add_argument("-o", "--save_path", type=pathlib.Path)
parser.add_argument(
"-t", "--type", type=str, default="q8_0", choices=["f32", "f16", "q8_0", "q4_0", "q4_1"]
@ -474,7 +447,9 @@ def main():
parser.add_argument("--experts", type=str, default="")
args = parser.parse_args()
vocab = load_vocab(pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input))
vocab = load_vocab(
pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input_dir)
)
ggml_type = gguf.GGMLQuantizationType[args.type.upper()]
convert_grok(args, vocab, ggml_type)