Script to convert Grok-1 weights from raw JAX pickle files.

This commit is contained in:
Heiner 2024-04-24 23:39:06 +02:00
parent 3cbd23ed88
commit 6ddf93b286

492
convert_grok.py Normal file
View file

@ -0,0 +1,492 @@
"""
Convert Grok-1 weights to GGUF format.
Example invocation:
python -m convert_grok -i path/to/grok-1/ckpt-0 --vocab_dir path/to/grok -o grok.bin -t q4_0 --experts 1,2
To run:
./build/bin/main -m grok.bin -p "The answer to life the universe and everything is" -s 1 -n 3 -ngl 1
"""
import argparse
import mmap
import os
import pathlib
import pickletools
import sys
import time
import ml_dtypes
import numpy as np
import torch
try:
from tabulate import tabulate
except ModuleNotFoundError:
pass
from convert import SentencePieceVocab
if "NO_LOCAL_GGUF" not in os.environ:
sys.path.insert(1, str(pathlib.Path(__file__).parent / "gguf-py"))
import gguf
GGML_QK8_0 = 32
GGML_QK4_0 = 32
GGML_QK4_1 = 32
# Heuristic to avoid having to fully parse pickle files.
FP32_SHAPES = {805306368: (131072, 6144), 6144: (6144,), 49152: (6144, 8)}
BF16_SHAPES = {
262144: (8, 1, 32768),
393216: (8, 8, 6144),
1024: (1, 1024),
49152: (8, 6144),
6144: (1, 6144),
}
class AttributeDict(dict):
def __getattr__(self, key):
return self.__getitem__(key) if key in self else super().__getattr__(key)
__setattr__ = dict.__setitem__
def _genops(data):
view = memoryview(data)
code2op = {ord(d.code): d for d in pickletools.opcodes}
dataops = {
"BINBYTES": pickletools.read_uint4,
"BINBYTES8": pickletools.read_uint8,
}
while True:
pos = data.tell()
code = data.read_byte()
opcode = code2op[code]
arg = None
if opcode.arg is not None:
if opcode.name not in dataops:
arg = opcode.arg.reader(data)
else:
size = dataops[opcode.name](data)
p = data.tell()
arg = np.frombuffer(view[p : p + size], dtype=np.uint8)
data.seek(size, 1)
yield opcode, arg, pos
if code == ord(b"."):
break
def genops(fn):
"""Yield (opcode, arg, pos) from for a pickle file.
Uses mmap to avoid copies of binary data (e.g., np and JAX arrays)."""
with open(fn, "rb") as f:
yield from _genops(mmap.mmap(f.fileno(), length=0, flags=mmap.MAP_PRIVATE))
def get_weights(fn):
"""Returns tensor/array data in Grok pickle files, zero copy."""
arrays = []
for unused_opcode, arg, unused_pos in genops(fn):
if isinstance(arg, np.ndarray):
arrays.append(arg)
if len(arrays) == 1:
# Plain numpy array.
array = arrays[0].view(np.float32)
array = array.reshape(FP32_SHAPES[array.size])
return array, None
elif len(arrays) == 2:
weight, scales = arrays
scales = scales.view(ml_dtypes.bfloat16)
scales = scales.reshape(BF16_SHAPES[scales.size])
weight = weight.view(np.int8)
shape = list(scales.shape)
shape[-2] = -1
weight = weight.reshape(shape)
return weight, scales
assert len(arrays) in (1, 2)
def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q8_0 in ggml.c
assert tensor.shape[1] % GGML_QK8_0 == 0
tensor = tensor.view(-1, GGML_QK8_0)
scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
tensor = (tensor / scale).round().clamp(min=-128, max=127).char()
# add scale into each block
tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1)
return tensor
def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q4_0 in ggml.c
assert tensor.shape[1] % GGML_QK4_0 == 0
tensor = tensor.reshape(-1, GGML_QK4_0)
abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices
max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
scale = max_values / -8
tensor = (tensor / scale + 8).round().clamp(min=0, max=15).char()
# compress two int4 weights into a int8
tensor = tensor[:, :16] | (tensor[:, 16:] << 4)
# add scale into each block
tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1)
return tensor
def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q4_1 in ggml.c
assert tensor.shape[1] % GGML_QK4_1 == 0
tensor = tensor.view(-1, GGML_QK4_1)
abs_max_indices = tensor.max(dim=-1, keepdim=True).indices
max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
abs_min_indices = tensor.min(dim=-1, keepdim=True).indices
min_values = torch.take_along_dim(tensor, abs_min_indices, dim=-1)
scale = (max_values - min_values) / 15
tensor = ((tensor - min_values) / scale).round().clamp(min=0, max=15).char()
# compress two int4 weights into a int8
tensor = tensor[:, :16] | (tensor[:, 16:] << 4)
# add scale into each block
tensor = torch.cat(
(scale.half().view(torch.int8), min_values.half().view(torch.int8), tensor), dim=-1
)
return tensor
def maybe_quantize_tensor(tensor, ggml_type):
assert tensor.dtype == torch.float32
if ggml_type == gguf.GGMLQuantizationType.F32:
return tensor.float()
elif ggml_type == gguf.GGMLQuantizationType.F16:
return tensor.half()
elif ggml_type == gguf.GGMLQuantizationType.Q8_0:
return quantize_q8_0(tensor)
elif ggml_type == gguf.GGMLQuantizationType.Q4_0:
return quantize_q4_0(tensor)
elif ggml_type == gguf.GGMLQuantizationType.Q4_1:
return quantize_q4_1(tensor)
else:
raise NotImplementedError(f"Cannot quantize tensor of dtype {tensor.dtype} ({ggml_type})")
def get_dtype_and_ggml_type(tensor, ggml_type):
if tensor.ndim == 2:
if tensor.shape[1] % GGML_QK8_0 == 0:
return np.int8, ggml_type
else:
return np.float16, gguf.GGMLQuantizationType.F16
else:
# 1d weight: convert it to float32
assert tensor.ndim == 1
return np.float32, gguf.GGMLQuantizationType.F32
def dump_state_dict(f, weight_names, model_files, ggml_type, config):
keys2names = {}
meta_tensors = {}
weight_scales = {}
# First operate on meta tensors to find shapes and dtypes for GGUF header.
for name, fn in model_files:
weight, scales = get_weights(fn)
meta_state_dict = convert_weight(name, weight, scales, config.experts, device="meta")
weight_scales[name] = (weight, scales)
for key in meta_state_dict.keys():
keys2names[key] = name
meta_tensors.update(meta_state_dict)
for key in weight_names:
meta_tensor = meta_tensors[key]
dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type)
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
f.add_tensor_info(
key, list(meta_tensor.shape), dtype, quantized_meta_tensor.nbytes, tensor_ggml_type
)
print("Loaded", len(meta_tensors), "files")
f.write_header_to_file()
f.write_kv_data_to_file()
f.write_ti_data_to_file()
cache = {}
tensor_info = []
for key in weight_names:
if key not in cache:
name = keys2names[key]
weight, scales = weight_scales.pop(name)
state_dict = convert_weight(name, weight, scales, config.experts)
permute_tensors(state_dict, config)
cache.update(state_dict)
tensor = cache.pop(key)
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
tensor = maybe_quantize_tensor(tensor, tensor_ggml_type)
array = tensor.numpy()
print(
f"dumping {key}: {tensor_ggml_type.name}/{array.dtype}, {array.shape}, {array.nbytes} bytes"
)
f.write_tensor_data(array)
tensor_info.append((key, tensor.shape, tensor_ggml_type.name))
try:
print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql"))
except NameError:
pass
if len(tensor_info) != len(weight_names):
print("Warning: not all tensors are converted")
def from_numpy(array):
"""Like torch.from_numpy, but handle ml_dtypes.bfloat16 too."""
if array.dtype == ml_dtypes.bfloat16:
return torch.from_numpy(array.view(np.uint8)).view(torch.bfloat16)
return torch.from_numpy(array)
def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, device=None):
# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
result = {}
weight = from_numpy(weight).to(device=device, dtype=dtype)
if scales is not None:
scale = from_numpy(scales).to(device=device, dtype=dtype)
# row parallel layers have sharded scale
if len(scale.shape) >= 2 and scale.shape[-2] != 1:
scale = scale[..., None, :]
weight = weight.view(*weight.shape[:-2], 8, -1, weight.shape[-1])
weight = (weight * scale).view(*weight.shape[:-3], -1, weight.shape[-1])
else:
weight = weight * scale
# Transpose linear matrix
if len(weight.shape) >= 2 and "token_embd" not in tensor_name:
weight = weight.transpose(-1, -2)
if tensor_name.endswith("ffn_gate_inp.weight"):
result[tensor_name] = weight[experts] # gather.
elif "experts" not in tensor_name:
result[tensor_name] = weight
else:
# split moe
for i, expert in enumerate(experts):
key = tensor_name.replace("experts", str(i))
result[key] = weight[expert]
return result
def permute_tensors(state_dict, config):
def permute(weights, n_head):
return (
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape)
)
for name, tensor in state_dict.items():
if name.endswith("attn_k.weight"):
state_dict[name] = permute(tensor, config.num_key_value_heads)
elif name.endswith("attn_q.weight"):
state_dict[name] = permute(tensor, config.num_attention_heads)
def extract_vocabulary_from_model(vocab):
tokens = []
scores = []
toktypes = []
for text, score, toktype in vocab.all_tokens():
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
assert len(tokens) == vocab.vocab_size
return tokens, scores, toktypes
def get_weight_names(config):
weight_names = ["token_embd.weight"]
for i in range(config.num_hidden_layers):
for j in range(config.num_experts):
weight_names += [
f"blk.{i}.ffn_gate.{j}.weight",
f"blk.{i}.ffn_down.{j}.weight",
f"blk.{i}.ffn_up.{j}.weight",
]
weight_names += [
f"blk.{i}.attn_k.weight",
f"blk.{i}.attn_output.weight",
f"blk.{i}.attn_q.weight",
f"blk.{i}.attn_v.weight",
f"blk.{i}.attn_norm.weight",
f"blk.{i}.attn_output_norm.weight",
f"blk.{i}.ffn_norm.weight",
f"blk.{i}.layer_output_norm.weight",
f"blk.{i}.ffn_gate_inp.weight",
]
weight_names += ["output_norm.weight"]
return weight_names
def convert_grok(args, vocab, ggml_type):
start = time.time()
def ffn_size(emb_size, widening_factor):
_ffn_size = int(widening_factor * emb_size) * 2 // 3
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
return _ffn_size
config = {
"vocab_size": 128 * 1024,
"hidden_act": "gelu",
"pad_token_id": 0,
"eos_token_id": 2,
"max_position_embeddings": 8192,
"output_multiplier_scale": 0.5773502691896257,
"embedding_multiplier_scale": 78.38367176906169,
"hidden_size": 48 * 128,
"intermediate_size": -1,
"num_attention_heads": 48,
"num_key_value_heads": 8,
"num_hidden_layers": 64, # Change to 1 for quicker debugging.
"num_selected_experts": 2,
"rope_theta": 10000,
"attn_output_multiplier": 0.08838834764831845,
"rms_norm_eps": 1e-5,
}
config = AttributeDict(config)
config.intermediate_size = ffn_size(config.hidden_size, 8)
config.experts = list(range(8))
if args.experts != "":
config.experts = [int(x, 0) for x in args.experts.split(",")]
config.num_experts = len(config.experts)
assert config.num_experts >= 2, "need at least 2 experts"
print("experts to export:", config.experts)
# Contents of in Grok-1 pickle files, in order. Weights with "experts" will be split later.
tensor_names = [
"token_embd.weight",
"output_norm.weight",
]
for i in range(config.num_hidden_layers):
tensor_names += [
f"blk.{i}.ffn_gate.experts.weight",
f"blk.{i}.ffn_down.experts.weight",
f"blk.{i}.ffn_up.experts.weight",
f"blk.{i}.attn_k.weight",
f"blk.{i}.attn_output.weight",
f"blk.{i}.attn_q.weight",
f"blk.{i}.attn_v.weight",
f"blk.{i}.attn_norm.weight",
f"blk.{i}.attn_output_norm.weight",
f"blk.{i}.ffn_norm.weight",
f"blk.{i}.layer_output_norm.weight",
f"blk.{i}.ffn_gate_inp.weight",
]
tensor_map = [(name, f"{args.input}/tensor{i:05}_000") for i, name in enumerate(tensor_names)]
f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE)
f.add_name("grok")
f.add_vocab_size(config.vocab_size)
f.add_context_length(config.max_position_embeddings)
f.add_embedding_length(config.hidden_size)
f.add_block_count(config.num_hidden_layers)
f.add_feed_forward_length(config.intermediate_size)
f.add_rope_dimension_count(config.hidden_size // config.num_attention_heads)
f.add_head_count(config.num_attention_heads)
f.add_head_count_kv(config.num_key_value_heads)
f.add_expert_count(config.num_experts)
f.add_expert_used_count(config.num_selected_experts)
f.add_layer_norm_rms_eps(config.rms_norm_eps)
f.add_rope_freq_base(config.rope_theta)
f.add_tokenizer_model("llama")
# Extract model vocabulary for model conversion
tokens, scores, toktypes = extract_vocabulary_from_model(vocab)
f.add_token_list(tokens)
f.add_token_scores(scores)
f.add_token_types(toktypes)
weight_names = get_weight_names(config)
dump_state_dict(f, weight_names, tensor_map, ggml_type, config)
f.close()
delta = time.time() - start
print(f"grok GGUF model saved to {args.save_path}. Total time {delta:.2f} sec")
def load_vocab(path):
def load_spm(p):
print(f"Loading vocab file {p}")
return SentencePieceVocab(p)
# Be extra-friendly and accept either a file or a directory. Also, if it's
# a directory, it might be the model directory, and tokenizer.model might
# be in the parent of that.
if path.is_dir():
path2 = path / "tokenizer.model"
# Use `.parent` instead of /.. to handle the symlink case better.
path3 = path.parent / "tokenizer.model"
if path2.exists():
return load_spm(path2)
elif path3.exists():
return load_spm(path3)
raise FileNotFoundError(
f"Could not find tokenizer.model in {path} or its parent; "
"if it's in another directory, pass the directory as --vocab-dir"
)
def main():
parser = argparse.ArgumentParser("convert_grok")
parser.add_argument("-i", "--input", type=str)
parser.add_argument("-o", "--save_path", type=pathlib.Path)
parser.add_argument(
"-t", "--type", type=str, default="q8_0", choices=["f32", "f16", "q8_0", "q4_0", "q4_1"]
)
parser.add_argument("--vocab_dir", type=str, default="")
parser.add_argument("--experts", type=str, default="")
args = parser.parse_args()
vocab = load_vocab(pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input))
ggml_type = gguf.GGMLQuantizationType[args.type.upper()]
convert_grok(args, vocab, ggml_type)
if __name__ == "__main__":
main()