convert-llama-ggmlv3-to-gguf: Try to handle files older than GGJTv3
Better error messages for files that cannot be converted Add file type to GGUF output
This commit is contained in:
parent
bd33e5ab92
commit
00fe3fdb61
1 changed files with 81 additions and 18 deletions
|
@ -5,6 +5,7 @@ import argparse
|
||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
|
from enum import IntEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -34,10 +35,34 @@ GGML_QUANT_SIZES = {
|
||||||
gguf.GGMLQuantizationType.Q8_K : (256, 4 + QK_K + QK_K // 8),
|
gguf.GGMLQuantizationType.Q8_K : (256, 4 + QK_K + QK_K // 8),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class GGMLFormat(IntEnum):
|
||||||
|
GGML = 0
|
||||||
|
GGMF = 1
|
||||||
|
GGJT = 2
|
||||||
|
|
||||||
|
class GGMLFType(IntEnum):
|
||||||
|
ALL_F32 = 0
|
||||||
|
MOSTLY_F16 = 1
|
||||||
|
MOSTLY_Q4_0 = 2
|
||||||
|
MOSTLY_Q4_1 = 3
|
||||||
|
MOSTLY_Q4_1_SOME_F16 = 4
|
||||||
|
MOSTLY_Q8_0 = 7
|
||||||
|
MOSTLY_Q5_0 = 8
|
||||||
|
MOSTLY_Q5_1 = 9
|
||||||
|
MOSTLY_Q2_K = 10
|
||||||
|
MOSTLY_Q3_K_S = 11
|
||||||
|
MOSTLY_Q3_K_M = 12
|
||||||
|
MOSTLY_Q3_K_L = 13
|
||||||
|
MOSTLY_Q4_K_S = 14
|
||||||
|
MOSTLY_Q4_K_M = 15
|
||||||
|
MOSTLY_Q5_K_S = 16
|
||||||
|
MOSTLY_Q5_K_M = 17
|
||||||
|
MOSTLY_Q6_K = 18
|
||||||
|
|
||||||
class Hyperparameters:
|
class Hyperparameters:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.n_vocab = self.n_embd = self.n_mult = self.n_head = self.n_layer = self.n_rot = self.ftype = 0
|
self.n_vocab = self.n_embd = self.n_mult = self.n_head = self.n_layer = self.n_rot = self.n_ff = 0
|
||||||
self.n_ff = 0
|
self.ftype = GGMLFType.ALL_F32
|
||||||
|
|
||||||
def set_n_ff(self, model):
|
def set_n_ff(self, model):
|
||||||
ff_tensor_idx = model.tensor_map.get(b'layers.0.feed_forward.w1.weight')
|
ff_tensor_idx = model.tensor_map.get(b'layers.0.feed_forward.w1.weight')
|
||||||
|
@ -53,16 +78,21 @@ class Hyperparameters:
|
||||||
self.n_head,
|
self.n_head,
|
||||||
self.n_layer,
|
self.n_layer,
|
||||||
self.n_rot,
|
self.n_rot,
|
||||||
self.ftype,
|
ftype,
|
||||||
) = struct.unpack('<7I', data[offset:offset + (4 * 7)])
|
) = struct.unpack('<7I', data[offset:offset + (4 * 7)])
|
||||||
|
try:
|
||||||
|
self.ftype = GGMLFType(ftype)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f'Invalid ftype {ftype}')
|
||||||
return 4 * 7
|
return 4 * 7
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'<Hyperparameters: n_vocab={self.n_vocab}, n_embd={self.n_embd}, n_mult={self.n_mult}, n_head={self.n_head}, n_layer={self.n_layer}, n_rot={self.n_rot}, n_ff={self.n_ff}, ftype={self.ftype}>'
|
return f'<Hyperparameters: n_vocab={self.n_vocab}, n_embd={self.n_embd}, n_mult={self.n_mult}, n_head={self.n_head}, n_layer={self.n_layer}, n_rot={self.n_rot}, n_ff={self.n_ff}, ftype={self.ftype.name}>'
|
||||||
|
|
||||||
class Vocab:
|
class Vocab:
|
||||||
def __init__(self):
|
def __init__(self, load_scores = True):
|
||||||
self.items = []
|
self.items = []
|
||||||
|
self.load_scores = load_scores
|
||||||
|
|
||||||
def load(self, data, offset, n_vocab):
|
def load(self, data, offset, n_vocab):
|
||||||
orig_offset = offset
|
orig_offset = offset
|
||||||
|
@ -70,20 +100,24 @@ class Vocab:
|
||||||
itemlen = struct.unpack('<I', data[offset:offset + 4])[0]
|
itemlen = struct.unpack('<I', data[offset:offset + 4])[0]
|
||||||
assert itemlen < 4096, 'Absurd vocab item length'
|
assert itemlen < 4096, 'Absurd vocab item length'
|
||||||
offset += 4
|
offset += 4
|
||||||
vocab = bytes(data[offset:offset + itemlen])
|
item_text = bytes(data[offset:offset + itemlen])
|
||||||
offset += itemlen
|
offset += itemlen
|
||||||
score = struct.unpack('<f', data[offset:offset + 4])[0]
|
if self.load_scores:
|
||||||
offset += 4
|
item_score = struct.unpack('<f', data[offset:offset + 4])[0]
|
||||||
self.items.append((vocab, score))
|
offset += 4
|
||||||
|
else:
|
||||||
|
item_score = 0.0
|
||||||
|
self.items.append((item_text, item_score))
|
||||||
return offset - orig_offset
|
return offset - orig_offset
|
||||||
|
|
||||||
class Tensor:
|
class Tensor:
|
||||||
def __init__(self):
|
def __init__(self, use_padding = True):
|
||||||
self.name = None
|
self.name = None
|
||||||
self.dims: tuple[int, ...] = ()
|
self.dims: tuple[int, ...] = ()
|
||||||
self.dtype = None
|
self.dtype = None
|
||||||
self.start_offset = 0
|
self.start_offset = 0
|
||||||
self.len_bytes = np.int64(0)
|
self.len_bytes = np.int64(0)
|
||||||
|
self.use_padding = use_padding
|
||||||
|
|
||||||
def load(self, data, offset):
|
def load(self, data, offset):
|
||||||
orig_offset = offset
|
orig_offset = offset
|
||||||
|
@ -99,7 +133,7 @@ class Tensor:
|
||||||
offset += 4 * n_dims
|
offset += 4 * n_dims
|
||||||
self.name = bytes(data[offset:offset + name_len])
|
self.name = bytes(data[offset:offset + name_len])
|
||||||
offset += name_len
|
offset += name_len
|
||||||
pad = ((offset + 31) & ~31) - offset
|
pad = ((offset + 31) & ~31) - offset if self.use_padding else 0
|
||||||
offset += pad
|
offset += pad
|
||||||
n_elems = np.prod(self.dims)
|
n_elems = np.prod(self.dims)
|
||||||
n_bytes = np.int64(np.int64(n_elems) * np.int64(tysize)) // np.int64(blksize)
|
n_bytes = np.int64(np.int64(n_elems) * np.int64(tysize)) // np.int64(blksize)
|
||||||
|
@ -109,7 +143,7 @@ class Tensor:
|
||||||
# print(n_dims, name_len, dtype, self.dims, self.name, pad)
|
# print(n_dims, name_len, dtype, self.dims, self.name, pad)
|
||||||
return offset - orig_offset
|
return offset - orig_offset
|
||||||
|
|
||||||
class GGMLV3Model:
|
class GGMLModel:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.hyperparameters = None
|
self.hyperparameters = None
|
||||||
self.vocab = None
|
self.vocab = None
|
||||||
|
@ -117,20 +151,46 @@ class GGMLV3Model:
|
||||||
self.tensors = []
|
self.tensors = []
|
||||||
|
|
||||||
def validate_header(self, data, offset):
|
def validate_header(self, data, offset):
|
||||||
if bytes(data[offset:offset + 4]) != b'tjgg' or struct.unpack('<I', data[offset + 4:offset + 8])[0] != 3:
|
magic = bytes(data[offset:offset + 4])
|
||||||
raise ValueError('Only GGJTv3 supported')
|
if magic == b'GGUF':
|
||||||
return 8
|
raise ValueError('File is already in GGUF format.')
|
||||||
|
if magic == b'lmgg':
|
||||||
|
self.file_format = GGMLFormat.GGML
|
||||||
|
self.format_version = 1
|
||||||
|
return 4
|
||||||
|
version = struct.unpack('<I', data[offset + 4:offset + 8])[0]
|
||||||
|
if magic == b'fmgg':
|
||||||
|
if version != 1:
|
||||||
|
raise ValueError(f'Cannot handle unexpected GGMF file version {version}')
|
||||||
|
self.file_format = GGMLFormat.GGMF
|
||||||
|
self.format_version = version
|
||||||
|
return 8
|
||||||
|
if magic == b'tjgg':
|
||||||
|
if version < 1 or version > 3:
|
||||||
|
raise ValueError(f'Cannot handle unexpected GGJT file version {version}')
|
||||||
|
self.file_format = GGMLFormat.GGJT
|
||||||
|
self.format_version = version
|
||||||
|
return 8
|
||||||
|
raise ValueError(f"Unexpected file magic {magic!r}! This doesn't look like a GGML format file.")
|
||||||
|
|
||||||
|
def validate_conversion(self, ftype):
|
||||||
|
if (self.file_format < GGMLFormat.GGJT or self.format_version < 2) and ftype not in (GGMLFType.ALL_F32, GGMLFType.MOSTLY_F16):
|
||||||
|
raise ValueError(f'Quantizations changed in GGJTv2. Can only convert unquantized GGML files older than GGJTv2. Sorry, your {self.file_format.name}v{self.format_version} file of type {ftype.name} is not eligible for conversion.')
|
||||||
|
if (self.file_format == GGMLFormat.GGJT and self.format_version == 2) and ftype in (GGMLFType.MOSTLY_Q4_0, GGMLFType.MOSTLY_Q4_1, GGMLFType.MOSTLY_Q4_1_SOME_F16, GGMLFType.MOSTLY_Q8_0):
|
||||||
|
raise ValueError(f'Q4 and Q8 quantizations changed in GGJTv3. Sorry, your {self.file_format.name}v{self.format_version} file of type {ftype.name} is not eligible for conversion.')
|
||||||
|
|
||||||
def load(self, data, offset):
|
def load(self, data, offset):
|
||||||
offset += self.validate_header(data, offset)
|
offset += self.validate_header(data, offset)
|
||||||
hp = Hyperparameters()
|
hp = Hyperparameters()
|
||||||
offset += hp.load(data, offset)
|
offset += hp.load(data, offset)
|
||||||
vocab = Vocab()
|
print(f'* File format: {self.file_format.name}v{self.format_version} with ftype {hp.ftype.name}')
|
||||||
|
self.validate_conversion(hp.ftype)
|
||||||
|
vocab = Vocab(load_scores = self.file_format > GGMLFormat.GGML)
|
||||||
offset += vocab.load(data, offset, hp.n_vocab)
|
offset += vocab.load(data, offset, hp.n_vocab)
|
||||||
tensors: list[Tensor] = []
|
tensors: list[Tensor] = []
|
||||||
tensor_map = {}
|
tensor_map = {}
|
||||||
while offset < len(data):
|
while offset < len(data):
|
||||||
tensor = Tensor()
|
tensor = Tensor(use_padding = self.file_format > GGMLFormat.GGMF)
|
||||||
offset += tensor.load(data, offset)
|
offset += tensor.load(data, offset)
|
||||||
tensor_map[tensor.name] = len(tensors)
|
tensor_map[tensor.name] = len(tensors)
|
||||||
tensors.append(tensor)
|
tensors.append(tensor)
|
||||||
|
@ -195,6 +255,7 @@ class GGMLToGGUF:
|
||||||
if name is not None:
|
if name is not None:
|
||||||
gguf_writer.add_name(name)
|
gguf_writer.add_name(name)
|
||||||
gguf_writer.add_description(desc)
|
gguf_writer.add_description(desc)
|
||||||
|
gguf_writer.add_file_type(int(hp.ftype))
|
||||||
if self.params_override is not None:
|
if self.params_override is not None:
|
||||||
po = self.params_override
|
po = self.params_override
|
||||||
assert po.n_embd == hp.n_embd, 'Model hyperparams mismatch'
|
assert po.n_embd == hp.n_embd, 'Model hyperparams mismatch'
|
||||||
|
@ -330,7 +391,7 @@ def main():
|
||||||
print(f'* Using config: {cfg}')
|
print(f'* Using config: {cfg}')
|
||||||
print('\n=== WARNING === Be aware that this conversion script is best-effort. Use a native GGUF model if possible. === WARNING ===\n')
|
print('\n=== WARNING === Be aware that this conversion script is best-effort. Use a native GGUF model if possible. === WARNING ===\n')
|
||||||
data = np.memmap(cfg.input, mode = 'r')
|
data = np.memmap(cfg.input, mode = 'r')
|
||||||
model = GGMLV3Model()
|
model = GGMLModel()
|
||||||
print('* Scanning GGML input file')
|
print('* Scanning GGML input file')
|
||||||
offset = model.load(data, 0)
|
offset = model.load(data, 0)
|
||||||
print(f'* GGML model hyperparameters: {model.hyperparameters}')
|
print(f'* GGML model hyperparameters: {model.hyperparameters}')
|
||||||
|
@ -345,6 +406,8 @@ def main():
|
||||||
print(f'* Special vocab: {special_vocab}')
|
print(f'* Special vocab: {special_vocab}')
|
||||||
else:
|
else:
|
||||||
print('\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n')
|
print('\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n')
|
||||||
|
if model.file_format == GGMLFormat.GGML:
|
||||||
|
print('! This is a very old GGML file that does not contain vocab scores. Strongly recommend using model metadata!')
|
||||||
converter = GGMLToGGUF(model, data, cfg, params_override = params_override, vocab_override = vocab_override, special_vocab = special_vocab)
|
converter = GGMLToGGUF(model, data, cfg, params_override = params_override, vocab_override = vocab_override, special_vocab = special_vocab)
|
||||||
converter.save()
|
converter.save()
|
||||||
print(f'* Successful completion. Output saved to: {cfg.output}')
|
print(f'* Successful completion. Output saved to: {cfg.output}')
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue