Use only one list of weight names, with values from the gguf module.

This saves weights in the order in which they are in the Grok-1 files.
Since we operate weight-by-weight now, we no longer need caches and
name2key translations.

Per reviewer request, I also moved to using keys in gguf.TENSOR_NAMES.
This commit is contained in:
Heiner 2024-05-03 23:44:47 +02:00
parent 3c57743874
commit 08427630c3

View file

@ -196,55 +196,47 @@ def get_dtype_and_ggml_type(tensor, ggml_type):
return np.float32, gguf.GGMLQuantizationType.F32 return np.float32, gguf.GGMLQuantizationType.F32
def dump_state_dict(f, weight_names, model_files, ggml_type, config): def dump_state_dict(f, ggml_type, input_dir, config):
keys2names = {} weight_names = get_weight_names(config.num_hidden_layers)
meta_tensors = {} weights = {}
weight_scales = {}
# First operate on meta tensors to find shapes and dtypes for GGUF header. # First operate on meta tensors to find shapes and dtypes for GGUF header.
for name, fn in model_files: for idx, name in enumerate(weight_names):
weight, scales = get_weights(fn) weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000")
meta_state_dict = convert_weight(name, weight, scales, config.experts, device="meta") meta_tensor = convert_weight(name, weight, scales, config.experts, device="meta")
weight_scales[name] = (weight, scales)
for key in meta_state_dict.keys():
keys2names[key] = name
meta_tensors.update(meta_state_dict)
for key in weight_names:
meta_tensor = meta_tensors[key]
dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type) dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type)
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type) quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
f.add_tensor_info( f.add_tensor_info(
key, list(meta_tensor.shape), dtype, quantized_meta_tensor.nbytes, tensor_ggml_type f"{name}.weight",
list(meta_tensor.shape),
dtype,
quantized_meta_tensor.nbytes,
tensor_ggml_type,
) )
print("Loaded", len(meta_tensors), "files") weights[name] = weight, scales
print("Loaded", len(weight_names), "files")
f.write_header_to_file() f.write_header_to_file()
f.write_kv_data_to_file() f.write_kv_data_to_file()
f.write_ti_data_to_file() f.write_ti_data_to_file()
cache = {} # Now write actual tensor data.
tensor_info = [] tensor_info = []
for key in weight_names: for name in weight_names:
if key not in cache: weight, scales = weights.pop(name)
name = keys2names[key] tensor = convert_weight(name, weight, scales, config.experts)
weight, scales = weight_scales.pop(name) tensor = maybe_permute_tensor(name, tensor, config)
state_dict = convert_weight(name, weight, scales, config.experts)
permute_tensors(state_dict, config)
cache.update(state_dict)
tensor = cache.pop(key)
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type) _, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy() array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()
print( print(
f"dumping {key}:", f"dumping {name}:",
f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes", f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes",
) )
f.write_tensor_data(array) f.write_tensor_data(array)
tensor_info.append((key, list(tensor.shape), tensor_ggml_type.name)) tensor_info.append((name, list(tensor.shape), tensor_ggml_type.name))
try: try:
print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql")) print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql"))
@ -263,10 +255,8 @@ def from_numpy(array):
return torch.from_numpy(array) return torch.from_numpy(array)
def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, device=None): def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=None):
# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1 # copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
result = {}
weight = from_numpy(weight).to(device=device, dtype=dtype) weight = from_numpy(weight).to(device=device, dtype=dtype)
if scales is not None: if scales is not None:
scale = from_numpy(scales).to(device=device, dtype=dtype) scale = from_numpy(scales).to(device=device, dtype=dtype)
@ -279,18 +269,16 @@ def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, de
weight = weight * scale weight = weight * scale
# Transpose linear matrix # Transpose linear matrix
if len(weight.shape) >= 2 and "token_embd" not in tensor_name: if len(weight.shape) >= 2 and "token_embd" not in name:
weight = weight.transpose(-1, -2) weight = weight.transpose(-1, -2)
if tensor_name.endswith("ffn_gate_inp.weight") or tensor_name.endswith("_exps.weight"): if name.endswith("ffn_gate_inp") or name.endswith("_exps"):
result[tensor_name] = weight[experts] # gather. weight = weight[experts] # gather.
elif "experts" not in tensor_name:
result[tensor_name] = weight
return result return weight
def permute_tensors(state_dict, config): def maybe_permute_tensor(name, tensor, config):
def permute(weights, n_head): def permute(weights, n_head):
return ( return (
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
@ -298,11 +286,12 @@ def permute_tensors(state_dict, config):
.reshape(weights.shape) .reshape(weights.shape)
) )
for name, tensor in state_dict.items(): if name.endswith("attn_k"):
if name.endswith("attn_k.weight"): return permute(tensor, config.num_key_value_heads)
state_dict[name] = permute(tensor, config.num_key_value_heads) elif name.endswith("attn_q"):
elif name.endswith("attn_q.weight"): return permute(tensor, config.num_attention_heads)
state_dict[name] = permute(tensor, config.num_attention_heads)
return tensor
def extract_vocabulary_from_model(vocab): def extract_vocabulary_from_model(vocab):
@ -320,25 +309,32 @@ def extract_vocabulary_from_model(vocab):
return tokens, scores, toktypes return tokens, scores, toktypes
def get_weight_names(config): def get_weight_names(num_hidden_layers=64):
weight_names = ["token_embd.weight"] """Return Grok-1 weight names, in the order in which they are in the tensor#####_000 files."""
for i in range(config.num_hidden_layers):
weight_names += [
f"blk.{i}.ffn_gate_exps.weight",
f"blk.{i}.ffn_down_exps.weight",
f"blk.{i}.ffn_up_exps.weight",
f"blk.{i}.attn_k.weight",
f"blk.{i}.attn_output.weight",
f"blk.{i}.attn_q.weight",
f"blk.{i}.attn_v.weight",
f"blk.{i}.attn_norm.weight",
f"blk.{i}.attn_output_norm.weight",
f"blk.{i}.ffn_norm.weight",
f"blk.{i}.layer_output_norm.weight",
f"blk.{i}.ffn_gate_inp.weight",
]
weight_names += ["output_norm.weight"] weight_names = [
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD],
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT_NORM],
]
layer = (
gguf.MODEL_TENSOR.FFN_GATE_EXP,
gguf.MODEL_TENSOR.FFN_DOWN_EXP,
gguf.MODEL_TENSOR.FFN_UP_EXP,
gguf.MODEL_TENSOR.ATTN_K,
gguf.MODEL_TENSOR.ATTN_OUT,
gguf.MODEL_TENSOR.ATTN_Q,
gguf.MODEL_TENSOR.ATTN_V,
gguf.MODEL_TENSOR.ATTN_NORM,
gguf.MODEL_TENSOR.ATTN_OUT_NORM,
gguf.MODEL_TENSOR.FFN_NORM,
gguf.MODEL_TENSOR.LAYER_OUT_NORM,
gguf.MODEL_TENSOR.FFN_GATE_INP,
)
for bid in range(num_hidden_layers):
for key in layer:
weight_names.append(gguf.TENSOR_NAMES[key].format(bid=bid))
return weight_names return weight_names
@ -383,28 +379,6 @@ def convert_grok(args, vocab, ggml_type):
assert config.num_experts >= 2, "need at least 2 experts" assert config.num_experts >= 2, "need at least 2 experts"
print("experts to export:", config.experts) print("experts to export:", config.experts)
# Contents of in Grok-1 pickle files, in order. Weights with "experts" will be split later.
tensor_names = [
"token_embd.weight",
"output_norm.weight",
]
for i in range(config.num_hidden_layers):
tensor_names += [
f"blk.{i}.ffn_gate_exps.weight",
f"blk.{i}.ffn_down_exps.weight",
f"blk.{i}.ffn_up_exps.weight",
f"blk.{i}.attn_k.weight",
f"blk.{i}.attn_output.weight",
f"blk.{i}.attn_q.weight",
f"blk.{i}.attn_v.weight",
f"blk.{i}.attn_norm.weight",
f"blk.{i}.attn_output_norm.weight",
f"blk.{i}.ffn_norm.weight",
f"blk.{i}.layer_output_norm.weight",
f"blk.{i}.ffn_gate_inp.weight",
]
tensor_map = [(name, f"{args.input}/tensor{i:05}_000") for i, name in enumerate(tensor_names)]
f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE) f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE)
f.add_name("grok") f.add_name("grok")
@ -430,8 +404,7 @@ def convert_grok(args, vocab, ggml_type):
f.add_token_scores(scores) f.add_token_scores(scores)
f.add_token_types(toktypes) f.add_token_types(toktypes)
weight_names = get_weight_names(config) dump_state_dict(f, ggml_type, args.input_dir, config)
dump_state_dict(f, weight_names, tensor_map, ggml_type, config)
f.close() f.close()
delta = time.time() - start delta = time.time() - start
@ -465,7 +438,7 @@ def load_vocab(path):
def main(): def main():
parser = argparse.ArgumentParser("convert_grok") parser = argparse.ArgumentParser("convert_grok")
parser.add_argument("-i", "--input", type=str) parser.add_argument("-i", "--input_dir", type=str)
parser.add_argument("-o", "--save_path", type=pathlib.Path) parser.add_argument("-o", "--save_path", type=pathlib.Path)
parser.add_argument( parser.add_argument(
"-t", "--type", type=str, default="q8_0", choices=["f32", "f16", "q8_0", "q4_0", "q4_1"] "-t", "--type", type=str, default="q8_0", choices=["f32", "f16", "q8_0", "q4_0", "q4_1"]
@ -474,7 +447,9 @@ def main():
parser.add_argument("--experts", type=str, default="") parser.add_argument("--experts", type=str, default="")
args = parser.parse_args() args = parser.parse_args()
vocab = load_vocab(pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input)) vocab = load_vocab(
pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input_dir)
)
ggml_type = gguf.GGMLQuantizationType[args.type.upper()] ggml_type = gguf.GGMLQuantizationType[args.type.upper()]
convert_grok(args, vocab, ggml_type) convert_grok(args, vocab, ggml_type)