convert : various script cleanups/fixes + merges and special token handling (#2842)

* convert: Fix permute calls and method/func definitions

* Cleanups for gguf-py

* Minor types cleanups.

* Initial implementation of handling merges and special tokens

* convert: Handle special tokens and merges in vocab only mode

convert: Vocab only mode no longer requires loading model tensors

* gguf: Refactor tensor name mapping

* convert: Fix type hint for special_token_types in SpecialVocab

* Use common special vocab handling in various conversion scripts

* First pass at implementing suggested changes

* Second pass

* gguf: SpecialVocab: Fix issue with special token content not in a dict

gguf: SpecialVocab: Allow skipping handling of merges

* convert-falcon-hf-to-gguf: Support --vocab-only option, bail out if no tokenizer.json

* convert-gptneox-hf-to-gguf and convert: Only handle merges for BPE tokenizer

* gguf: SpecialVocab: Actually set load_merges in object

* Uniform args parsing and vocab only mode for convert examples

* convert.py: Set gpt2 as tokenizer model when using BPE

* Squish last type warning in gguf.py - yay!
This commit is contained in:
Kerfuffle 2023-08-30 02:25:50 -06:00 committed by GitHub
parent ad9ddcff6e
commit dc07dc492e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 728 additions and 748 deletions

View file

@ -4,9 +4,13 @@ import sys
import struct
import tempfile
import numpy as np
import json
import os
from pathlib import Path
from enum import IntEnum, auto
from typing import Any, IO, List, Optional
from io import BufferedWriter
from typing import Any, BinaryIO, Callable, IO, Dict, List, Optional, Sequence, Tuple, Union
#
# constants
@ -71,35 +75,35 @@ KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world"
class MODEL_ARCH(IntEnum):
LLAMA = auto()
FALCON = auto()
GPT2 = auto()
GPTJ = auto()
GPTNEOX = auto()
MPT = auto()
LLAMA : int = auto()
FALCON : int = auto()
GPT2 : int = auto()
GPTJ : int = auto()
GPTNEOX: int = auto()
MPT : int = auto()
class MODEL_TENSOR(IntEnum):
TOKEN_EMBD = auto()
POS_EMBD = auto()
OUTPUT = auto()
OUTPUT_NORM = auto()
ROPE_FREQS = auto()
ATTN_Q = auto()
ATTN_K = auto()
ATTN_V = auto()
ATTN_QKV = auto()
ATTN_OUT = auto()
ATTN_NORM = auto()
ATTN_NORM_2 = auto()
ATTN_ROT_EMBD = auto()
FFN_GATE = auto()
FFN_DOWN = auto()
FFN_UP = auto()
FFN_NORM = auto()
TOKEN_EMBD : int = auto()
POS_EMBD : int = auto()
OUTPUT : int = auto()
OUTPUT_NORM : int = auto()
ROPE_FREQS : int = auto()
ATTN_Q : int = auto()
ATTN_K : int = auto()
ATTN_V : int = auto()
ATTN_QKV : int = auto()
ATTN_OUT : int = auto()
ATTN_NORM : int = auto()
ATTN_NORM_2 : int = auto()
ATTN_ROT_EMBD: int = auto()
FFN_GATE : int = auto()
FFN_DOWN : int = auto()
FFN_UP : int = auto()
FFN_NORM : int = auto()
MODEL_ARCH_NAMES = {
MODEL_ARCH_NAMES: Dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.GPT2: "gpt2",
@ -108,7 +112,7 @@ MODEL_ARCH_NAMES = {
MODEL_ARCH.MPT: "mpt",
}
MODEL_TENSOR_NAMES = {
MODEL_TENSOR_NAMES: Dict[MODEL_ARCH, Dict[MODEL_TENSOR, str]] = {
MODEL_ARCH.LLAMA: {
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
@ -154,7 +158,7 @@ MODEL_TENSOR_NAMES = {
}
# tensors that will not be serialized
MODEL_TENSOR_SKIP = {
MODEL_TENSOR_SKIP: Dict[MODEL_ARCH, List[MODEL_TENSOR]] = {
MODEL_ARCH.LLAMA: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
@ -162,167 +166,198 @@ MODEL_TENSOR_SKIP = {
}
# TODO: the following helper functions should be removed
# instead, get_tensor_name_map should return tuples of (name, MODEL_TENSOR)
# however, my Python is very bad, and I couldn't figure out how to do this, hence these functions
# REMOVE
def should_skip_tensor_TMP(arch: MODEL_ARCH, n_blocks: int, name: str) -> bool:
for skip in MODEL_TENSOR_SKIP.get(arch, []):
for i in range(n_blocks):
if name == MODEL_TENSOR_NAMES[arch][skip].format(bid=i):
return True
class TensorNameMap:
mappings_cfg: Dict[MODEL_TENSOR, Tuple[str, ...]] = {
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox
"transformer.wte", # gpt2 mpt
"transformer.word_embeddings", # falcon
"model.embed_tokens", # llama-hf
"tok_embeddings", # llama-pth
),
return False
# Position embeddings
MODEL_TENSOR.POS_EMBD: (
"transformer.wpe", # gpt2
),
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf
"output", # llama-pth
),
def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> dict:
tensor_map = {}
# Output norm
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 falcon
"model.norm", # llama-hf
"norm", # llama-pth
),
# Token embeddings
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.TOKEN_EMBD, None)
# Rope frequencies
MODEL_TENSOR.ROPE_FREQS: (
"rope.freqs", # llama-pth
),
}
tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox
tensor_map["transformer.wte"] = mapped_to # gpt2 mpt
tensor_map["transformer.word_embeddings"] = mapped_to # falcon
tensor_map["model.embed_tokens"] = mapped_to # llama-hf
tensor_map["tok_embeddings"] = mapped_to # llama-pth
# Position embeddings
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.POS_EMBD, None)
tensor_map["transformer.wpe"] = mapped_to # gpt2
# Output
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT, None)
tensor_map["embed_out"] = mapped_to # gptneox
tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf
tensor_map["output"] = mapped_to # llama-pth
# Output norm
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT_NORM, None)
tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox
tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon
tensor_map["transformer.norm_f"] = mapped_to # mpt
tensor_map["model.norm"] = mapped_to # llama-hf
tensor_map["norm"] = mapped_to # llama-pth
# Rope frequencies
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ROPE_FREQS, None)
tensor_map["rope.freqs"] = mapped_to # llama-pth
# Attention and feed-forward blocks
for i in range(0, n_blocks):
block_mappings_cfg: Dict[MODEL_TENSOR, Tuple[str, ...]] = {
# Attention norm
# TODO: is there are simpler way to write these 2 lines in Python?
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM, None)
mapped_to = mapped_to.format(bid=i) if mapped_to else None
tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".input_layernorm"] = mapped_to # falcon7b
tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b
tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth
MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
"transformer.h.{bid}.ln_1", # gpt2
"transformer.blocks.{bid}.norm_1", # mpt
"transformer.h.{bid}.input_layernorm", # falcon7b
"transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf
"layers.{bid}.attention_norm", # llama-pth
),
# Attention norm 2
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM_2, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b
MODEL_TENSOR.ATTN_NORM_2: (
"transformer.h.{bid}.ln_attn", # falcon40b
),
# Attention query-key-value
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_QKV, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon
MODEL_TENSOR.ATTN_QKV: (
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
"transformer.h.{bid}.attn.c_attn", # gpt2
"transformer.blocks.{bid}.attn.Wqkv", # mpt
"transformer.h.{bid}.self_attention.query_key_value", # falcon
),
# Attention query
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_Q, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth
MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf
"layers.{bid}.attention.wq", # llama-pth
),
# Attention key
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_K, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth
MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf
"layers.{bid}.attention.wk", # llama-pth
),
# Attention value
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_V, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth
MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf
"layers.{bid}.attention.wv", # llama-pth
),
# Attention output
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_OUT, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon
tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth
MODEL_TENSOR.ATTN_OUT: (
"gpt_neox.layers.{bid}.attention.dense", # gptneox
"transformer.h.{bid}.attn.c_proj", # gpt2
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"model.layers.{bid}.self_attn.o_proj", # llama-hf
"layers.{bid}.attention.wo", # llama-pth
),
# Rotary embeddings
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_ROT_EMBD, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".self_attn.rotary_emb.inv_freq"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention.inner_attention.rope.freqs"] = mapped_to # llama-pth
MODEL_TENSOR.ATTN_ROT_EMBD: (
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
),
# Feed-forward norm
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_NORM, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt
tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth
MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
"transformer.h.{bid}.ln_2", # gpt2
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf
"layers.{bid}.ffn_norm", # llama-pth
),
# Feed-forward up
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_UP, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon
tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth
MODEL_TENSOR.FFN_UP: (
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
"transformer.h.{bid}.mlp.c_fc", # gpt2
"transformer.blocks.{bid}.ffn.up_proj", # mpt
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
"model.layers.{bid}.mlp.up_proj", # llama-hf
"layers.{bid}.feed_forward.w3", # llama-pth
),
# Feed-forward gate
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_GATE, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth
MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.mlp.gate_proj", # llama-hf
"layers.{bid}.feed_forward.w1", # llama-pth
),
# Feed-forward down
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_DOWN, None)
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None
MODEL_TENSOR.FFN_DOWN: (
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
"transformer.h.{bid}.mlp.c_proj", # gpt2
"transformer.blocks.{bid}.ffn.down_proj", # mpt
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"model.layers.{bid}.mlp.down_proj", # llama-hf
"layers.{bid}.feed_forward.w2", # llama-pth
),
}
tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox
tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # falcon
tensor_map["model.layers."+str(i)+".mlp.down_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".feed_forward.w2"] = mapped_to # llama-pth
mapping: Dict[str, Tuple[MODEL_TENSOR, str]]
return tensor_map
tensor_names: Dict[MODEL_TENSOR, str]
def __init__(self, arch: MODEL_ARCH, n_blocks: int):
mapping = self.mapping = {}
tensor_names = self.tensor_names = MODEL_TENSOR_NAMES[arch]
for tensor, keys in self.mappings_cfg.items():
tensor_name = tensor_names.get(tensor)
if tensor_name is None:
continue
for key in keys:
mapping[key] = (tensor, tensor_name)
for bid in range(n_blocks):
for tensor, keys in self.block_mappings_cfg.items():
tensor_name = tensor_names.get(tensor)
if tensor_name is None:
continue
tensor_name = tensor_name.format(bid = bid)
for key in keys:
key = key.format(bid = bid)
mapping[key] = (tensor, tensor_name)
def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[Tuple[MODEL_TENSOR, str]]:
result = self.mapping.get(key)
if result is not None:
return result
for suffix in try_suffixes:
if key.endswith(suffix):
result = self.mapping.get(key[:-len(suffix)])
if result is not None:
return (result[0], result[1] + suffix)
return None
def get_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[str]:
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
if result is None:
return None
return result[1]
def get_type(self, key: str, try_suffixes: Sequence[str]) -> Optional[MODEL_TENSOR]:
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
if result is None:
return None
return result[0]
def __getitem__(self, key: str) -> str:
try:
return self.mapping[key][1]
except KeyError:
raise KeyError(key)
def __contains__(self, key: str) -> bool:
return key in self.mapping
def __repr__(self) -> str:
return repr(self.mapping)
def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> TensorNameMap:
return TensorNameMap(arch, n_blocks)
class TokenType(IntEnum):
NORMAL = 1
@ -388,15 +423,21 @@ class GGUFValueType(IntEnum):
class GGUFWriter:
def __init__(self, path: str, arch: str, use_temp_file = True):
fout: BufferedWriter
arch: str
offset_tensor = 0
data_alignment = GGUF_DEFAULT_ALIGNMENT
kv_data = b""
kv_data_count = 0
ti_data = b""
ti_data_count = 0
use_temp_file: bool
temp_file: Optional[tempfile.SpooledTemporaryFile[bytes]] = None
tensors: List[Tuple[np.ndarray[Any, Any], int]]
def __init__(self, path: Union[os.PathLike[str], str], arch: str, use_temp_file = True):
self.fout = open(path, "wb")
self.arch = arch
self.offset_tensor = 0
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
self.kv_data = b""
self.kv_data_count = 0
self.ti_data = b""
self.ti_data_count = 0
self.add_architecture()
self.use_temp_file = use_temp_file
self.tensors = []
@ -470,14 +511,27 @@ class GGUFWriter:
self.add_key(key)
self.add_val(val, GGUFValueType.STRING)
def add_array(self, key: str, val: list):
if not isinstance(val, list):
raise ValueError("Value must be a list for array type")
def add_array(self, key: str, val: Sequence[Any]):
if not isinstance(val, Sequence):
raise ValueError("Value must be a sequence for array type")
self.add_key(key)
self.add_val(val, GGUFValueType.ARRAY)
def add_val(self: str, val: Any, vtype: GGUFValueType = None, add_vtype: bool = True):
_simple_value_packing = {
GGUFValueType.UINT8: "<B",
GGUFValueType.INT8: "<b",
GGUFValueType.UINT16: "<H",
GGUFValueType.INT16: "<h",
GGUFValueType.UINT32: "<I",
GGUFValueType.INT32: "<i",
GGUFValueType.FLOAT32: "<f",
GGUFValueType.UINT64: "<Q",
GGUFValueType.INT64: "<q",
GGUFValueType.FLOAT64: "<d",
GGUFValueType.BOOL: "?" ,
}
def add_val(self, val: Any, vtype: Optional[GGUFValueType] = None, add_vtype: bool = True):
if vtype is None:
vtype = GGUFValueType.get_type(val)
@ -485,47 +539,29 @@ class GGUFWriter:
self.kv_data += struct.pack("<I", vtype)
self.kv_data_count += 1
if vtype == GGUFValueType.UINT8:
self.kv_data += struct.pack("<B", val)
elif vtype == GGUFValueType.INT8:
self.kv_data += struct.pack("<b", val)
elif vtype == GGUFValueType.UINT16:
self.kv_data += struct.pack("<H", val)
elif vtype == GGUFValueType.INT16:
self.kv_data += struct.pack("<h", val)
elif vtype == GGUFValueType.UINT32:
self.kv_data += struct.pack("<I", val)
elif vtype == GGUFValueType.INT32:
self.kv_data += struct.pack("<i", val)
elif vtype == GGUFValueType.FLOAT32:
self.kv_data += struct.pack("<f", val)
elif vtype == GGUFValueType.UINT64:
self.kv_data += struct.pack("<Q", val)
elif vtype == GGUFValueType.INT64:
self.kv_data += struct.pack("<q", val)
elif vtype == GGUFValueType.FLOAT64:
self.kv_data += struct.pack("<d", val)
elif vtype == GGUFValueType.BOOL:
self.kv_data += struct.pack("?", val)
pack_fmt = self._simple_value_packing.get(vtype)
if pack_fmt is not None:
self.kv_data += struct.pack(pack_fmt, val)
elif vtype == GGUFValueType.STRING:
encoded_val = val.encode("utf8") if isinstance(val, str) else val
self.kv_data += struct.pack("<Q", len(encoded_val))
self.kv_data += encoded_val
elif vtype == GGUFValueType.ARRAY:
ltype = set([GGUFValueType.get_type(item) for item in val])
assert len(ltype) == 1, "All items in a GGUF array should be of the same type"
self.kv_data += struct.pack("<I", list(ltype)[0])
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and len(val) > 0:
ltype = GGUFValueType.get_type(val[0])
if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
raise ValueError("All items in a GGUF array should be of the same type")
self.kv_data += struct.pack("<I", ltype)
self.kv_data += struct.pack("<Q", len(val))
for item in val:
self.add_val(item, add_vtype=False)
else:
raise ValueError("Invalid GGUF metadata value type")
raise ValueError("Invalid GGUF metadata value type or value")
@staticmethod
def ggml_pad(x: int, n: int) -> int:
return ((x + n - 1) // n) * n
def add_tensor_info(self, name: str, tensor_shape: np.ndarray, tensor_dtype: np.dtype, tensor_nbytes: int, raw_dtype: Optional[GGMLQuantizationType] = None):
def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: Union[np.dtype[np.float16], np.dtype[np.float32]], tensor_nbytes: int, raw_dtype: Optional[GGMLQuantizationType] = None):
assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now"
encoded_name = name.encode("utf8")
@ -544,16 +580,18 @@ class GGUFWriter:
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
self.ti_data_count += 1
def add_tensor(self, name: str, tensor: np.ndarray, raw_shape: Optional[np.ndarray] = None, raw_dtype: Optional[GGMLQuantizationType] = None):
if self.use_temp_file and not hasattr(self, "temp_file"):
self.temp_file = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024)
self.temp_file.seek(0)
def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Optional[Sequence[int]] = None, raw_dtype: Optional[GGMLQuantizationType] = None):
if self.use_temp_file and self.temp_file is None:
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024)
fp.seek(0)
self.temp_file = fp
self.add_tensor_info(name, raw_shape if raw_shape is not None else tensor.shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes
if not self.use_temp_file:
if self.temp_file is None:
self.tensors.append((tensor, pad))
return
@ -562,25 +600,22 @@ class GGUFWriter:
if pad != 0:
self.temp_file.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray):
pad = GGUFWriter.ggml_pad(self.fout.tell(), self.data_alignment) - self.fout.tell()
def write_padding(self, fp: BinaryIO, n: int, align: Optional[int] = None):
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
if pad != 0:
self.fout.write(bytes([0] * pad))
fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray[Any, Any]):
self.write_padding(self.fout, self.fout.tell())
tensor.tofile(self.fout)
pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes
if pad != 0:
self.fout.write(bytes([0] * pad))
self.write_padding(self.fout, tensor.nbytes)
def write_tensors_to_file(self):
self.write_ti_data_to_file()
pad = GGUFWriter.ggml_pad(self.fout.tell(), self.data_alignment) - self.fout.tell()
if pad != 0:
self.fout.write(bytes([0] * pad))
self.write_padding(self.fout, self.fout.tell())
if not self.use_temp_file:
if self.temp_file is None:
for (currtensor, currpad) in self.tensors:
currtensor.tofile(self.fout)
if currpad != 0:
@ -654,10 +689,6 @@ class GGUFWriter:
self.add_bool(
KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
def add_tensor_data_layout(self, layout: str):
self.add_string(
KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_head_count(self, count: int):
self.add_uint32(
KEY_ATTENTION_HEAD_COUNT.format(arch=self.arch), count)
@ -695,16 +726,16 @@ class GGUFWriter:
def add_tokenizer_model(self, model: str):
self.add_string(KEY_TOKENIZER_MODEL, model)
def add_token_list(self, tokens: List):
def add_token_list(self, tokens: Union[Sequence[str], Sequence[bytes], Sequence[bytearray]]):
self.add_array(KEY_TOKENIZER_LIST, tokens)
def add_token_merges(self, merges: List):
def add_token_merges(self, merges: Union[Sequence[str], Sequence[bytes], Sequence[bytearray]]):
self.add_array(KEY_TOKENIZER_MERGES, merges)
def add_token_types(self, types: List[int]):
def add_token_types(self, types: Union[Sequence[TokenType], Sequence[int]]):
self.add_array(KEY_TOKENIZER_TOKEN_TYPE, types)
def add_token_scores(self, scores: List[float]):
def add_token_scores(self, scores: Sequence[float]):
self.add_array(KEY_TOKENIZER_SCORES, scores)
def add_bos_token_id(self, id: int):
@ -723,6 +754,84 @@ class GGUFWriter:
self.add_uint32(KEY_TOKENIZER_PAD_ID, id)
class SpecialVocab:
load_merges: bool = False
merges: List[str] = []
special_token_types: Tuple[str, ...] = tuple(('bos', 'eos', 'unk', 'sep', 'pad'))
special_token_ids: Dict[str, int] = {}
def __init__(self, path: Path, load_merges: bool = False, special_token_types: Optional[Tuple[str, ...]] = None):
self.special_token_ids = {}
self.load_merges = load_merges
if special_token_types is not None:
self.special_token_types = special_token_types
self.load(path)
def load(self, path: Path):
if not self.try_load_from_tokenizer_json(path):
self.try_load_from_config_json(path)
def try_load_from_tokenizer_json(self, path: Path) -> bool:
tokenizer_file = path / 'tokenizer.json'
if not tokenizer_file.is_file():
return False
with open(tokenizer_file, 'r', encoding = 'utf-8') as f:
tokenizer = json.load(f)
if self.load_merges:
merges = tokenizer.get('model', {}).get('merges')
if isinstance(merges, list) and len(merges) > 0 and isinstance(merges[0], str):
self.merges = merges
tokenizer_config_file = path / 'tokenizer_config.json'
added_tokens = tokenizer.get('added_tokens')
if added_tokens is None or not tokenizer_config_file.is_file():
return True
with open(tokenizer_config_file, 'r', encoding = 'utf-8') as f:
tokenizer_config = json.load(f)
for typ in self.special_token_types:
entry = tokenizer_config.get(f'{typ}_token')
if isinstance(entry, str):
tc_content = entry
elif isinstance(entry, dict):
entry_content = entry.get('content')
if not isinstance(entry_content, str):
continue
tc_content = entry_content
else:
continue
for maybe_token_id in (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content):
if isinstance(maybe_token_id, int):
self.special_token_ids[typ] = maybe_token_id
break
return True
def try_load_from_config_json(self, path: Path) -> bool:
config_file = path / 'config.json'
if not config_file.is_file():
return False
with open(config_file, 'r', encoding = 'utf-8') as f:
config = json.load(f)
for typ in self.special_token_types:
maybe_token_id = config.get(f'{typ}_token_id')
if isinstance(maybe_token_id, int):
self.special_token_ids[typ] = maybe_token_id
return True
def add_to_gguf(self, gw: GGUFWriter):
if len(self.merges) > 0:
print(f'gguf: Adding {len(self.merges)} merge(s).')
gw.add_token_merges(self.merges)
for typ, tokid in self.special_token_ids.items():
handler: Optional[Callable[[int], None]] = getattr(gw, f'add_{typ}_token_id', None)
if handler is None:
print(f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping')
continue
print(f'gguf: Setting special token type {typ} to {tokid}')
handler(tokid)
def __repr__(self):
return f'<SpecialVocab with {len(self.merges)} merges and special tokens {self.special_token_ids if self.special_token_ids else "unset"}>'
# Example usage:
if __name__ == "__main__":
# Example usage with a file

0
gguf-py/gguf/py.typed Normal file
View file

View file

@ -5,6 +5,7 @@ description = "Write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"]
packages = [
{include = "gguf"},
{include = "gguf/py.typed"},
]
readme = "README.md"
homepage = "https://ggml.ai"