From 6ddf93b286ae680aef3677dd2255fb4320072233 Mon Sep 17 00:00:00 2001 From: Heiner Date: Wed, 24 Apr 2024 23:39:06 +0200 Subject: [PATCH 01/13] Script to convert Grok-1 weights from raw JAX pickle files. --- convert_grok.py | 492 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 492 insertions(+) create mode 100644 convert_grok.py diff --git a/convert_grok.py b/convert_grok.py new file mode 100644 index 000000000..8096931f5 --- /dev/null +++ b/convert_grok.py @@ -0,0 +1,492 @@ +""" +Convert Grok-1 weights to GGUF format. + +Example invocation: + +python -m convert_grok -i path/to/grok-1/ckpt-0 --vocab_dir path/to/grok -o grok.bin -t q4_0 --experts 1,2 + +To run: + +./build/bin/main -m grok.bin -p "The answer to life the universe and everything is" -s 1 -n 3 -ngl 1 +""" + +import argparse +import mmap +import os +import pathlib +import pickletools +import sys +import time + +import ml_dtypes +import numpy as np +import torch + +try: + from tabulate import tabulate +except ModuleNotFoundError: + pass + +from convert import SentencePieceVocab + +if "NO_LOCAL_GGUF" not in os.environ: + sys.path.insert(1, str(pathlib.Path(__file__).parent / "gguf-py")) + +import gguf + +GGML_QK8_0 = 32 +GGML_QK4_0 = 32 +GGML_QK4_1 = 32 + + +# Heuristic to avoid having to fully parse pickle files. +FP32_SHAPES = {805306368: (131072, 6144), 6144: (6144,), 49152: (6144, 8)} +BF16_SHAPES = { + 262144: (8, 1, 32768), + 393216: (8, 8, 6144), + 1024: (1, 1024), + 49152: (8, 6144), + 6144: (1, 6144), +} + + +class AttributeDict(dict): + def __getattr__(self, key): + return self.__getitem__(key) if key in self else super().__getattr__(key) + + __setattr__ = dict.__setitem__ + + +def _genops(data): + view = memoryview(data) + + code2op = {ord(d.code): d for d in pickletools.opcodes} + dataops = { + "BINBYTES": pickletools.read_uint4, + "BINBYTES8": pickletools.read_uint8, + } + + while True: + pos = data.tell() + code = data.read_byte() + opcode = code2op[code] + + arg = None + if opcode.arg is not None: + if opcode.name not in dataops: + arg = opcode.arg.reader(data) + else: + size = dataops[opcode.name](data) + p = data.tell() + arg = np.frombuffer(view[p : p + size], dtype=np.uint8) + data.seek(size, 1) + + yield opcode, arg, pos + if code == ord(b"."): + break + + +def genops(fn): + """Yield (opcode, arg, pos) from for a pickle file. + + Uses mmap to avoid copies of binary data (e.g., np and JAX arrays).""" + with open(fn, "rb") as f: + yield from _genops(mmap.mmap(f.fileno(), length=0, flags=mmap.MAP_PRIVATE)) + + +def get_weights(fn): + """Returns tensor/array data in Grok pickle files, zero copy.""" + + arrays = [] + for unused_opcode, arg, unused_pos in genops(fn): + if isinstance(arg, np.ndarray): + arrays.append(arg) + + if len(arrays) == 1: + # Plain numpy array. + array = arrays[0].view(np.float32) + array = array.reshape(FP32_SHAPES[array.size]) + return array, None + elif len(arrays) == 2: + weight, scales = arrays + + scales = scales.view(ml_dtypes.bfloat16) + scales = scales.reshape(BF16_SHAPES[scales.size]) + + weight = weight.view(np.int8) + shape = list(scales.shape) + shape[-2] = -1 + weight = weight.reshape(shape) + return weight, scales + + assert len(arrays) in (1, 2) + + +def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor: + # equivalent to ggml_quantize_q8_0 in ggml.c + assert tensor.shape[1] % GGML_QK8_0 == 0 + tensor = tensor.view(-1, GGML_QK8_0) + scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1) + tensor = (tensor / scale).round().clamp(min=-128, max=127).char() + # add scale into each block + tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1) + return tensor + + +def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor: + # equivalent to ggml_quantize_q4_0 in ggml.c + assert tensor.shape[1] % GGML_QK4_0 == 0 + tensor = tensor.reshape(-1, GGML_QK4_0) + abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices + max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1) + scale = max_values / -8 + tensor = (tensor / scale + 8).round().clamp(min=0, max=15).char() + # compress two int4 weights into a int8 + tensor = tensor[:, :16] | (tensor[:, 16:] << 4) + # add scale into each block + tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1) + return tensor + + +def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor: + # equivalent to ggml_quantize_q4_1 in ggml.c + assert tensor.shape[1] % GGML_QK4_1 == 0 + tensor = tensor.view(-1, GGML_QK4_1) + abs_max_indices = tensor.max(dim=-1, keepdim=True).indices + max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1) + abs_min_indices = tensor.min(dim=-1, keepdim=True).indices + min_values = torch.take_along_dim(tensor, abs_min_indices, dim=-1) + scale = (max_values - min_values) / 15 + tensor = ((tensor - min_values) / scale).round().clamp(min=0, max=15).char() + # compress two int4 weights into a int8 + tensor = tensor[:, :16] | (tensor[:, 16:] << 4) + # add scale into each block + tensor = torch.cat( + (scale.half().view(torch.int8), min_values.half().view(torch.int8), tensor), dim=-1 + ) + return tensor + + +def maybe_quantize_tensor(tensor, ggml_type): + assert tensor.dtype == torch.float32 + + if ggml_type == gguf.GGMLQuantizationType.F32: + return tensor.float() + elif ggml_type == gguf.GGMLQuantizationType.F16: + return tensor.half() + elif ggml_type == gguf.GGMLQuantizationType.Q8_0: + return quantize_q8_0(tensor) + elif ggml_type == gguf.GGMLQuantizationType.Q4_0: + return quantize_q4_0(tensor) + elif ggml_type == gguf.GGMLQuantizationType.Q4_1: + return quantize_q4_1(tensor) + else: + raise NotImplementedError(f"Cannot quantize tensor of dtype {tensor.dtype} ({ggml_type})") + + +def get_dtype_and_ggml_type(tensor, ggml_type): + if tensor.ndim == 2: + if tensor.shape[1] % GGML_QK8_0 == 0: + return np.int8, ggml_type + else: + return np.float16, gguf.GGMLQuantizationType.F16 + else: + # 1d weight: convert it to float32 + assert tensor.ndim == 1 + return np.float32, gguf.GGMLQuantizationType.F32 + + +def dump_state_dict(f, weight_names, model_files, ggml_type, config): + keys2names = {} + meta_tensors = {} + weight_scales = {} + + # First operate on meta tensors to find shapes and dtypes for GGUF header. + for name, fn in model_files: + weight, scales = get_weights(fn) + meta_state_dict = convert_weight(name, weight, scales, config.experts, device="meta") + weight_scales[name] = (weight, scales) + for key in meta_state_dict.keys(): + keys2names[key] = name + + meta_tensors.update(meta_state_dict) + + for key in weight_names: + meta_tensor = meta_tensors[key] + dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type) + quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type) + f.add_tensor_info( + key, list(meta_tensor.shape), dtype, quantized_meta_tensor.nbytes, tensor_ggml_type + ) + print("Loaded", len(meta_tensors), "files") + + f.write_header_to_file() + f.write_kv_data_to_file() + f.write_ti_data_to_file() + + cache = {} + tensor_info = [] + + for key in weight_names: + if key not in cache: + name = keys2names[key] + weight, scales = weight_scales.pop(name) + state_dict = convert_weight(name, weight, scales, config.experts) + permute_tensors(state_dict, config) + cache.update(state_dict) + tensor = cache.pop(key) + _, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type) + tensor = maybe_quantize_tensor(tensor, tensor_ggml_type) + + array = tensor.numpy() + print( + f"dumping {key}: {tensor_ggml_type.name}/{array.dtype}, {array.shape}, {array.nbytes} bytes" + ) + f.write_tensor_data(array) + + tensor_info.append((key, tensor.shape, tensor_ggml_type.name)) + + try: + print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql")) + except NameError: + pass + + if len(tensor_info) != len(weight_names): + print("Warning: not all tensors are converted") + + +def from_numpy(array): + """Like torch.from_numpy, but handle ml_dtypes.bfloat16 too.""" + + if array.dtype == ml_dtypes.bfloat16: + return torch.from_numpy(array.view(np.uint8)).view(torch.bfloat16) + return torch.from_numpy(array) + + +def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, device=None): + # copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1 + result = {} + + weight = from_numpy(weight).to(device=device, dtype=dtype) + if scales is not None: + scale = from_numpy(scales).to(device=device, dtype=dtype) + # row parallel layers have sharded scale + if len(scale.shape) >= 2 and scale.shape[-2] != 1: + scale = scale[..., None, :] + weight = weight.view(*weight.shape[:-2], 8, -1, weight.shape[-1]) + weight = (weight * scale).view(*weight.shape[:-3], -1, weight.shape[-1]) + else: + weight = weight * scale + + # Transpose linear matrix + if len(weight.shape) >= 2 and "token_embd" not in tensor_name: + weight = weight.transpose(-1, -2) + + if tensor_name.endswith("ffn_gate_inp.weight"): + result[tensor_name] = weight[experts] # gather. + elif "experts" not in tensor_name: + result[tensor_name] = weight + else: + # split moe + for i, expert in enumerate(experts): + key = tensor_name.replace("experts", str(i)) + result[key] = weight[expert] + + return result + + +def permute_tensors(state_dict, config): + def permute(weights, n_head): + return ( + weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape) + ) + + for name, tensor in state_dict.items(): + if name.endswith("attn_k.weight"): + state_dict[name] = permute(tensor, config.num_key_value_heads) + elif name.endswith("attn_q.weight"): + state_dict[name] = permute(tensor, config.num_attention_heads) + + +def extract_vocabulary_from_model(vocab): + tokens = [] + scores = [] + toktypes = [] + + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + assert len(tokens) == vocab.vocab_size + + return tokens, scores, toktypes + + +def get_weight_names(config): + weight_names = ["token_embd.weight"] + for i in range(config.num_hidden_layers): + for j in range(config.num_experts): + weight_names += [ + f"blk.{i}.ffn_gate.{j}.weight", + f"blk.{i}.ffn_down.{j}.weight", + f"blk.{i}.ffn_up.{j}.weight", + ] + + weight_names += [ + f"blk.{i}.attn_k.weight", + f"blk.{i}.attn_output.weight", + f"blk.{i}.attn_q.weight", + f"blk.{i}.attn_v.weight", + f"blk.{i}.attn_norm.weight", + f"blk.{i}.attn_output_norm.weight", + f"blk.{i}.ffn_norm.weight", + f"blk.{i}.layer_output_norm.weight", + f"blk.{i}.ffn_gate_inp.weight", + ] + + weight_names += ["output_norm.weight"] + + return weight_names + + +def convert_grok(args, vocab, ggml_type): + start = time.time() + + def ffn_size(emb_size, widening_factor): + _ffn_size = int(widening_factor * emb_size) * 2 // 3 + _ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8 + return _ffn_size + + config = { + "vocab_size": 128 * 1024, + "hidden_act": "gelu", + "pad_token_id": 0, + "eos_token_id": 2, + "max_position_embeddings": 8192, + "output_multiplier_scale": 0.5773502691896257, + "embedding_multiplier_scale": 78.38367176906169, + "hidden_size": 48 * 128, + "intermediate_size": -1, + "num_attention_heads": 48, + "num_key_value_heads": 8, + "num_hidden_layers": 64, # Change to 1 for quicker debugging. + "num_selected_experts": 2, + "rope_theta": 10000, + "attn_output_multiplier": 0.08838834764831845, + "rms_norm_eps": 1e-5, + } + + config = AttributeDict(config) + + config.intermediate_size = ffn_size(config.hidden_size, 8) + + config.experts = list(range(8)) + if args.experts != "": + config.experts = [int(x, 0) for x in args.experts.split(",")] + + config.num_experts = len(config.experts) + + assert config.num_experts >= 2, "need at least 2 experts" + print("experts to export:", config.experts) + + # Contents of in Grok-1 pickle files, in order. Weights with "experts" will be split later. + tensor_names = [ + "token_embd.weight", + "output_norm.weight", + ] + for i in range(config.num_hidden_layers): + tensor_names += [ + f"blk.{i}.ffn_gate.experts.weight", + f"blk.{i}.ffn_down.experts.weight", + f"blk.{i}.ffn_up.experts.weight", + f"blk.{i}.attn_k.weight", + f"blk.{i}.attn_output.weight", + f"blk.{i}.attn_q.weight", + f"blk.{i}.attn_v.weight", + f"blk.{i}.attn_norm.weight", + f"blk.{i}.attn_output_norm.weight", + f"blk.{i}.ffn_norm.weight", + f"blk.{i}.layer_output_norm.weight", + f"blk.{i}.ffn_gate_inp.weight", + ] + + tensor_map = [(name, f"{args.input}/tensor{i:05}_000") for i, name in enumerate(tensor_names)] + f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE) + + f.add_name("grok") + f.add_vocab_size(config.vocab_size) + f.add_context_length(config.max_position_embeddings) + f.add_embedding_length(config.hidden_size) + f.add_block_count(config.num_hidden_layers) + f.add_feed_forward_length(config.intermediate_size) + f.add_rope_dimension_count(config.hidden_size // config.num_attention_heads) + f.add_head_count(config.num_attention_heads) + f.add_head_count_kv(config.num_key_value_heads) + + f.add_expert_count(config.num_experts) + f.add_expert_used_count(config.num_selected_experts) + f.add_layer_norm_rms_eps(config.rms_norm_eps) + + f.add_rope_freq_base(config.rope_theta) + + f.add_tokenizer_model("llama") + # Extract model vocabulary for model conversion + tokens, scores, toktypes = extract_vocabulary_from_model(vocab) + f.add_token_list(tokens) + f.add_token_scores(scores) + f.add_token_types(toktypes) + + weight_names = get_weight_names(config) + dump_state_dict(f, weight_names, tensor_map, ggml_type, config) + f.close() + + delta = time.time() - start + + print(f"grok GGUF model saved to {args.save_path}. Total time {delta:.2f} sec") + + +def load_vocab(path): + def load_spm(p): + print(f"Loading vocab file {p}") + return SentencePieceVocab(p) + + # Be extra-friendly and accept either a file or a directory. Also, if it's + # a directory, it might be the model directory, and tokenizer.model might + # be in the parent of that. + if path.is_dir(): + path2 = path / "tokenizer.model" + # Use `.parent` instead of /.. to handle the symlink case better. + path3 = path.parent / "tokenizer.model" + + if path2.exists(): + return load_spm(path2) + elif path3.exists(): + return load_spm(path3) + + raise FileNotFoundError( + f"Could not find tokenizer.model in {path} or its parent; " + "if it's in another directory, pass the directory as --vocab-dir" + ) + + +def main(): + parser = argparse.ArgumentParser("convert_grok") + parser.add_argument("-i", "--input", type=str) + parser.add_argument("-o", "--save_path", type=pathlib.Path) + parser.add_argument( + "-t", "--type", type=str, default="q8_0", choices=["f32", "f16", "q8_0", "q4_0", "q4_1"] + ) + parser.add_argument("--vocab_dir", type=str, default="") + parser.add_argument("--experts", type=str, default="") + args = parser.parse_args() + + vocab = load_vocab(pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input)) + ggml_type = gguf.GGMLQuantizationType[args.type.upper()] + convert_grok(args, vocab, ggml_type) + + +if __name__ == "__main__": + main() From 3c577438743cdd8a94efa0684ca9a288b38a1d95 Mon Sep 17 00:00:00 2001 From: Heiner Date: Fri, 3 May 2024 22:38:05 +0200 Subject: [PATCH 02/13] Don't split MoE weights. As per https://github.com/ggerganov/llama.cpp/pull/7058#issuecomment-2092967508. This helps avoid a memcopy when running. --- convert_grok.py | 35 +++++++++++++---------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/convert_grok.py b/convert_grok.py index 8096931f5..d196f7a7b 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -185,14 +185,14 @@ def maybe_quantize_tensor(tensor, ggml_type): def get_dtype_and_ggml_type(tensor, ggml_type): - if tensor.ndim == 2: + if tensor.ndim in (2, 3): if tensor.shape[1] % GGML_QK8_0 == 0: return np.int8, ggml_type else: return np.float16, gguf.GGMLQuantizationType.F16 else: # 1d weight: convert it to float32 - assert tensor.ndim == 1 + assert tensor.ndim == 1, tensor return np.float32, gguf.GGMLQuantizationType.F32 @@ -236,15 +236,15 @@ def dump_state_dict(f, weight_names, model_files, ggml_type, config): cache.update(state_dict) tensor = cache.pop(key) _, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type) - tensor = maybe_quantize_tensor(tensor, tensor_ggml_type) + array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy() - array = tensor.numpy() print( - f"dumping {key}: {tensor_ggml_type.name}/{array.dtype}, {array.shape}, {array.nbytes} bytes" + f"dumping {key}:", + f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes", ) f.write_tensor_data(array) - tensor_info.append((key, tensor.shape, tensor_ggml_type.name)) + tensor_info.append((key, list(tensor.shape), tensor_ggml_type.name)) try: print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql")) @@ -282,15 +282,10 @@ def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, de if len(weight.shape) >= 2 and "token_embd" not in tensor_name: weight = weight.transpose(-1, -2) - if tensor_name.endswith("ffn_gate_inp.weight"): + if tensor_name.endswith("ffn_gate_inp.weight") or tensor_name.endswith("_exps.weight"): result[tensor_name] = weight[experts] # gather. elif "experts" not in tensor_name: result[tensor_name] = weight - else: - # split moe - for i, expert in enumerate(experts): - key = tensor_name.replace("experts", str(i)) - result[key] = weight[expert] return result @@ -328,14 +323,10 @@ def extract_vocabulary_from_model(vocab): def get_weight_names(config): weight_names = ["token_embd.weight"] for i in range(config.num_hidden_layers): - for j in range(config.num_experts): - weight_names += [ - f"blk.{i}.ffn_gate.{j}.weight", - f"blk.{i}.ffn_down.{j}.weight", - f"blk.{i}.ffn_up.{j}.weight", - ] - weight_names += [ + f"blk.{i}.ffn_gate_exps.weight", + f"blk.{i}.ffn_down_exps.weight", + f"blk.{i}.ffn_up_exps.weight", f"blk.{i}.attn_k.weight", f"blk.{i}.attn_output.weight", f"blk.{i}.attn_q.weight", @@ -399,9 +390,9 @@ def convert_grok(args, vocab, ggml_type): ] for i in range(config.num_hidden_layers): tensor_names += [ - f"blk.{i}.ffn_gate.experts.weight", - f"blk.{i}.ffn_down.experts.weight", - f"blk.{i}.ffn_up.experts.weight", + f"blk.{i}.ffn_gate_exps.weight", + f"blk.{i}.ffn_down_exps.weight", + f"blk.{i}.ffn_up_exps.weight", f"blk.{i}.attn_k.weight", f"blk.{i}.attn_output.weight", f"blk.{i}.attn_q.weight", From 08427630c30bb8c4c3fb46decab7a6c481267a74 Mon Sep 17 00:00:00 2001 From: Heiner Date: Fri, 3 May 2024 23:44:47 +0200 Subject: [PATCH 03/13] Use only one list of weight names, with values from the gguf module. This saves weights in the order in which they are in the Grok-1 files. Since we operate weight-by-weight now, we no longer need caches and name2key translations. Per reviewer request, I also moved to using keys in gguf.TENSOR_NAMES. --- convert_grok.py | 149 ++++++++++++++++++++---------------------------- 1 file changed, 62 insertions(+), 87 deletions(-) diff --git a/convert_grok.py b/convert_grok.py index d196f7a7b..f3aa6d369 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -196,55 +196,47 @@ def get_dtype_and_ggml_type(tensor, ggml_type): return np.float32, gguf.GGMLQuantizationType.F32 -def dump_state_dict(f, weight_names, model_files, ggml_type, config): - keys2names = {} - meta_tensors = {} - weight_scales = {} +def dump_state_dict(f, ggml_type, input_dir, config): + weight_names = get_weight_names(config.num_hidden_layers) + weights = {} # First operate on meta tensors to find shapes and dtypes for GGUF header. - for name, fn in model_files: - weight, scales = get_weights(fn) - meta_state_dict = convert_weight(name, weight, scales, config.experts, device="meta") - weight_scales[name] = (weight, scales) - for key in meta_state_dict.keys(): - keys2names[key] = name - - meta_tensors.update(meta_state_dict) - - for key in weight_names: - meta_tensor = meta_tensors[key] + for idx, name in enumerate(weight_names): + weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000") + meta_tensor = convert_weight(name, weight, scales, config.experts, device="meta") dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type) quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type) f.add_tensor_info( - key, list(meta_tensor.shape), dtype, quantized_meta_tensor.nbytes, tensor_ggml_type + f"{name}.weight", + list(meta_tensor.shape), + dtype, + quantized_meta_tensor.nbytes, + tensor_ggml_type, ) - print("Loaded", len(meta_tensors), "files") + weights[name] = weight, scales + print("Loaded", len(weight_names), "files") f.write_header_to_file() f.write_kv_data_to_file() f.write_ti_data_to_file() - cache = {} + # Now write actual tensor data. tensor_info = [] - for key in weight_names: - if key not in cache: - name = keys2names[key] - weight, scales = weight_scales.pop(name) - state_dict = convert_weight(name, weight, scales, config.experts) - permute_tensors(state_dict, config) - cache.update(state_dict) - tensor = cache.pop(key) + for name in weight_names: + weight, scales = weights.pop(name) + tensor = convert_weight(name, weight, scales, config.experts) + tensor = maybe_permute_tensor(name, tensor, config) _, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type) array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy() print( - f"dumping {key}:", + f"dumping {name}:", f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes", ) f.write_tensor_data(array) - tensor_info.append((key, list(tensor.shape), tensor_ggml_type.name)) + tensor_info.append((name, list(tensor.shape), tensor_ggml_type.name)) try: print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql")) @@ -263,10 +255,8 @@ def from_numpy(array): return torch.from_numpy(array) -def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, device=None): +def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=None): # copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1 - result = {} - weight = from_numpy(weight).to(device=device, dtype=dtype) if scales is not None: scale = from_numpy(scales).to(device=device, dtype=dtype) @@ -279,18 +269,16 @@ def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, de weight = weight * scale # Transpose linear matrix - if len(weight.shape) >= 2 and "token_embd" not in tensor_name: + if len(weight.shape) >= 2 and "token_embd" not in name: weight = weight.transpose(-1, -2) - if tensor_name.endswith("ffn_gate_inp.weight") or tensor_name.endswith("_exps.weight"): - result[tensor_name] = weight[experts] # gather. - elif "experts" not in tensor_name: - result[tensor_name] = weight + if name.endswith("ffn_gate_inp") or name.endswith("_exps"): + weight = weight[experts] # gather. - return result + return weight -def permute_tensors(state_dict, config): +def maybe_permute_tensor(name, tensor, config): def permute(weights, n_head): return ( weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) @@ -298,11 +286,12 @@ def permute_tensors(state_dict, config): .reshape(weights.shape) ) - for name, tensor in state_dict.items(): - if name.endswith("attn_k.weight"): - state_dict[name] = permute(tensor, config.num_key_value_heads) - elif name.endswith("attn_q.weight"): - state_dict[name] = permute(tensor, config.num_attention_heads) + if name.endswith("attn_k"): + return permute(tensor, config.num_key_value_heads) + elif name.endswith("attn_q"): + return permute(tensor, config.num_attention_heads) + + return tensor def extract_vocabulary_from_model(vocab): @@ -320,25 +309,32 @@ def extract_vocabulary_from_model(vocab): return tokens, scores, toktypes -def get_weight_names(config): - weight_names = ["token_embd.weight"] - for i in range(config.num_hidden_layers): - weight_names += [ - f"blk.{i}.ffn_gate_exps.weight", - f"blk.{i}.ffn_down_exps.weight", - f"blk.{i}.ffn_up_exps.weight", - f"blk.{i}.attn_k.weight", - f"blk.{i}.attn_output.weight", - f"blk.{i}.attn_q.weight", - f"blk.{i}.attn_v.weight", - f"blk.{i}.attn_norm.weight", - f"blk.{i}.attn_output_norm.weight", - f"blk.{i}.ffn_norm.weight", - f"blk.{i}.layer_output_norm.weight", - f"blk.{i}.ffn_gate_inp.weight", - ] +def get_weight_names(num_hidden_layers=64): + """Return Grok-1 weight names, in the order in which they are in the tensor#####_000 files.""" - weight_names += ["output_norm.weight"] + weight_names = [ + gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD], + gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT_NORM], + ] + + layer = ( + gguf.MODEL_TENSOR.FFN_GATE_EXP, + gguf.MODEL_TENSOR.FFN_DOWN_EXP, + gguf.MODEL_TENSOR.FFN_UP_EXP, + gguf.MODEL_TENSOR.ATTN_K, + gguf.MODEL_TENSOR.ATTN_OUT, + gguf.MODEL_TENSOR.ATTN_Q, + gguf.MODEL_TENSOR.ATTN_V, + gguf.MODEL_TENSOR.ATTN_NORM, + gguf.MODEL_TENSOR.ATTN_OUT_NORM, + gguf.MODEL_TENSOR.FFN_NORM, + gguf.MODEL_TENSOR.LAYER_OUT_NORM, + gguf.MODEL_TENSOR.FFN_GATE_INP, + ) + + for bid in range(num_hidden_layers): + for key in layer: + weight_names.append(gguf.TENSOR_NAMES[key].format(bid=bid)) return weight_names @@ -383,28 +379,6 @@ def convert_grok(args, vocab, ggml_type): assert config.num_experts >= 2, "need at least 2 experts" print("experts to export:", config.experts) - # Contents of in Grok-1 pickle files, in order. Weights with "experts" will be split later. - tensor_names = [ - "token_embd.weight", - "output_norm.weight", - ] - for i in range(config.num_hidden_layers): - tensor_names += [ - f"blk.{i}.ffn_gate_exps.weight", - f"blk.{i}.ffn_down_exps.weight", - f"blk.{i}.ffn_up_exps.weight", - f"blk.{i}.attn_k.weight", - f"blk.{i}.attn_output.weight", - f"blk.{i}.attn_q.weight", - f"blk.{i}.attn_v.weight", - f"blk.{i}.attn_norm.weight", - f"blk.{i}.attn_output_norm.weight", - f"blk.{i}.ffn_norm.weight", - f"blk.{i}.layer_output_norm.weight", - f"blk.{i}.ffn_gate_inp.weight", - ] - - tensor_map = [(name, f"{args.input}/tensor{i:05}_000") for i, name in enumerate(tensor_names)] f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE) f.add_name("grok") @@ -430,8 +404,7 @@ def convert_grok(args, vocab, ggml_type): f.add_token_scores(scores) f.add_token_types(toktypes) - weight_names = get_weight_names(config) - dump_state_dict(f, weight_names, tensor_map, ggml_type, config) + dump_state_dict(f, ggml_type, args.input_dir, config) f.close() delta = time.time() - start @@ -465,7 +438,7 @@ def load_vocab(path): def main(): parser = argparse.ArgumentParser("convert_grok") - parser.add_argument("-i", "--input", type=str) + parser.add_argument("-i", "--input_dir", type=str) parser.add_argument("-o", "--save_path", type=pathlib.Path) parser.add_argument( "-t", "--type", type=str, default="q8_0", choices=["f32", "f16", "q8_0", "q4_0", "q4_1"] @@ -474,7 +447,9 @@ def main(): parser.add_argument("--experts", type=str, default="") args = parser.parse_args() - vocab = load_vocab(pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input)) + vocab = load_vocab( + pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input_dir) + ) ggml_type = gguf.GGMLQuantizationType[args.type.upper()] convert_grok(args, vocab, ggml_type) From 5bc4f10ee97abf38730d33fa88ca9d25066670d3 Mon Sep 17 00:00:00 2001 From: Brian Date: Fri, 10 May 2024 01:17:56 +1000 Subject: [PATCH 04/13] Update convert_grok.py to use logging module --- convert_grok.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/convert_grok.py b/convert_grok.py index f3aa6d369..aba0793b2 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -34,6 +34,8 @@ if "NO_LOCAL_GGUF" not in os.environ: import gguf +logger = logging.getLogger("convert_grok") + GGML_QK8_0 = 32 GGML_QK4_0 = 32 GGML_QK4_1 = 32 @@ -214,7 +216,7 @@ def dump_state_dict(f, ggml_type, input_dir, config): tensor_ggml_type, ) weights[name] = weight, scales - print("Loaded", len(weight_names), "files") + logger.info("Loaded", len(weight_names), "files") f.write_header_to_file() f.write_kv_data_to_file() @@ -230,7 +232,7 @@ def dump_state_dict(f, ggml_type, input_dir, config): _, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type) array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy() - print( + logger.debug( f"dumping {name}:", f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes", ) @@ -239,12 +241,12 @@ def dump_state_dict(f, ggml_type, input_dir, config): tensor_info.append((name, list(tensor.shape), tensor_ggml_type.name)) try: - print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql")) + print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql")) # noqa: NP100 except NameError: pass if len(tensor_info) != len(weight_names): - print("Warning: not all tensors are converted") + logger.warning("Not all tensors are converted") def from_numpy(array): @@ -377,7 +379,7 @@ def convert_grok(args, vocab, ggml_type): config.num_experts = len(config.experts) assert config.num_experts >= 2, "need at least 2 experts" - print("experts to export:", config.experts) + logger.info("experts to export:", config.experts) f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE) @@ -409,12 +411,12 @@ def convert_grok(args, vocab, ggml_type): delta = time.time() - start - print(f"grok GGUF model saved to {args.save_path}. Total time {delta:.2f} sec") + logger.info(f"grok GGUF model saved to {args.save_path}. Total time {delta:.2f} sec") def load_vocab(path): def load_spm(p): - print(f"Loading vocab file {p}") + logger.info(f"Loading vocab file {p}") return SentencePieceVocab(p) # Be extra-friendly and accept either a file or a directory. Also, if it's @@ -445,8 +447,12 @@ def main(): ) parser.add_argument("--vocab_dir", type=str, default="") parser.add_argument("--experts", type=str, default="") + parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + args = parser.parse_args() + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + vocab = load_vocab( pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input_dir) ) From d894497a9622fd1b5cac8dcb3ce8dba051641273 Mon Sep 17 00:00:00 2001 From: Heiner Date: Thu, 9 May 2024 17:12:07 +0200 Subject: [PATCH 05/13] Move print to logging: Fixes. --- convert_grok.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/convert_grok.py b/convert_grok.py index aba0793b2..d71d647dd 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -11,6 +11,7 @@ To run: """ import argparse +import logging import mmap import os import pathlib @@ -34,8 +35,6 @@ if "NO_LOCAL_GGUF" not in os.environ: import gguf -logger = logging.getLogger("convert_grok") - GGML_QK8_0 = 32 GGML_QK4_0 = 32 GGML_QK4_1 = 32 @@ -216,7 +215,7 @@ def dump_state_dict(f, ggml_type, input_dir, config): tensor_ggml_type, ) weights[name] = weight, scales - logger.info("Loaded", len(weight_names), "files") + logging.debug("Loaded %i files", len(weight_names)) f.write_header_to_file() f.write_kv_data_to_file() @@ -232,21 +231,23 @@ def dump_state_dict(f, ggml_type, input_dir, config): _, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type) array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy() - logger.debug( - f"dumping {name}:", - f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes", + logging.info( + f"dumping {name}:" + f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes" ) f.write_tensor_data(array) tensor_info.append((name, list(tensor.shape), tensor_ggml_type.name)) try: - print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql")) # noqa: NP100 + print( + tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql") + ) # noqa: NP100 except NameError: pass if len(tensor_info) != len(weight_names): - logger.warning("Not all tensors are converted") + logging.warning("Not all tensors are converted") def from_numpy(array): @@ -379,7 +380,7 @@ def convert_grok(args, vocab, ggml_type): config.num_experts = len(config.experts) assert config.num_experts >= 2, "need at least 2 experts" - logger.info("experts to export:", config.experts) + logging.info("experts to export: %s", config.experts) f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE) @@ -411,12 +412,12 @@ def convert_grok(args, vocab, ggml_type): delta = time.time() - start - logger.info(f"grok GGUF model saved to {args.save_path}. Total time {delta:.2f} sec") + logging.info(f"grok GGUF model saved to {args.save_path}. Total time {delta:.2f} sec") def load_vocab(path): def load_spm(p): - logger.info(f"Loading vocab file {p}") + logging.info(f"Loading vocab file {p}") return SentencePieceVocab(p) # Be extra-friendly and accept either a file or a directory. Also, if it's @@ -452,7 +453,7 @@ def main(): args = parser.parse_args() logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) - + vocab = load_vocab( pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input_dir) ) From ef671c693d99dc7e460f54aa69503f01cd45f84d Mon Sep 17 00:00:00 2001 From: Heiner Date: Thu, 9 May 2024 22:43:38 +0200 Subject: [PATCH 06/13] Address review comments by foldl. --- convert_grok.py | 32 +++++++++----------------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/convert_grok.py b/convert_grok.py index d71d647dd..7d0dbfc77 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -204,7 +204,7 @@ def dump_state_dict(f, ggml_type, input_dir, config): # First operate on meta tensors to find shapes and dtypes for GGUF header. for idx, name in enumerate(weight_names): weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000") - meta_tensor = convert_weight(name, weight, scales, config.experts, device="meta") + meta_tensor = convert_weight(name, weight, scales, config, device="meta") dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type) quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type) f.add_tensor_info( @@ -226,8 +226,7 @@ def dump_state_dict(f, ggml_type, input_dir, config): for name in weight_names: weight, scales = weights.pop(name) - tensor = convert_weight(name, weight, scales, config.experts) - tensor = maybe_permute_tensor(name, tensor, config) + tensor = convert_weight(name, weight, scales, config) _, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type) array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy() @@ -258,7 +257,7 @@ def from_numpy(array): return torch.from_numpy(array) -def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=None): +def convert_weight(name, weight, scales, config, dtype=torch.float32, device=None): # copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1 weight = from_numpy(weight).to(device=device, dtype=dtype) if scales is not None: @@ -271,32 +270,19 @@ def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=No else: weight = weight * scale - # Transpose linear matrix - if len(weight.shape) >= 2 and "token_embd" not in name: + if name == "token_embd": + weight *= config.embedding_multiplier_scale + elif len(weight.shape) >= 2: + # Transpose linear matrix weight = weight.transpose(-1, -2) + if name.endswith("ffn_gate_inp") or name.endswith("_exps"): - weight = weight[experts] # gather. + weight = weight[config.experts] # gather. return weight -def maybe_permute_tensor(name, tensor, config): - def permute(weights, n_head): - return ( - weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape) - ) - - if name.endswith("attn_k"): - return permute(tensor, config.num_key_value_heads) - elif name.endswith("attn_q"): - return permute(tensor, config.num_attention_heads) - - return tensor - - def extract_vocabulary_from_model(vocab): tokens = [] scores = [] From 9a0629d5456b8b50f36cfabc6c5c2b4f5d17401c Mon Sep 17 00:00:00 2001 From: Heiner Date: Fri, 10 May 2024 12:40:05 +0200 Subject: [PATCH 07/13] Don't multiply embeddings with embedding_multiplier_scale as it happens in llama.cpp. --- convert_grok.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/convert_grok.py b/convert_grok.py index 7d0dbfc77..4475da3c1 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -270,13 +270,9 @@ def convert_weight(name, weight, scales, config, dtype=torch.float32, device=Non else: weight = weight * scale - if name == "token_embd": - weight *= config.embedding_multiplier_scale - elif len(weight.shape) >= 2: + if name != "token_embd" and len(weight.shape) >= 2: # Transpose linear matrix weight = weight.transpose(-1, -2) - - if name.endswith("ffn_gate_inp") or name.endswith("_exps"): weight = weight[config.experts] # gather. From f177b6596cdd8602715ba948393c2e53a7961c64 Mon Sep 17 00:00:00 2001 From: Heiner Date: Tue, 21 May 2024 23:17:14 +0200 Subject: [PATCH 08/13] Fix layer order. --- convert_grok.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/convert_grok.py b/convert_grok.py index 4475da3c1..514692718 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -126,7 +126,7 @@ def get_weights(fn): def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor: # equivalent to ggml_quantize_q8_0 in ggml.c assert tensor.shape[1] % GGML_QK8_0 == 0 - tensor = tensor.view(-1, GGML_QK8_0) + tensor = tensor.reshape(-1, GGML_QK8_0) scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1) tensor = (tensor / scale).round().clamp(min=-128, max=127).char() # add scale into each block @@ -152,7 +152,7 @@ def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor: def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor: # equivalent to ggml_quantize_q4_1 in ggml.c assert tensor.shape[1] % GGML_QK4_1 == 0 - tensor = tensor.view(-1, GGML_QK4_1) + tensor = tensor.reshape(-1, GGML_QK4_1) abs_max_indices = tensor.max(dim=-1, keepdim=True).indices max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1) abs_min_indices = tensor.min(dim=-1, keepdim=True).indices @@ -185,15 +185,13 @@ def maybe_quantize_tensor(tensor, ggml_type): raise NotImplementedError(f"Cannot quantize tensor of dtype {tensor.dtype} ({ggml_type})") -def get_dtype_and_ggml_type(tensor, ggml_type): - if tensor.ndim in (2, 3): +def get_dtype_and_ggml_type(name, tensor, ggml_type): + if tensor.ndim in (2, 3) and "ffn_gate_inp" not in name: if tensor.shape[1] % GGML_QK8_0 == 0: return np.int8, ggml_type else: return np.float16, gguf.GGMLQuantizationType.F16 else: - # 1d weight: convert it to float32 - assert tensor.ndim == 1, tensor return np.float32, gguf.GGMLQuantizationType.F32 @@ -205,7 +203,7 @@ def dump_state_dict(f, ggml_type, input_dir, config): for idx, name in enumerate(weight_names): weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000") meta_tensor = convert_weight(name, weight, scales, config, device="meta") - dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type) + dtype, tensor_ggml_type = get_dtype_and_ggml_type(name, meta_tensor, ggml_type) quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type) f.add_tensor_info( f"{name}.weight", @@ -227,7 +225,7 @@ def dump_state_dict(f, ggml_type, input_dir, config): for name in weight_names: weight, scales = weights.pop(name) tensor = convert_weight(name, weight, scales, config) - _, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type) + _, tensor_ggml_type = get_dtype_and_ggml_type(name, tensor, ggml_type) array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy() logging.info( @@ -317,7 +315,10 @@ def get_weight_names(num_hidden_layers=64): gguf.MODEL_TENSOR.FFN_GATE_INP, ) - for bid in range(num_hidden_layers): + layers = [str(bid) for bid in range(64)] + layers.sort() # Lexicographic sort: 0 < 1 < 10 < 11 ... < 2 < 20 < ... + + for bid in layers[:num_hidden_layers]: for key in layer: weight_names.append(gguf.TENSOR_NAMES[key].format(bid=bid)) @@ -333,7 +334,6 @@ def convert_grok(args, vocab, ggml_type): return _ffn_size config = { - "vocab_size": 128 * 1024, "hidden_act": "gelu", "pad_token_id": 0, "eos_token_id": 2, @@ -366,8 +366,7 @@ def convert_grok(args, vocab, ggml_type): f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE) - f.add_name("grok") - f.add_vocab_size(config.vocab_size) + f.add_name("grok-1") f.add_context_length(config.max_position_embeddings) f.add_embedding_length(config.hidden_size) f.add_block_count(config.num_hidden_layers) @@ -389,6 +388,8 @@ def convert_grok(args, vocab, ggml_type): f.add_token_scores(scores) f.add_token_types(toktypes) + f.add_quantization_version(ggml_type) + dump_state_dict(f, ggml_type, args.input_dir, config) f.close() From e2f13a3346a08cd27bc2c1731d2e404618120ad4 Mon Sep 17 00:00:00 2001 From: Heiner Date: Thu, 23 May 2024 11:10:59 +0200 Subject: [PATCH 09/13] Use Q8_0 quantization from gguf module. This makes tensors exactly as in https://huggingface.co/Arki05/Grok-1-GGUF/tree/main/Q8_0 --- convert_grok.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/convert_grok.py b/convert_grok.py index 514692718..35ce46cd1 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -124,7 +124,7 @@ def get_weights(fn): def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor: - # equivalent to ggml_quantize_q8_0 in ggml.c + # equivalent to ggml_quantize_q8_0 in ggml.c (modulo rounding away from zero) assert tensor.shape[1] % GGML_QK8_0 == 0 tensor = tensor.reshape(-1, GGML_QK8_0) scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1) @@ -135,7 +135,7 @@ def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor: def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor: - # equivalent to ggml_quantize_q4_0 in ggml.c + # equivalent to ggml_quantize_q4_0 in ggml.c (modulo rounding away from zero) assert tensor.shape[1] % GGML_QK4_0 == 0 tensor = tensor.reshape(-1, GGML_QK4_0) abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices @@ -150,7 +150,7 @@ def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor: def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor: - # equivalent to ggml_quantize_q4_1 in ggml.c + # equivalent to ggml_quantize_q4_1 in ggml.c (modulo rounding away from zero) assert tensor.shape[1] % GGML_QK4_1 == 0 tensor = tensor.reshape(-1, GGML_QK4_1) abs_max_indices = tensor.max(dim=-1, keepdim=True).indices @@ -170,13 +170,14 @@ def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor: def maybe_quantize_tensor(tensor, ggml_type): assert tensor.dtype == torch.float32 - if ggml_type == gguf.GGMLQuantizationType.F32: return tensor.float() elif ggml_type == gguf.GGMLQuantizationType.F16: return tensor.half() elif ggml_type == gguf.GGMLQuantizationType.Q8_0: - return quantize_q8_0(tensor) + if tensor.device.type == "meta": + return quantize_q8_0(tensor) # Cannot convert into numpy array. + return torch.from_numpy(gguf.quantize_q8_0(tensor.numpy())) elif ggml_type == gguf.GGMLQuantizationType.Q4_0: return quantize_q4_0(tensor) elif ggml_type == gguf.GGMLQuantizationType.Q4_1: From 60b29ea6e43c4be3de1b950e1588e644040554bf Mon Sep 17 00:00:00 2001 From: Heiner Date: Thu, 23 May 2024 11:26:35 +0200 Subject: [PATCH 10/13] More constants from gguf. --- convert_grok.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/convert_grok.py b/convert_grok.py index 35ce46cd1..714a216fc 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -35,9 +35,9 @@ if "NO_LOCAL_GGUF" not in os.environ: import gguf -GGML_QK8_0 = 32 -GGML_QK4_0 = 32 -GGML_QK4_1 = 32 +QK8_0 = gguf.GGML_QUANT_SIZES[gguf.GGMLQuantizationType.Q8_0][0] +QK4_0 = gguf.GGML_QUANT_SIZES[gguf.GGMLQuantizationType.Q4_0][0] +QK4_1 = gguf.GGML_QUANT_SIZES[gguf.GGMLQuantizationType.Q4_1][0] # Heuristic to avoid having to fully parse pickle files. @@ -125,8 +125,8 @@ def get_weights(fn): def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor: # equivalent to ggml_quantize_q8_0 in ggml.c (modulo rounding away from zero) - assert tensor.shape[1] % GGML_QK8_0 == 0 - tensor = tensor.reshape(-1, GGML_QK8_0) + assert tensor.shape[1] % QK8_0 == 0 + tensor = tensor.reshape(-1, QK8_0) scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1) tensor = (tensor / scale).round().clamp(min=-128, max=127).char() # add scale into each block @@ -136,8 +136,8 @@ def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor: def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor: # equivalent to ggml_quantize_q4_0 in ggml.c (modulo rounding away from zero) - assert tensor.shape[1] % GGML_QK4_0 == 0 - tensor = tensor.reshape(-1, GGML_QK4_0) + assert tensor.shape[1] % QK4_0 == 0 + tensor = tensor.reshape(-1, QK4_0) abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1) scale = max_values / -8 @@ -151,8 +151,8 @@ def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor: def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor: # equivalent to ggml_quantize_q4_1 in ggml.c (modulo rounding away from zero) - assert tensor.shape[1] % GGML_QK4_1 == 0 - tensor = tensor.reshape(-1, GGML_QK4_1) + assert tensor.shape[1] % QK4_1 == 0 + tensor = tensor.reshape(-1, QK4_1) abs_max_indices = tensor.max(dim=-1, keepdim=True).indices max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1) abs_min_indices = tensor.min(dim=-1, keepdim=True).indices @@ -188,7 +188,7 @@ def maybe_quantize_tensor(tensor, ggml_type): def get_dtype_and_ggml_type(name, tensor, ggml_type): if tensor.ndim in (2, 3) and "ffn_gate_inp" not in name: - if tensor.shape[1] % GGML_QK8_0 == 0: + if tensor.shape[1] % QK8_0 == 0: return np.int8, ggml_type else: return np.float16, gguf.GGMLQuantizationType.F16 From 0a1ef1127fc45649694e8d837759e00ae21c237e Mon Sep 17 00:00:00 2001 From: Heiner Date: Thu, 23 May 2024 15:07:27 +0200 Subject: [PATCH 11/13] Write tensors in layer order. --- convert_grok.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/convert_grok.py b/convert_grok.py index 714a216fc..05a304407 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -197,12 +197,20 @@ def get_dtype_and_ggml_type(name, tensor, ggml_type): def dump_state_dict(f, ggml_type, input_dir, config): - weight_names = get_weight_names(config.num_hidden_layers) weights = {} - # First operate on meta tensors to find shapes and dtypes for GGUF header. - for idx, name in enumerate(weight_names): - weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000") + # Load weights in file order (mmap'ed). + for idx, name in enumerate(get_weight_names(config.num_hidden_layers)): + weights[name] = get_weights(f"{input_dir}/tensor{idx:05}_000") + + logging.debug("Loaded %i files", len(weights)) + + # But write in layer order. + weight_names = get_weight_names(config.num_hidden_layers, lexicographic=False) + + # Operate on meta tensors to find shapes and dtypes for GGUF header. + for name in weight_names: + weight, scales = weights[name] meta_tensor = convert_weight(name, weight, scales, config, device="meta") dtype, tensor_ggml_type = get_dtype_and_ggml_type(name, meta_tensor, ggml_type) quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type) @@ -213,8 +221,6 @@ def dump_state_dict(f, ggml_type, input_dir, config): quantized_meta_tensor.nbytes, tensor_ggml_type, ) - weights[name] = weight, scales - logging.debug("Loaded %i files", len(weight_names)) f.write_header_to_file() f.write_kv_data_to_file() @@ -244,7 +250,7 @@ def dump_state_dict(f, ggml_type, input_dir, config): except NameError: pass - if len(tensor_info) != len(weight_names): + if weights: logging.warning("Not all tensors are converted") @@ -293,8 +299,10 @@ def extract_vocabulary_from_model(vocab): return tokens, scores, toktypes -def get_weight_names(num_hidden_layers=64): - """Return Grok-1 weight names, in the order in which they are in the tensor#####_000 files.""" +def get_weight_names(num_hidden_layers=64, lexicographic=True): + """Return Grok-1 weight names. + + If `lexicographic` is set, the order is as in the tensor#####_000 files.""" weight_names = [ gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD], @@ -317,7 +325,10 @@ def get_weight_names(num_hidden_layers=64): ) layers = [str(bid) for bid in range(64)] - layers.sort() # Lexicographic sort: 0 < 1 < 10 < 11 ... < 2 < 20 < ... + + if lexicographic: + # Lexicographic sort: 0 < 1 < 10 < 11 ... < 2 < 20 < ... + layers.sort() for bid in layers[:num_hidden_layers]: for key in layer: From abc958b07e8f183a8c8c23e04275e4e07ec9a8b6 Mon Sep 17 00:00:00 2001 From: Heiner Date: Thu, 23 May 2024 15:19:49 +0200 Subject: [PATCH 12/13] Move noqa comment to where the lastest flake8 likes it. --- convert_grok.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/convert_grok.py b/convert_grok.py index 05a304407..43615d6ba 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -244,9 +244,9 @@ def dump_state_dict(f, ggml_type, input_dir, config): tensor_info.append((name, list(tensor.shape), tensor_ggml_type.name)) try: - print( + print( # noqa: NP100 tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql") - ) # noqa: NP100 + ) except NameError: pass From 739648f3e68c90fb823e06e6922e73ff026a26f5 Mon Sep 17 00:00:00 2001 From: Heiner Date: Thu, 23 May 2024 20:24:47 +0200 Subject: [PATCH 13/13] Implement Q8_0 quantization fully in PyTorch. This is equivalent to gguf.quantize_q8_0 but doesn't round-trip to Numpy. --- convert_grok.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/convert_grok.py b/convert_grok.py index 43615d6ba..e5d9aa597 100644 --- a/convert_grok.py +++ b/convert_grok.py @@ -123,12 +123,21 @@ def get_weights(fn): assert len(arrays) in (1, 2) +def torch_roundf(t: torch.Tensor) -> torch.Tensor: + """Round halfway cases away from zero like roundf(3). Cf. gguf/quants.py.""" + a = abs(t) + floored = torch.floor(a) + b = floored + torch.floor(2 * (a - floored)) + return torch.sign(t) * b + + def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor: - # equivalent to ggml_quantize_q8_0 in ggml.c (modulo rounding away from zero) + # Equivalent to gguf.quantize_q8_0 but PyTorch instead of Numpy. assert tensor.shape[1] % QK8_0 == 0 tensor = tensor.reshape(-1, QK8_0) scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1) - tensor = (tensor / scale).round().clamp(min=-128, max=127).char() + iscale = torch.where(scale != 0.0, 1.0 / scale, 0.0) + tensor = torch_roundf(tensor * iscale).clamp(min=-128, max=127).char() # add scale into each block tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1) return tensor @@ -175,9 +184,7 @@ def maybe_quantize_tensor(tensor, ggml_type): elif ggml_type == gguf.GGMLQuantizationType.F16: return tensor.half() elif ggml_type == gguf.GGMLQuantizationType.Q8_0: - if tensor.device.type == "meta": - return quantize_q8_0(tensor) # Cannot convert into numpy array. - return torch.from_numpy(gguf.quantize_q8_0(tensor.numpy())) + return quantize_q8_0(tensor) elif ggml_type == gguf.GGMLQuantizationType.Q4_0: return quantize_q4_0(tensor) elif ggml_type == gguf.GGMLQuantizationType.Q4_1: