diff --git a/convert_grok.py b/convert_grok.py index d196f7a7b..f3aa6d369 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -196,55 +196,47 @@ def get_dtype_and_ggml_type(tensor, ggml_type): return np.float32, gguf.GGMLQuantizationType.F32 -def dump_state_dict(f, weight_names, model_files, ggml_type, config): - keys2names = {} - meta_tensors = {} - weight_scales = {} +def dump_state_dict(f, ggml_type, input_dir, config): + weight_names = get_weight_names(config.num_hidden_layers) + weights = {} # First operate on meta tensors to find shapes and dtypes for GGUF header. - for name, fn in model_files: - weight, scales = get_weights(fn) - meta_state_dict = convert_weight(name, weight, scales, config.experts, device="meta") - weight_scales[name] = (weight, scales) - for key in meta_state_dict.keys(): - keys2names[key] = name - - meta_tensors.update(meta_state_dict) - - for key in weight_names: - meta_tensor = meta_tensors[key] + for idx, name in enumerate(weight_names): + weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000") + meta_tensor = convert_weight(name, weight, scales, config.experts, device="meta") dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type) quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type) f.add_tensor_info( - key, list(meta_tensor.shape), dtype, quantized_meta_tensor.nbytes, tensor_ggml_type + f"{name}.weight", + list(meta_tensor.shape), + dtype, + quantized_meta_tensor.nbytes, + tensor_ggml_type, ) - print("Loaded", len(meta_tensors), "files") + weights[name] = weight, scales + print("Loaded", len(weight_names), "files") f.write_header_to_file() f.write_kv_data_to_file() f.write_ti_data_to_file() - cache = {} + # Now write actual tensor data. tensor_info = [] - for key in weight_names: - if key not in cache: - name = keys2names[key] - weight, scales = weight_scales.pop(name) - state_dict = convert_weight(name, weight, scales, config.experts) - permute_tensors(state_dict, config) - cache.update(state_dict) - tensor = cache.pop(key) + for name in weight_names: + weight, scales = weights.pop(name) + tensor = convert_weight(name, weight, scales, config.experts) + tensor = maybe_permute_tensor(name, tensor, config) _, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type) array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy() print( - f"dumping {key}:", + f"dumping {name}:", f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes", ) f.write_tensor_data(array) - tensor_info.append((key, list(tensor.shape), tensor_ggml_type.name)) + tensor_info.append((name, list(tensor.shape), tensor_ggml_type.name)) try: print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql")) @@ -263,10 +255,8 @@ def from_numpy(array): return torch.from_numpy(array) -def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, device=None): +def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=None): # copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1 - result = {} - weight = from_numpy(weight).to(device=device, dtype=dtype) if scales is not None: scale = from_numpy(scales).to(device=device, dtype=dtype) @@ -279,18 +269,16 @@ def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, de weight = weight * scale # Transpose linear matrix - if len(weight.shape) >= 2 and "token_embd" not in tensor_name: + if len(weight.shape) >= 2 and "token_embd" not in name: weight = weight.transpose(-1, -2) - if tensor_name.endswith("ffn_gate_inp.weight") or tensor_name.endswith("_exps.weight"): - result[tensor_name] = weight[experts] # gather. - elif "experts" not in tensor_name: - result[tensor_name] = weight + if name.endswith("ffn_gate_inp") or name.endswith("_exps"): + weight = weight[experts] # gather. - return result + return weight -def permute_tensors(state_dict, config): +def maybe_permute_tensor(name, tensor, config): def permute(weights, n_head): return ( weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) @@ -298,11 +286,12 @@ def permute_tensors(state_dict, config): .reshape(weights.shape) ) - for name, tensor in state_dict.items(): - if name.endswith("attn_k.weight"): - state_dict[name] = permute(tensor, config.num_key_value_heads) - elif name.endswith("attn_q.weight"): - state_dict[name] = permute(tensor, config.num_attention_heads) + if name.endswith("attn_k"): + return permute(tensor, config.num_key_value_heads) + elif name.endswith("attn_q"): + return permute(tensor, config.num_attention_heads) + + return tensor def extract_vocabulary_from_model(vocab): @@ -320,25 +309,32 @@ def extract_vocabulary_from_model(vocab): return tokens, scores, toktypes -def get_weight_names(config): - weight_names = ["token_embd.weight"] - for i in range(config.num_hidden_layers): - weight_names += [ - f"blk.{i}.ffn_gate_exps.weight", - f"blk.{i}.ffn_down_exps.weight", - f"blk.{i}.ffn_up_exps.weight", - f"blk.{i}.attn_k.weight", - f"blk.{i}.attn_output.weight", - f"blk.{i}.attn_q.weight", - f"blk.{i}.attn_v.weight", - f"blk.{i}.attn_norm.weight", - f"blk.{i}.attn_output_norm.weight", - f"blk.{i}.ffn_norm.weight", - f"blk.{i}.layer_output_norm.weight", - f"blk.{i}.ffn_gate_inp.weight", - ] +def get_weight_names(num_hidden_layers=64): + """Return Grok-1 weight names, in the order in which they are in the tensor#####_000 files.""" - weight_names += ["output_norm.weight"] + weight_names = [ + gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD], + gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT_NORM], + ] + + layer = ( + gguf.MODEL_TENSOR.FFN_GATE_EXP, + gguf.MODEL_TENSOR.FFN_DOWN_EXP, + gguf.MODEL_TENSOR.FFN_UP_EXP, + gguf.MODEL_TENSOR.ATTN_K, + gguf.MODEL_TENSOR.ATTN_OUT, + gguf.MODEL_TENSOR.ATTN_Q, + gguf.MODEL_TENSOR.ATTN_V, + gguf.MODEL_TENSOR.ATTN_NORM, + gguf.MODEL_TENSOR.ATTN_OUT_NORM, + gguf.MODEL_TENSOR.FFN_NORM, + gguf.MODEL_TENSOR.LAYER_OUT_NORM, + gguf.MODEL_TENSOR.FFN_GATE_INP, + ) + + for bid in range(num_hidden_layers): + for key in layer: + weight_names.append(gguf.TENSOR_NAMES[key].format(bid=bid)) return weight_names @@ -383,28 +379,6 @@ def convert_grok(args, vocab, ggml_type): assert config.num_experts >= 2, "need at least 2 experts" print("experts to export:", config.experts) - # Contents of in Grok-1 pickle files, in order. Weights with "experts" will be split later. - tensor_names = [ - "token_embd.weight", - "output_norm.weight", - ] - for i in range(config.num_hidden_layers): - tensor_names += [ - f"blk.{i}.ffn_gate_exps.weight", - f"blk.{i}.ffn_down_exps.weight", - f"blk.{i}.ffn_up_exps.weight", - f"blk.{i}.attn_k.weight", - f"blk.{i}.attn_output.weight", - f"blk.{i}.attn_q.weight", - f"blk.{i}.attn_v.weight", - f"blk.{i}.attn_norm.weight", - f"blk.{i}.attn_output_norm.weight", - f"blk.{i}.ffn_norm.weight", - f"blk.{i}.layer_output_norm.weight", - f"blk.{i}.ffn_gate_inp.weight", - ] - - tensor_map = [(name, f"{args.input}/tensor{i:05}_000") for i, name in enumerate(tensor_names)] f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE) f.add_name("grok") @@ -430,8 +404,7 @@ def convert_grok(args, vocab, ggml_type): f.add_token_scores(scores) f.add_token_types(toktypes) - weight_names = get_weight_names(config) - dump_state_dict(f, weight_names, tensor_map, ggml_type, config) + dump_state_dict(f, ggml_type, args.input_dir, config) f.close() delta = time.time() - start @@ -465,7 +438,7 @@ def load_vocab(path): def main(): parser = argparse.ArgumentParser("convert_grok") - parser.add_argument("-i", "--input", type=str) + parser.add_argument("-i", "--input_dir", type=str) parser.add_argument("-o", "--save_path", type=pathlib.Path) parser.add_argument( "-t", "--type", type=str, default="q8_0", choices=["f32", "f16", "q8_0", "q4_0", "q4_1"] @@ -474,7 +447,9 @@ def main(): parser.add_argument("--experts", type=str, default="") args = parser.parse_args() - vocab = load_vocab(pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input)) + vocab = load_vocab( + pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input_dir) + ) ggml_type = gguf.GGMLQuantizationType[args.type.upper()] convert_grok(args, vocab, ggml_type)