Use only one list of weight names, with values from the gguf module.
This saves weights in the order in which they are in the Grok-1 files. Since we operate weight-by-weight now, we no longer need caches and name2key translations. Per reviewer request, I also moved to using keys in gguf.TENSOR_NAMES.
This commit is contained in:
parent
3c57743874
commit
08427630c3
1 changed files with 62 additions and 87 deletions
147
convert_grok.py
147
convert_grok.py
|
@ -196,55 +196,47 @@ def get_dtype_and_ggml_type(tensor, ggml_type):
|
||||||
return np.float32, gguf.GGMLQuantizationType.F32
|
return np.float32, gguf.GGMLQuantizationType.F32
|
||||||
|
|
||||||
|
|
||||||
def dump_state_dict(f, weight_names, model_files, ggml_type, config):
|
def dump_state_dict(f, ggml_type, input_dir, config):
|
||||||
keys2names = {}
|
weight_names = get_weight_names(config.num_hidden_layers)
|
||||||
meta_tensors = {}
|
weights = {}
|
||||||
weight_scales = {}
|
|
||||||
|
|
||||||
# First operate on meta tensors to find shapes and dtypes for GGUF header.
|
# First operate on meta tensors to find shapes and dtypes for GGUF header.
|
||||||
for name, fn in model_files:
|
for idx, name in enumerate(weight_names):
|
||||||
weight, scales = get_weights(fn)
|
weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000")
|
||||||
meta_state_dict = convert_weight(name, weight, scales, config.experts, device="meta")
|
meta_tensor = convert_weight(name, weight, scales, config.experts, device="meta")
|
||||||
weight_scales[name] = (weight, scales)
|
|
||||||
for key in meta_state_dict.keys():
|
|
||||||
keys2names[key] = name
|
|
||||||
|
|
||||||
meta_tensors.update(meta_state_dict)
|
|
||||||
|
|
||||||
for key in weight_names:
|
|
||||||
meta_tensor = meta_tensors[key]
|
|
||||||
dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type)
|
dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type)
|
||||||
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
|
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
|
||||||
f.add_tensor_info(
|
f.add_tensor_info(
|
||||||
key, list(meta_tensor.shape), dtype, quantized_meta_tensor.nbytes, tensor_ggml_type
|
f"{name}.weight",
|
||||||
|
list(meta_tensor.shape),
|
||||||
|
dtype,
|
||||||
|
quantized_meta_tensor.nbytes,
|
||||||
|
tensor_ggml_type,
|
||||||
)
|
)
|
||||||
print("Loaded", len(meta_tensors), "files")
|
weights[name] = weight, scales
|
||||||
|
print("Loaded", len(weight_names), "files")
|
||||||
|
|
||||||
f.write_header_to_file()
|
f.write_header_to_file()
|
||||||
f.write_kv_data_to_file()
|
f.write_kv_data_to_file()
|
||||||
f.write_ti_data_to_file()
|
f.write_ti_data_to_file()
|
||||||
|
|
||||||
cache = {}
|
# Now write actual tensor data.
|
||||||
tensor_info = []
|
tensor_info = []
|
||||||
|
|
||||||
for key in weight_names:
|
for name in weight_names:
|
||||||
if key not in cache:
|
weight, scales = weights.pop(name)
|
||||||
name = keys2names[key]
|
tensor = convert_weight(name, weight, scales, config.experts)
|
||||||
weight, scales = weight_scales.pop(name)
|
tensor = maybe_permute_tensor(name, tensor, config)
|
||||||
state_dict = convert_weight(name, weight, scales, config.experts)
|
|
||||||
permute_tensors(state_dict, config)
|
|
||||||
cache.update(state_dict)
|
|
||||||
tensor = cache.pop(key)
|
|
||||||
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
|
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
|
||||||
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()
|
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"dumping {key}:",
|
f"dumping {name}:",
|
||||||
f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes",
|
f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes",
|
||||||
)
|
)
|
||||||
f.write_tensor_data(array)
|
f.write_tensor_data(array)
|
||||||
|
|
||||||
tensor_info.append((key, list(tensor.shape), tensor_ggml_type.name))
|
tensor_info.append((name, list(tensor.shape), tensor_ggml_type.name))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql"))
|
print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql"))
|
||||||
|
@ -263,10 +255,8 @@ def from_numpy(array):
|
||||||
return torch.from_numpy(array)
|
return torch.from_numpy(array)
|
||||||
|
|
||||||
|
|
||||||
def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, device=None):
|
def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=None):
|
||||||
# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
|
# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
|
||||||
result = {}
|
|
||||||
|
|
||||||
weight = from_numpy(weight).to(device=device, dtype=dtype)
|
weight = from_numpy(weight).to(device=device, dtype=dtype)
|
||||||
if scales is not None:
|
if scales is not None:
|
||||||
scale = from_numpy(scales).to(device=device, dtype=dtype)
|
scale = from_numpy(scales).to(device=device, dtype=dtype)
|
||||||
|
@ -279,18 +269,16 @@ def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, de
|
||||||
weight = weight * scale
|
weight = weight * scale
|
||||||
|
|
||||||
# Transpose linear matrix
|
# Transpose linear matrix
|
||||||
if len(weight.shape) >= 2 and "token_embd" not in tensor_name:
|
if len(weight.shape) >= 2 and "token_embd" not in name:
|
||||||
weight = weight.transpose(-1, -2)
|
weight = weight.transpose(-1, -2)
|
||||||
|
|
||||||
if tensor_name.endswith("ffn_gate_inp.weight") or tensor_name.endswith("_exps.weight"):
|
if name.endswith("ffn_gate_inp") or name.endswith("_exps"):
|
||||||
result[tensor_name] = weight[experts] # gather.
|
weight = weight[experts] # gather.
|
||||||
elif "experts" not in tensor_name:
|
|
||||||
result[tensor_name] = weight
|
|
||||||
|
|
||||||
return result
|
return weight
|
||||||
|
|
||||||
|
|
||||||
def permute_tensors(state_dict, config):
|
def maybe_permute_tensor(name, tensor, config):
|
||||||
def permute(weights, n_head):
|
def permute(weights, n_head):
|
||||||
return (
|
return (
|
||||||
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||||
|
@ -298,11 +286,12 @@ def permute_tensors(state_dict, config):
|
||||||
.reshape(weights.shape)
|
.reshape(weights.shape)
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, tensor in state_dict.items():
|
if name.endswith("attn_k"):
|
||||||
if name.endswith("attn_k.weight"):
|
return permute(tensor, config.num_key_value_heads)
|
||||||
state_dict[name] = permute(tensor, config.num_key_value_heads)
|
elif name.endswith("attn_q"):
|
||||||
elif name.endswith("attn_q.weight"):
|
return permute(tensor, config.num_attention_heads)
|
||||||
state_dict[name] = permute(tensor, config.num_attention_heads)
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def extract_vocabulary_from_model(vocab):
|
def extract_vocabulary_from_model(vocab):
|
||||||
|
@ -320,25 +309,32 @@ def extract_vocabulary_from_model(vocab):
|
||||||
return tokens, scores, toktypes
|
return tokens, scores, toktypes
|
||||||
|
|
||||||
|
|
||||||
def get_weight_names(config):
|
def get_weight_names(num_hidden_layers=64):
|
||||||
weight_names = ["token_embd.weight"]
|
"""Return Grok-1 weight names, in the order in which they are in the tensor#####_000 files."""
|
||||||
for i in range(config.num_hidden_layers):
|
|
||||||
weight_names += [
|
weight_names = [
|
||||||
f"blk.{i}.ffn_gate_exps.weight",
|
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD],
|
||||||
f"blk.{i}.ffn_down_exps.weight",
|
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT_NORM],
|
||||||
f"blk.{i}.ffn_up_exps.weight",
|
|
||||||
f"blk.{i}.attn_k.weight",
|
|
||||||
f"blk.{i}.attn_output.weight",
|
|
||||||
f"blk.{i}.attn_q.weight",
|
|
||||||
f"blk.{i}.attn_v.weight",
|
|
||||||
f"blk.{i}.attn_norm.weight",
|
|
||||||
f"blk.{i}.attn_output_norm.weight",
|
|
||||||
f"blk.{i}.ffn_norm.weight",
|
|
||||||
f"blk.{i}.layer_output_norm.weight",
|
|
||||||
f"blk.{i}.ffn_gate_inp.weight",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
weight_names += ["output_norm.weight"]
|
layer = (
|
||||||
|
gguf.MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
|
gguf.MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
|
gguf.MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
gguf.MODEL_TENSOR.ATTN_K,
|
||||||
|
gguf.MODEL_TENSOR.ATTN_OUT,
|
||||||
|
gguf.MODEL_TENSOR.ATTN_Q,
|
||||||
|
gguf.MODEL_TENSOR.ATTN_V,
|
||||||
|
gguf.MODEL_TENSOR.ATTN_NORM,
|
||||||
|
gguf.MODEL_TENSOR.ATTN_OUT_NORM,
|
||||||
|
gguf.MODEL_TENSOR.FFN_NORM,
|
||||||
|
gguf.MODEL_TENSOR.LAYER_OUT_NORM,
|
||||||
|
gguf.MODEL_TENSOR.FFN_GATE_INP,
|
||||||
|
)
|
||||||
|
|
||||||
|
for bid in range(num_hidden_layers):
|
||||||
|
for key in layer:
|
||||||
|
weight_names.append(gguf.TENSOR_NAMES[key].format(bid=bid))
|
||||||
|
|
||||||
return weight_names
|
return weight_names
|
||||||
|
|
||||||
|
@ -383,28 +379,6 @@ def convert_grok(args, vocab, ggml_type):
|
||||||
assert config.num_experts >= 2, "need at least 2 experts"
|
assert config.num_experts >= 2, "need at least 2 experts"
|
||||||
print("experts to export:", config.experts)
|
print("experts to export:", config.experts)
|
||||||
|
|
||||||
# Contents of in Grok-1 pickle files, in order. Weights with "experts" will be split later.
|
|
||||||
tensor_names = [
|
|
||||||
"token_embd.weight",
|
|
||||||
"output_norm.weight",
|
|
||||||
]
|
|
||||||
for i in range(config.num_hidden_layers):
|
|
||||||
tensor_names += [
|
|
||||||
f"blk.{i}.ffn_gate_exps.weight",
|
|
||||||
f"blk.{i}.ffn_down_exps.weight",
|
|
||||||
f"blk.{i}.ffn_up_exps.weight",
|
|
||||||
f"blk.{i}.attn_k.weight",
|
|
||||||
f"blk.{i}.attn_output.weight",
|
|
||||||
f"blk.{i}.attn_q.weight",
|
|
||||||
f"blk.{i}.attn_v.weight",
|
|
||||||
f"blk.{i}.attn_norm.weight",
|
|
||||||
f"blk.{i}.attn_output_norm.weight",
|
|
||||||
f"blk.{i}.ffn_norm.weight",
|
|
||||||
f"blk.{i}.layer_output_norm.weight",
|
|
||||||
f"blk.{i}.ffn_gate_inp.weight",
|
|
||||||
]
|
|
||||||
|
|
||||||
tensor_map = [(name, f"{args.input}/tensor{i:05}_000") for i, name in enumerate(tensor_names)]
|
|
||||||
f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE)
|
f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE)
|
||||||
|
|
||||||
f.add_name("grok")
|
f.add_name("grok")
|
||||||
|
@ -430,8 +404,7 @@ def convert_grok(args, vocab, ggml_type):
|
||||||
f.add_token_scores(scores)
|
f.add_token_scores(scores)
|
||||||
f.add_token_types(toktypes)
|
f.add_token_types(toktypes)
|
||||||
|
|
||||||
weight_names = get_weight_names(config)
|
dump_state_dict(f, ggml_type, args.input_dir, config)
|
||||||
dump_state_dict(f, weight_names, tensor_map, ggml_type, config)
|
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
delta = time.time() - start
|
delta = time.time() - start
|
||||||
|
@ -465,7 +438,7 @@ def load_vocab(path):
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser("convert_grok")
|
parser = argparse.ArgumentParser("convert_grok")
|
||||||
parser.add_argument("-i", "--input", type=str)
|
parser.add_argument("-i", "--input_dir", type=str)
|
||||||
parser.add_argument("-o", "--save_path", type=pathlib.Path)
|
parser.add_argument("-o", "--save_path", type=pathlib.Path)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-t", "--type", type=str, default="q8_0", choices=["f32", "f16", "q8_0", "q4_0", "q4_1"]
|
"-t", "--type", type=str, default="q8_0", choices=["f32", "f16", "q8_0", "q4_0", "q4_1"]
|
||||||
|
@ -474,7 +447,9 @@ def main():
|
||||||
parser.add_argument("--experts", type=str, default="")
|
parser.add_argument("--experts", type=str, default="")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
vocab = load_vocab(pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input))
|
vocab = load_vocab(
|
||||||
|
pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input_dir)
|
||||||
|
)
|
||||||
ggml_type = gguf.GGMLQuantizationType[args.type.upper()]
|
ggml_type = gguf.GGMLQuantizationType[args.type.upper()]
|
||||||
convert_grok(args, vocab, ggml_type)
|
convert_grok(args, vocab, ggml_type)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue