Merge branch 'master' into fix-tokenizer
This commit is contained in:
commit
c767746399
6 changed files with 221 additions and 115 deletions
206
convert.py
206
convert.py
|
@ -3,6 +3,7 @@
|
|||
import gguf
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
import copy
|
||||
import enum
|
||||
import faulthandler
|
||||
|
@ -17,13 +18,14 @@ import re
|
|||
import signal
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
import zipfile
|
||||
import numpy as np
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, Union)
|
||||
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, TypeVar, Union)
|
||||
from sentencepiece import SentencePieceProcessor # type: ignore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -37,30 +39,70 @@ NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
|
|||
ARCH=gguf.MODEL_ARCH.LLAMA
|
||||
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
|
||||
|
||||
DEFAULT_CONCURRENCY = 8
|
||||
#
|
||||
# data types
|
||||
#
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UnquantizedDataType:
|
||||
class DataType:
|
||||
name: str
|
||||
dtype: 'np.dtype[Any]'
|
||||
valid_conversions: List[str]
|
||||
|
||||
DT_F16 = UnquantizedDataType('F16')
|
||||
DT_F32 = UnquantizedDataType('F32')
|
||||
DT_I32 = UnquantizedDataType('I32')
|
||||
DT_BF16 = UnquantizedDataType('BF16')
|
||||
def elements_to_bytes(self, n_elements: int) -> int:
|
||||
return n_elements * self.dtype.itemsize
|
||||
|
||||
DataType = Union[UnquantizedDataType]
|
||||
@dataclass(frozen=True)
|
||||
class UnquantizedDataType(DataType):
|
||||
pass
|
||||
|
||||
DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
|
||||
DT_BF16: np.dtype(np.uint16),
|
||||
DT_F16: np.dtype(np.float16),
|
||||
DT_F32: np.dtype(np.float32),
|
||||
DT_I32: np.dtype(np.int32),
|
||||
}
|
||||
DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
|
||||
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
|
||||
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
|
||||
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
|
||||
|
||||
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \
|
||||
{dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()}
|
||||
@dataclass(frozen=True)
|
||||
class QuantizedDataType(DataType):
|
||||
block_size: int
|
||||
quantized_dtype: 'np.dtype[Any]'
|
||||
ggml_type: gguf.GGMLQuantizationType
|
||||
|
||||
def quantize(self, arr: NDArray) -> NDArray:
|
||||
raise NotImplementedError(f'Quantization for {self.name} not implemented')
|
||||
|
||||
def elements_to_bytes(self, n_elements: int) -> int:
|
||||
assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}'
|
||||
return self.quantized_dtype.itemsize * (n_elements // self.block_size)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Q8_0QuantizedDataType(QuantizedDataType):
|
||||
# Mini Q8_0 quantization in Python!
|
||||
def quantize(self, arr: NDArray) -> NDArray:
|
||||
assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}'
|
||||
assert arr.dtype == np.float32, f'Bad array type {arr.dtype}'
|
||||
n_blocks = arr.size // self.block_size
|
||||
blocks = arr.reshape((n_blocks, self.block_size))
|
||||
# Much faster implementation of block quantization contributed by @Cebtenzzre
|
||||
def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[Tuple[Any, Any]]:
|
||||
d = abs(blocks).max(axis = 1) / np.float32(127)
|
||||
with np.errstate(divide = 'ignore'):
|
||||
qs = (blocks / d[:, None]).round()
|
||||
qs[d == 0] = 0
|
||||
yield from zip(d, qs)
|
||||
return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype)
|
||||
|
||||
DT_Q8_0 = Q8_0QuantizedDataType('Q8_0',
|
||||
dtype = np.dtype(np.float32), valid_conversions = [],
|
||||
ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32,
|
||||
quantized_dtype = np.dtype([('d', '<f2'), ('qs', 'i1', (32,))]))
|
||||
|
||||
# Quantized types skipped here because they may also map to np.float32
|
||||
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = {}
|
||||
for dt in (DT_BF16, DT_F16, DT_F32, DT_I32):
|
||||
if dt.dtype in NUMPY_TYPE_TO_DATA_TYPE:
|
||||
raise ValueError(f'Invalid duplicate data type {dt}')
|
||||
NUMPY_TYPE_TO_DATA_TYPE[dt.dtype] = dt
|
||||
|
||||
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
|
||||
'BF16': DT_BF16,
|
||||
|
@ -73,20 +115,22 @@ SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
|
|||
# TODO: rename to LLAMAFileType
|
||||
# TODO: move to `gguf.py`
|
||||
class GGMLFileType(enum.IntEnum):
|
||||
AllF32 = 0
|
||||
MostlyF16 = 1 # except 1d tensors
|
||||
AllF32 = 0
|
||||
MostlyF16 = 1 # except 1d tensors
|
||||
MostlyQ8_0 = 7 # except 1d tensors
|
||||
|
||||
def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
|
||||
if len(tensor.shape) == 1:
|
||||
# 1D tensors are always F32.
|
||||
return DT_F32
|
||||
elif self == GGMLFileType.AllF32:
|
||||
return DT_F32
|
||||
elif self == GGMLFileType.MostlyF16:
|
||||
return DT_F16
|
||||
else:
|
||||
dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self)
|
||||
if dt is None:
|
||||
raise ValueError(self)
|
||||
# 1D tensors are always F32.
|
||||
return dt if len(tensor.shape) > 1 else DT_F32
|
||||
|
||||
GGML_FILE_TYPE_TO_DATA_TYPE: Dict[GGMLFileType, DataType] = {
|
||||
GGMLFileType.AllF32 : DT_F32,
|
||||
GGMLFileType.MostlyF16 : DT_F16,
|
||||
GGMLFileType.MostlyQ8_0: DT_Q8_0,
|
||||
}
|
||||
|
||||
#
|
||||
# hparams loading
|
||||
|
@ -415,7 +459,7 @@ class UnquantizedTensor(Tensor):
|
|||
self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
|
||||
|
||||
def astype(self, data_type: DataType) -> Tensor:
|
||||
dtype = DATA_TYPE_TO_NUMPY[data_type]
|
||||
dtype = data_type.dtype
|
||||
if self.data_type == DT_BF16:
|
||||
self.ndarray = bf16_to_fp32(self.ndarray)
|
||||
return UnquantizedTensor(self.ndarray.astype(dtype))
|
||||
|
@ -454,22 +498,6 @@ def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, conv
|
|||
GGMLCompatibleTensor = Union[UnquantizedTensor]
|
||||
|
||||
|
||||
class DeferredPermutedTensor(Tensor):
|
||||
def __init__(self, base: Tensor, n_head: int, n_head_kv: int) -> None:
|
||||
self.base = base
|
||||
self.n_head = n_head
|
||||
self.data_type = self.base.data_type
|
||||
|
||||
def astype(self, data_type: DataType) -> Tensor:
|
||||
return self.base.astype(data_type).permute(self.n_head, self.n_head_kv)
|
||||
|
||||
def to_ggml(self) -> GGMLCompatibleTensor:
|
||||
return self.base.to_ggml().permute(self.n_head, self.n_head_kv)
|
||||
|
||||
def permute(self, n_head: int, n_head_kv: int) -> Tensor:
|
||||
raise Exception("shouldn't permute twice")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LazyTensor:
|
||||
_load: Callable[[], Tensor]
|
||||
|
@ -479,7 +507,9 @@ class LazyTensor:
|
|||
|
||||
def load(self) -> Tensor:
|
||||
ret = self._load()
|
||||
assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description)
|
||||
# Should be okay if it maps to the same numpy type?
|
||||
assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \
|
||||
(self.data_type, ret.data_type, self.description)
|
||||
return ret
|
||||
|
||||
def astype(self, data_type: DataType) -> 'LazyTensor':
|
||||
|
@ -490,8 +520,8 @@ class LazyTensor:
|
|||
return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
|
||||
|
||||
def validate_conversion_to(self, data_type: DataType) -> None:
|
||||
if data_type == self.data_type:
|
||||
return
|
||||
if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions:
|
||||
raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.')
|
||||
|
||||
|
||||
LazyModel = Dict[str, LazyTensor]
|
||||
|
@ -617,9 +647,7 @@ class LazyUnpickler(pickle.Unpickler):
|
|||
info = self.zip_file.getinfo(filename)
|
||||
|
||||
def load(offset: int, elm_count: int) -> NDArray:
|
||||
dtype = DATA_TYPE_TO_NUMPY.get(data_type)
|
||||
if dtype is None:
|
||||
raise Exception("tensor stored in unsupported format")
|
||||
dtype = data_type.dtype
|
||||
fp = self.zip_file.open(info)
|
||||
fp.seek(offset * dtype.itemsize)
|
||||
size = elm_count * dtype.itemsize
|
||||
|
@ -683,7 +711,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
|
|||
|
||||
def convert(info: Dict[str, Any]) -> LazyTensor:
|
||||
data_type = SAFETENSORS_DATA_TYPES[info['dtype']]
|
||||
numpy_dtype = DATA_TYPE_TO_NUMPY[data_type]
|
||||
numpy_dtype = data_type.dtype
|
||||
shape: List[int] = info['shape']
|
||||
begin, end = info['data_offsets']
|
||||
assert 0 <= begin <= end <= len(byte_buf)
|
||||
|
@ -723,23 +751,35 @@ def lazy_load_file(path: Path) -> ModelPlus:
|
|||
In = TypeVar('In')
|
||||
Out = TypeVar('Out')
|
||||
|
||||
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int) -> Iterable[Out]:
|
||||
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, factory: Callable = ThreadPoolExecutor) -> Iterable[Out]:
|
||||
'''Parallel map, but with backpressure. If the caller doesn't call `next`
|
||||
fast enough, this will stop calling `func` at some point rather than
|
||||
letting results pile up in memory. Specifically, there is a max of one
|
||||
output value buffered per thread.'''
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
if concurrency < 2:
|
||||
yield from map(func, iterable)
|
||||
# Not reached.
|
||||
iterable = iter(iterable)
|
||||
with factory(max_workers = max_workers) as executor:
|
||||
futures: List[concurrent.futures.Future[Out]] = []
|
||||
items_rev = list(iterable)[::-1]
|
||||
for i in range(min(concurrency, len(items_rev))):
|
||||
futures.append(executor.submit(func, items_rev.pop()))
|
||||
done = False
|
||||
for _ in range(concurrency):
|
||||
try:
|
||||
futures.append(executor.submit(func, next(iterable)))
|
||||
except StopIteration:
|
||||
done = True
|
||||
break
|
||||
|
||||
while futures:
|
||||
result = futures.pop(0).result()
|
||||
if items_rev:
|
||||
futures.append(executor.submit(func, items_rev.pop()))
|
||||
while not done and len(futures) < concurrency:
|
||||
try:
|
||||
futures.append(executor.submit(func, next(iterable)))
|
||||
except StopIteration:
|
||||
done = True
|
||||
break
|
||||
yield result
|
||||
|
||||
|
||||
def check_vocab_size(params: Params, vocab: Vocab) -> None:
|
||||
if params.n_vocab != vocab.vocab_size:
|
||||
assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab)
|
||||
|
@ -804,12 +844,11 @@ class OutputFile:
|
|||
self.gguf.add_token_types(toktypes)
|
||||
|
||||
def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
|
||||
n_elements = 1
|
||||
for dim in tensor.shape:
|
||||
n_elements *= dim
|
||||
data_type = DATA_TYPE_TO_NUMPY[tensor.data_type]
|
||||
data_nbytes = n_elements * data_type.itemsize
|
||||
self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes)
|
||||
n_elements = int(np.prod(tensor.shape))
|
||||
raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
|
||||
data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
|
||||
data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
|
||||
self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype)
|
||||
|
||||
def write_meta(self) -> None:
|
||||
self.gguf.write_header_to_file()
|
||||
|
@ -835,7 +874,20 @@ class OutputFile:
|
|||
of.close()
|
||||
|
||||
@staticmethod
|
||||
def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None:
|
||||
def do_item(item: Tuple[str, LazyTensor]) -> Tuple[DataType, NDArray]:
|
||||
name, lazy_tensor = item
|
||||
tensor = lazy_tensor.load().to_ggml()
|
||||
return (lazy_tensor.data_type, tensor.ndarray)
|
||||
|
||||
@staticmethod
|
||||
def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray:
|
||||
dt, arr = item
|
||||
if not isinstance(dt, QuantizedDataType):
|
||||
return arr
|
||||
return dt.quantize(arr)
|
||||
|
||||
@staticmethod
|
||||
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
|
||||
check_vocab_size(params, vocab)
|
||||
|
||||
of = OutputFile(fname_out)
|
||||
|
@ -851,16 +903,19 @@ class OutputFile:
|
|||
of.write_meta()
|
||||
of.write_tensor_info()
|
||||
|
||||
def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
|
||||
name, lazy_tensor = item
|
||||
return lazy_tensor.load().to_ggml().ndarray
|
||||
|
||||
# tensor data
|
||||
ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8)
|
||||
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
|
||||
if ftype == GGMLFileType.MostlyQ8_0:
|
||||
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor)
|
||||
else:
|
||||
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
|
||||
|
||||
start = time.time()
|
||||
for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
|
||||
elapsed = time.time() - start
|
||||
size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
|
||||
padi = len(str(len(model)))
|
||||
print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type}")
|
||||
print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}")
|
||||
of.gguf.write_tensor_data(ndarray)
|
||||
|
||||
of.close()
|
||||
|
@ -872,6 +927,8 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi
|
|||
return GGMLFileType.AllF32
|
||||
if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)):
|
||||
return GGMLFileType.MostlyF16
|
||||
if output_type_str == "q8_0":
|
||||
return GGMLFileType.MostlyQ8_0
|
||||
|
||||
name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
|
||||
|
||||
|
@ -918,7 +975,7 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
|
|||
print(f"skipping tensor {name_new}")
|
||||
continue
|
||||
else:
|
||||
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}")
|
||||
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
|
||||
out[name_new] = lazy_tensor
|
||||
|
||||
return out
|
||||
|
@ -1023,6 +1080,7 @@ def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
|
|||
namestr = {
|
||||
GGMLFileType.AllF32: "f32",
|
||||
GGMLFileType.MostlyF16: "f16",
|
||||
GGMLFileType.MostlyQ8_0:"q8_0",
|
||||
}[file_type]
|
||||
ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
|
||||
if ret in model_paths:
|
||||
|
@ -1046,12 +1104,13 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
|||
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
|
||||
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
|
||||
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
|
||||
parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
|
||||
parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
|
||||
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
|
||||
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
|
||||
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
|
||||
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
|
||||
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
|
||||
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
|
||||
args = parser.parse_args(args_in)
|
||||
|
||||
if args.dump_single:
|
||||
|
@ -1073,6 +1132,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
|||
params.ftype = {
|
||||
"f32": GGMLFileType.AllF32,
|
||||
"f16": GGMLFileType.MostlyF16,
|
||||
"q8_0": GGMLFileType.MostlyQ8_0,
|
||||
}[args.outtype]
|
||||
|
||||
print(f"params = {params}")
|
||||
|
@ -1104,7 +1164,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
|||
params.ftype = ftype
|
||||
print(f"Writing {outfile}, format {ftype}")
|
||||
|
||||
OutputFile.write_all(outfile, params, model, vocab)
|
||||
OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency)
|
||||
print(f"Wrote {outfile}")
|
||||
|
||||
|
||||
|
|
|
@ -598,7 +598,12 @@ int main(int argc, char ** argv) {
|
|||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||
last_n_repeat, alpha_frequency, alpha_presence);
|
||||
if (!penalize_nl) {
|
||||
logits[llama_token_nl(ctx)] = nl_logit;
|
||||
for (size_t idx = 0; idx < candidates_p.size; idx++) {
|
||||
if (candidates_p.data[idx].id == llama_token_nl(ctx)) {
|
||||
candidates_p.data[idx].logit = nl_logit;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (grammar != NULL) {
|
||||
|
|
|
@ -351,6 +351,7 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
|
||||
|
||||
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
|
||||
fprintf(stderr, "================================= is_spm = %d\n", is_spm);
|
||||
|
||||
// This is needed as usual for LLaMA models
|
||||
const bool add_bos = is_spm;
|
||||
|
@ -406,6 +407,8 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
double acc = 0.0f;
|
||||
const int n_vocab = llama_n_vocab(ctx);
|
||||
|
||||
std::vector<std::vector<int>> ending_tokens(4);
|
||||
|
||||
std::vector<float> tok_logits(n_vocab);
|
||||
|
||||
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
|
||||
|
@ -413,11 +416,21 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos);
|
||||
size_t context_size = context_embd.size();
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
ending_tokens[i] = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[i], add_bos);
|
||||
for (int k = 0; k < int(context_size); ++k) {
|
||||
if (ending_tokens[i][k] != context_embd[k]) {
|
||||
fprintf(stderr, "Oops: ending %d of task %d differs from context at position %d\n",i,int(task_idx),k);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Do the 1st ending
|
||||
// In this case we include the context when evaluating
|
||||
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
|
||||
//auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
|
||||
auto query_embd = ending_tokens[0];
|
||||
auto query_size = query_embd.size();
|
||||
//printf("First query: %d\n",(int)query_size);
|
||||
|
||||
// Stop if query wont fit the ctx window
|
||||
if (query_size > (size_t)params.n_ctx) {
|
||||
|
@ -462,7 +475,8 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {
|
||||
|
||||
// Tokenize the query
|
||||
query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false);
|
||||
query_embd.resize(ending_tokens[ending_idx].size() - context_size);
|
||||
std::memcpy(query_embd.data(), ending_tokens[ending_idx].data() + context_size, query_embd.size()*sizeof(int));
|
||||
query_size = query_embd.size();
|
||||
|
||||
// Stop if query wont fit the ctx window
|
||||
|
|
12
flake.lock
generated
12
flake.lock
generated
|
@ -5,11 +5,11 @@
|
|||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1685518550,
|
||||
"narHash": "sha256-o2d0KcvaXzTrPRIo0kOLV0/QXHhDQ5DTi+OxcjO8xqY=",
|
||||
"lastModified": 1692799911,
|
||||
"narHash": "sha256-3eihraek4qL744EvQXsK1Ha6C3CR7nnT8X2qWap4RNk=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "a1720a10a6cfe8234c0e93907ffe81be440f4cef",
|
||||
"rev": "f9e7cf818399d17d347f847525c5a5a8032e4e44",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
@ -20,11 +20,11 @@
|
|||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1685931219,
|
||||
"narHash": "sha256-8EWeOZ6LKQfgAjB/USffUSELPRjw88A+xTcXnOUvO5M=",
|
||||
"lastModified": 1692913444,
|
||||
"narHash": "sha256-1SvMQm2DwofNxXVtNWWtIcTh7GctEVrS/Xel/mdc6iY=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "7409480d5c8584a1a83c422530419efe4afb0d19",
|
||||
"rev": "18324978d632ffc55ef1d928e81630c620f4f447",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
|
54
flake.nix
54
flake.nix
|
@ -6,6 +6,9 @@
|
|||
outputs = { self, nixpkgs, flake-utils }:
|
||||
flake-utils.lib.eachDefaultSystem (system:
|
||||
let
|
||||
name = "llama.cpp";
|
||||
src = ./.;
|
||||
meta.mainProgram = "llama";
|
||||
inherit (pkgs.stdenv) isAarch32 isAarch64 isDarwin;
|
||||
buildInputs = with pkgs; [ openmpi ];
|
||||
osSpecific = with pkgs; buildInputs ++
|
||||
|
@ -21,11 +24,17 @@
|
|||
CoreGraphics
|
||||
CoreVideo
|
||||
]
|
||||
else if isDarwin then
|
||||
with pkgs.darwin.apple_sdk.frameworks; [
|
||||
Accelerate
|
||||
CoreGraphics
|
||||
CoreVideo
|
||||
]
|
||||
else
|
||||
with pkgs; [ openblas ]
|
||||
);
|
||||
pkgs = import nixpkgs { inherit system; };
|
||||
nativeBuildInputs = with pkgs; [ cmake pkgconfig ];
|
||||
nativeBuildInputs = with pkgs; [ cmake ninja pkgconfig ];
|
||||
llama-python =
|
||||
pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
|
||||
postPatch = ''
|
||||
|
@ -38,35 +47,35 @@
|
|||
mv $out/bin/server $out/bin/llama-server
|
||||
'';
|
||||
cmakeFlags = [ "-DLLAMA_BUILD_SERVER=ON" "-DLLAMA_MPI=ON" "-DBUILD_SHARED_LIBS=ON" "-DCMAKE_SKIP_BUILD_RPATH=ON" ];
|
||||
in {
|
||||
in
|
||||
{
|
||||
packages.default = pkgs.stdenv.mkDerivation {
|
||||
name = "llama.cpp";
|
||||
src = ./.;
|
||||
postPatch = postPatch;
|
||||
nativeBuildInputs = nativeBuildInputs;
|
||||
buildInputs = osSpecific;
|
||||
inherit name src meta postPatch nativeBuildInputs buildInputs postInstall;
|
||||
cmakeFlags = cmakeFlags
|
||||
++ (if isAarch64 && isDarwin then [
|
||||
"-DCMAKE_C_FLAGS=-D__ARM_FEATURE_DOTPROD=1"
|
||||
"-DLLAMA_METAL=ON"
|
||||
] else [
|
||||
"-DLLAMA_BLAS=ON"
|
||||
"-DLLAMA_BLAS_VENDOR=OpenBLAS"
|
||||
"-DCMAKE_C_FLAGS=-D__ARM_FEATURE_DOTPROD=1"
|
||||
"-DLLAMA_METAL=ON"
|
||||
] else [
|
||||
"-DLLAMA_BLAS=ON"
|
||||
"-DLLAMA_BLAS_VENDOR=OpenBLAS"
|
||||
]);
|
||||
postInstall = postInstall;
|
||||
meta.mainProgram = "llama";
|
||||
};
|
||||
packages.opencl = pkgs.stdenv.mkDerivation {
|
||||
name = "llama.cpp";
|
||||
src = ./.;
|
||||
postPatch = postPatch;
|
||||
nativeBuildInputs = nativeBuildInputs;
|
||||
inherit name src meta postPatch nativeBuildInputs postInstall;
|
||||
buildInputs = with pkgs; buildInputs ++ [ clblast ];
|
||||
cmakeFlags = cmakeFlags ++ [
|
||||
"-DLLAMA_CLBLAST=ON"
|
||||
];
|
||||
postInstall = postInstall;
|
||||
meta.mainProgram = "llama";
|
||||
};
|
||||
packages.rocm = pkgs.stdenv.mkDerivation {
|
||||
inherit name src meta postPatch nativeBuildInputs postInstall;
|
||||
buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ];
|
||||
cmakeFlags = cmakeFlags ++ [
|
||||
"-DLLAMA_HIPBLAS=1"
|
||||
"-DCMAKE_C_COMPILER=hipcc"
|
||||
"-DCMAKE_CXX_COMPILER=hipcc"
|
||||
"-DCMAKE_POSITION_INDEPENDENT_CODE=ON"
|
||||
];
|
||||
};
|
||||
apps.llama-server = {
|
||||
type = "app";
|
||||
|
@ -80,8 +89,13 @@
|
|||
type = "app";
|
||||
program = "${self.packages.${system}.default}/bin/llama";
|
||||
};
|
||||
apps.quantize = {
|
||||
type = "app";
|
||||
program = "${self.packages.${system}.default}/bin/quantize";
|
||||
};
|
||||
apps.default = self.apps.${system}.llama;
|
||||
devShells.default = pkgs.mkShell {
|
||||
buildInputs = [ llama-python ];
|
||||
packages = nativeBuildInputs ++ osSpecific;
|
||||
};
|
||||
});
|
||||
|
|
37
llama.cpp
37
llama.cpp
|
@ -1,9 +1,6 @@
|
|||
// Defines fileno on msys:
|
||||
#ifndef _GNU_SOURCE
|
||||
#define _GNU_SOURCE
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#endif
|
||||
|
||||
#include "llama.h"
|
||||
|
@ -62,6 +59,9 @@
|
|||
#include <cinttypes>
|
||||
#include <climits>
|
||||
#include <cstdarg>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <fstream>
|
||||
|
@ -955,10 +955,10 @@ struct llama_vocab {
|
|||
id linefeed_id = 13;
|
||||
|
||||
int find_bpe_rank(std::string token_left, std::string token_right) const {
|
||||
replace_all(token_left, " ", "Ġ");
|
||||
replace_all(token_left, "\n", "Ċ");
|
||||
replace_all(token_right, " ", "Ġ");
|
||||
replace_all(token_right, "\n", "Ċ");
|
||||
replace_all(token_left, " ", "\u0120");
|
||||
replace_all(token_left, "\n", "\u010A");
|
||||
replace_all(token_right, " ", "\u0120");
|
||||
replace_all(token_right, "\n", "\u010A");
|
||||
|
||||
auto it = bpe_ranks.find(std::make_pair(token_left, token_right));
|
||||
if (it == bpe_ranks.end()) {
|
||||
|
@ -3894,7 +3894,7 @@ void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array *
|
|||
|
||||
// Calculate absolute value of second derivatives
|
||||
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
||||
second_derivatives[i] = abs(second_derivatives[i]);
|
||||
second_derivatives[i] = std::abs(second_derivatives[i]);
|
||||
}
|
||||
|
||||
// Normalize the second derivatives
|
||||
|
@ -4660,6 +4660,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
|
||||
std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname_inp, /*use_mmap*/ false));
|
||||
|
||||
llama_model model;
|
||||
llm_load_arch(*ml, model);
|
||||
llm_load_hparams(*ml, model, 0, 0, 0);
|
||||
|
||||
const size_t align = GGUF_DEFAULT_ALIGNMENT;
|
||||
struct gguf_context * ctx_out = gguf_init_empty();
|
||||
|
||||
|
@ -4685,6 +4689,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
++n_feed_forward_w2;
|
||||
}
|
||||
}
|
||||
if (n_attention_wv != n_feed_forward_w2 || (uint32_t)n_attention_wv != model.hparams.n_layer) {
|
||||
LLAMA_LOG_WARN("%s ============ Strange model: n_attention_wv = %d, n_feed_forward_w2 = %d, hparams.n_layer = %d\n",
|
||||
__func__, n_attention_wv, n_feed_forward_w2, model.hparams.n_layer);
|
||||
}
|
||||
|
||||
int i_attention_wv = 0;
|
||||
int i_feed_forward_w2 = 0;
|
||||
|
@ -4761,8 +4769,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
|
||||
if (name == tn(LLM_TENSOR_OUTPUT, "weight")) {
|
||||
int nx = tensor->ne[0];
|
||||
int ny = tensor->ne[1];
|
||||
if (nx % QK_K == 0 && ny % QK_K == 0) {
|
||||
if (nx % QK_K == 0) {
|
||||
new_type = GGML_TYPE_Q6_K;
|
||||
}
|
||||
} else if (name.find("attn_v.weight") != std::string::npos) {
|
||||
|
@ -4776,6 +4783,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
|
||||
else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) &&
|
||||
(i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K;
|
||||
if (model.type == MODEL_70B) {
|
||||
// In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
|
||||
// 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
|
||||
// nearly negligible increase in model size by quantizing this tensor with more bits:
|
||||
if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
|
||||
}
|
||||
++i_attention_wv;
|
||||
} else if (name.find("ffn_down.weight") != std::string::npos) {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
||||
|
@ -4805,8 +4818,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) {
|
||||
int nx = tensor->ne[0];
|
||||
int ny = tensor->ne[1];
|
||||
if (nx % QK_K != 0 || ny % QK_K != 0) {
|
||||
LLAMA_LOG_INFO("\n\nTensor sizes %d x %d are not divisible by %d, required for k-quants.\n",nx,ny,QK_K);
|
||||
if (nx % QK_K != 0) {
|
||||
LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for k-quants\n", __func__, nx, ny, QK_K);
|
||||
convert_incompatible_tensor = true;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue