convert-llama-ggmlv3-to-gguf: Try to handle files older than GGJTv3

Better error messages for files that cannot be converted

Add file type to GGUF output
This commit is contained in:
KerfuffleV2 2023-09-05 03:37:02 -06:00
parent bd33e5ab92
commit 00fe3fdb61

View file

@ -5,6 +5,7 @@ import argparse
import math import math
import struct import struct
import sys import sys
from enum import IntEnum
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
@ -34,10 +35,34 @@ GGML_QUANT_SIZES = {
gguf.GGMLQuantizationType.Q8_K : (256, 4 + QK_K + QK_K // 8), gguf.GGMLQuantizationType.Q8_K : (256, 4 + QK_K + QK_K // 8),
} }
class GGMLFormat(IntEnum):
GGML = 0
GGMF = 1
GGJT = 2
class GGMLFType(IntEnum):
ALL_F32 = 0
MOSTLY_F16 = 1
MOSTLY_Q4_0 = 2
MOSTLY_Q4_1 = 3
MOSTLY_Q4_1_SOME_F16 = 4
MOSTLY_Q8_0 = 7
MOSTLY_Q5_0 = 8
MOSTLY_Q5_1 = 9
MOSTLY_Q2_K = 10
MOSTLY_Q3_K_S = 11
MOSTLY_Q3_K_M = 12
MOSTLY_Q3_K_L = 13
MOSTLY_Q4_K_S = 14
MOSTLY_Q4_K_M = 15
MOSTLY_Q5_K_S = 16
MOSTLY_Q5_K_M = 17
MOSTLY_Q6_K = 18
class Hyperparameters: class Hyperparameters:
def __init__(self): def __init__(self):
self.n_vocab = self.n_embd = self.n_mult = self.n_head = self.n_layer = self.n_rot = self.ftype = 0 self.n_vocab = self.n_embd = self.n_mult = self.n_head = self.n_layer = self.n_rot = self.n_ff = 0
self.n_ff = 0 self.ftype = GGMLFType.ALL_F32
def set_n_ff(self, model): def set_n_ff(self, model):
ff_tensor_idx = model.tensor_map.get(b'layers.0.feed_forward.w1.weight') ff_tensor_idx = model.tensor_map.get(b'layers.0.feed_forward.w1.weight')
@ -53,16 +78,21 @@ class Hyperparameters:
self.n_head, self.n_head,
self.n_layer, self.n_layer,
self.n_rot, self.n_rot,
self.ftype, ftype,
) = struct.unpack('<7I', data[offset:offset + (4 * 7)]) ) = struct.unpack('<7I', data[offset:offset + (4 * 7)])
try:
self.ftype = GGMLFType(ftype)
except ValueError:
raise ValueError(f'Invalid ftype {ftype}')
return 4 * 7 return 4 * 7
def __str__(self): def __str__(self):
return f'<Hyperparameters: n_vocab={self.n_vocab}, n_embd={self.n_embd}, n_mult={self.n_mult}, n_head={self.n_head}, n_layer={self.n_layer}, n_rot={self.n_rot}, n_ff={self.n_ff}, ftype={self.ftype}>' return f'<Hyperparameters: n_vocab={self.n_vocab}, n_embd={self.n_embd}, n_mult={self.n_mult}, n_head={self.n_head}, n_layer={self.n_layer}, n_rot={self.n_rot}, n_ff={self.n_ff}, ftype={self.ftype.name}>'
class Vocab: class Vocab:
def __init__(self): def __init__(self, load_scores = True):
self.items = [] self.items = []
self.load_scores = load_scores
def load(self, data, offset, n_vocab): def load(self, data, offset, n_vocab):
orig_offset = offset orig_offset = offset
@ -70,20 +100,24 @@ class Vocab:
itemlen = struct.unpack('<I', data[offset:offset + 4])[0] itemlen = struct.unpack('<I', data[offset:offset + 4])[0]
assert itemlen < 4096, 'Absurd vocab item length' assert itemlen < 4096, 'Absurd vocab item length'
offset += 4 offset += 4
vocab = bytes(data[offset:offset + itemlen]) item_text = bytes(data[offset:offset + itemlen])
offset += itemlen offset += itemlen
score = struct.unpack('<f', data[offset:offset + 4])[0] if self.load_scores:
item_score = struct.unpack('<f', data[offset:offset + 4])[0]
offset += 4 offset += 4
self.items.append((vocab, score)) else:
item_score = 0.0
self.items.append((item_text, item_score))
return offset - orig_offset return offset - orig_offset
class Tensor: class Tensor:
def __init__(self): def __init__(self, use_padding = True):
self.name = None self.name = None
self.dims: tuple[int, ...] = () self.dims: tuple[int, ...] = ()
self.dtype = None self.dtype = None
self.start_offset = 0 self.start_offset = 0
self.len_bytes = np.int64(0) self.len_bytes = np.int64(0)
self.use_padding = use_padding
def load(self, data, offset): def load(self, data, offset):
orig_offset = offset orig_offset = offset
@ -99,7 +133,7 @@ class Tensor:
offset += 4 * n_dims offset += 4 * n_dims
self.name = bytes(data[offset:offset + name_len]) self.name = bytes(data[offset:offset + name_len])
offset += name_len offset += name_len
pad = ((offset + 31) & ~31) - offset pad = ((offset + 31) & ~31) - offset if self.use_padding else 0
offset += pad offset += pad
n_elems = np.prod(self.dims) n_elems = np.prod(self.dims)
n_bytes = np.int64(np.int64(n_elems) * np.int64(tysize)) // np.int64(blksize) n_bytes = np.int64(np.int64(n_elems) * np.int64(tysize)) // np.int64(blksize)
@ -109,7 +143,7 @@ class Tensor:
# print(n_dims, name_len, dtype, self.dims, self.name, pad) # print(n_dims, name_len, dtype, self.dims, self.name, pad)
return offset - orig_offset return offset - orig_offset
class GGMLV3Model: class GGMLModel:
def __init__(self): def __init__(self):
self.hyperparameters = None self.hyperparameters = None
self.vocab = None self.vocab = None
@ -117,20 +151,46 @@ class GGMLV3Model:
self.tensors = [] self.tensors = []
def validate_header(self, data, offset): def validate_header(self, data, offset):
if bytes(data[offset:offset + 4]) != b'tjgg' or struct.unpack('<I', data[offset + 4:offset + 8])[0] != 3: magic = bytes(data[offset:offset + 4])
raise ValueError('Only GGJTv3 supported') if magic == b'GGUF':
raise ValueError('File is already in GGUF format.')
if magic == b'lmgg':
self.file_format = GGMLFormat.GGML
self.format_version = 1
return 4
version = struct.unpack('<I', data[offset + 4:offset + 8])[0]
if magic == b'fmgg':
if version != 1:
raise ValueError(f'Cannot handle unexpected GGMF file version {version}')
self.file_format = GGMLFormat.GGMF
self.format_version = version
return 8 return 8
if magic == b'tjgg':
if version < 1 or version > 3:
raise ValueError(f'Cannot handle unexpected GGJT file version {version}')
self.file_format = GGMLFormat.GGJT
self.format_version = version
return 8
raise ValueError(f"Unexpected file magic {magic!r}! This doesn't look like a GGML format file.")
def validate_conversion(self, ftype):
if (self.file_format < GGMLFormat.GGJT or self.format_version < 2) and ftype not in (GGMLFType.ALL_F32, GGMLFType.MOSTLY_F16):
raise ValueError(f'Quantizations changed in GGJTv2. Can only convert unquantized GGML files older than GGJTv2. Sorry, your {self.file_format.name}v{self.format_version} file of type {ftype.name} is not eligible for conversion.')
if (self.file_format == GGMLFormat.GGJT and self.format_version == 2) and ftype in (GGMLFType.MOSTLY_Q4_0, GGMLFType.MOSTLY_Q4_1, GGMLFType.MOSTLY_Q4_1_SOME_F16, GGMLFType.MOSTLY_Q8_0):
raise ValueError(f'Q4 and Q8 quantizations changed in GGJTv3. Sorry, your {self.file_format.name}v{self.format_version} file of type {ftype.name} is not eligible for conversion.')
def load(self, data, offset): def load(self, data, offset):
offset += self.validate_header(data, offset) offset += self.validate_header(data, offset)
hp = Hyperparameters() hp = Hyperparameters()
offset += hp.load(data, offset) offset += hp.load(data, offset)
vocab = Vocab() print(f'* File format: {self.file_format.name}v{self.format_version} with ftype {hp.ftype.name}')
self.validate_conversion(hp.ftype)
vocab = Vocab(load_scores = self.file_format > GGMLFormat.GGML)
offset += vocab.load(data, offset, hp.n_vocab) offset += vocab.load(data, offset, hp.n_vocab)
tensors: list[Tensor] = [] tensors: list[Tensor] = []
tensor_map = {} tensor_map = {}
while offset < len(data): while offset < len(data):
tensor = Tensor() tensor = Tensor(use_padding = self.file_format > GGMLFormat.GGMF)
offset += tensor.load(data, offset) offset += tensor.load(data, offset)
tensor_map[tensor.name] = len(tensors) tensor_map[tensor.name] = len(tensors)
tensors.append(tensor) tensors.append(tensor)
@ -195,6 +255,7 @@ class GGMLToGGUF:
if name is not None: if name is not None:
gguf_writer.add_name(name) gguf_writer.add_name(name)
gguf_writer.add_description(desc) gguf_writer.add_description(desc)
gguf_writer.add_file_type(int(hp.ftype))
if self.params_override is not None: if self.params_override is not None:
po = self.params_override po = self.params_override
assert po.n_embd == hp.n_embd, 'Model hyperparams mismatch' assert po.n_embd == hp.n_embd, 'Model hyperparams mismatch'
@ -330,7 +391,7 @@ def main():
print(f'* Using config: {cfg}') print(f'* Using config: {cfg}')
print('\n=== WARNING === Be aware that this conversion script is best-effort. Use a native GGUF model if possible. === WARNING ===\n') print('\n=== WARNING === Be aware that this conversion script is best-effort. Use a native GGUF model if possible. === WARNING ===\n')
data = np.memmap(cfg.input, mode = 'r') data = np.memmap(cfg.input, mode = 'r')
model = GGMLV3Model() model = GGMLModel()
print('* Scanning GGML input file') print('* Scanning GGML input file')
offset = model.load(data, 0) offset = model.load(data, 0)
print(f'* GGML model hyperparameters: {model.hyperparameters}') print(f'* GGML model hyperparameters: {model.hyperparameters}')
@ -345,6 +406,8 @@ def main():
print(f'* Special vocab: {special_vocab}') print(f'* Special vocab: {special_vocab}')
else: else:
print('\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n') print('\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n')
if model.file_format == GGMLFormat.GGML:
print('! This is a very old GGML file that does not contain vocab scores. Strongly recommend using model metadata!')
converter = GGMLToGGUF(model, data, cfg, params_override = params_override, vocab_override = vocab_override, special_vocab = special_vocab) converter = GGMLToGGUF(model, data, cfg, params_override = params_override, vocab_override = vocab_override, special_vocab = special_vocab)
converter.save() converter.save()
print(f'* Successful completion. Output saved to: {cfg.output}') print(f'* Successful completion. Output saved to: {cfg.output}')