Merge branch 'master' into fix-tokenizer

This commit is contained in:
Georgi Gerganov 2023-08-27 00:42:05 +03:00
commit c767746399
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 221 additions and 115 deletions

View file

@ -3,6 +3,7 @@
import gguf import gguf
import argparse import argparse
import concurrent.futures import concurrent.futures
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import copy import copy
import enum import enum
import faulthandler import faulthandler
@ -17,13 +18,14 @@ import re
import signal import signal
import struct import struct
import sys import sys
import time
import zipfile import zipfile
import numpy as np import numpy as np
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, Union) from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, TypeVar, Union)
from sentencepiece import SentencePieceProcessor # type: ignore from sentencepiece import SentencePieceProcessor # type: ignore
if TYPE_CHECKING: if TYPE_CHECKING:
@ -37,30 +39,70 @@ NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
ARCH=gguf.MODEL_ARCH.LLAMA ARCH=gguf.MODEL_ARCH.LLAMA
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH] NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
DEFAULT_CONCURRENCY = 8
# #
# data types # data types
# #
@dataclass(frozen=True) @dataclass(frozen=True)
class UnquantizedDataType: class DataType:
name: str name: str
dtype: 'np.dtype[Any]'
valid_conversions: List[str]
DT_F16 = UnquantizedDataType('F16') def elements_to_bytes(self, n_elements: int) -> int:
DT_F32 = UnquantizedDataType('F32') return n_elements * self.dtype.itemsize
DT_I32 = UnquantizedDataType('I32')
DT_BF16 = UnquantizedDataType('BF16')
DataType = Union[UnquantizedDataType] @dataclass(frozen=True)
class UnquantizedDataType(DataType):
pass
DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = { DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
DT_BF16: np.dtype(np.uint16), DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
DT_F16: np.dtype(np.float16), DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
DT_F32: np.dtype(np.float32), DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
DT_I32: np.dtype(np.int32),
}
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \ @dataclass(frozen=True)
{dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()} class QuantizedDataType(DataType):
block_size: int
quantized_dtype: 'np.dtype[Any]'
ggml_type: gguf.GGMLQuantizationType
def quantize(self, arr: NDArray) -> NDArray:
raise NotImplementedError(f'Quantization for {self.name} not implemented')
def elements_to_bytes(self, n_elements: int) -> int:
assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}'
return self.quantized_dtype.itemsize * (n_elements // self.block_size)
@dataclass(frozen=True)
class Q8_0QuantizedDataType(QuantizedDataType):
# Mini Q8_0 quantization in Python!
def quantize(self, arr: NDArray) -> NDArray:
assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}'
assert arr.dtype == np.float32, f'Bad array type {arr.dtype}'
n_blocks = arr.size // self.block_size
blocks = arr.reshape((n_blocks, self.block_size))
# Much faster implementation of block quantization contributed by @Cebtenzzre
def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[Tuple[Any, Any]]:
d = abs(blocks).max(axis = 1) / np.float32(127)
with np.errstate(divide = 'ignore'):
qs = (blocks / d[:, None]).round()
qs[d == 0] = 0
yield from zip(d, qs)
return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype)
DT_Q8_0 = Q8_0QuantizedDataType('Q8_0',
dtype = np.dtype(np.float32), valid_conversions = [],
ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32,
quantized_dtype = np.dtype([('d', '<f2'), ('qs', 'i1', (32,))]))
# Quantized types skipped here because they may also map to np.float32
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = {}
for dt in (DT_BF16, DT_F16, DT_F32, DT_I32):
if dt.dtype in NUMPY_TYPE_TO_DATA_TYPE:
raise ValueError(f'Invalid duplicate data type {dt}')
NUMPY_TYPE_TO_DATA_TYPE[dt.dtype] = dt
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = { SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
'BF16': DT_BF16, 'BF16': DT_BF16,
@ -73,20 +115,22 @@ SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
# TODO: rename to LLAMAFileType # TODO: rename to LLAMAFileType
# TODO: move to `gguf.py` # TODO: move to `gguf.py`
class GGMLFileType(enum.IntEnum): class GGMLFileType(enum.IntEnum):
AllF32 = 0 AllF32 = 0
MostlyF16 = 1 # except 1d tensors MostlyF16 = 1 # except 1d tensors
MostlyQ8_0 = 7 # except 1d tensors
def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType: def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
if len(tensor.shape) == 1: dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self)
# 1D tensors are always F32. if dt is None:
return DT_F32
elif self == GGMLFileType.AllF32:
return DT_F32
elif self == GGMLFileType.MostlyF16:
return DT_F16
else:
raise ValueError(self) raise ValueError(self)
# 1D tensors are always F32.
return dt if len(tensor.shape) > 1 else DT_F32
GGML_FILE_TYPE_TO_DATA_TYPE: Dict[GGMLFileType, DataType] = {
GGMLFileType.AllF32 : DT_F32,
GGMLFileType.MostlyF16 : DT_F16,
GGMLFileType.MostlyQ8_0: DT_Q8_0,
}
# #
# hparams loading # hparams loading
@ -415,7 +459,7 @@ class UnquantizedTensor(Tensor):
self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype] self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
def astype(self, data_type: DataType) -> Tensor: def astype(self, data_type: DataType) -> Tensor:
dtype = DATA_TYPE_TO_NUMPY[data_type] dtype = data_type.dtype
if self.data_type == DT_BF16: if self.data_type == DT_BF16:
self.ndarray = bf16_to_fp32(self.ndarray) self.ndarray = bf16_to_fp32(self.ndarray)
return UnquantizedTensor(self.ndarray.astype(dtype)) return UnquantizedTensor(self.ndarray.astype(dtype))
@ -454,22 +498,6 @@ def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, conv
GGMLCompatibleTensor = Union[UnquantizedTensor] GGMLCompatibleTensor = Union[UnquantizedTensor]
class DeferredPermutedTensor(Tensor):
def __init__(self, base: Tensor, n_head: int, n_head_kv: int) -> None:
self.base = base
self.n_head = n_head
self.data_type = self.base.data_type
def astype(self, data_type: DataType) -> Tensor:
return self.base.astype(data_type).permute(self.n_head, self.n_head_kv)
def to_ggml(self) -> GGMLCompatibleTensor:
return self.base.to_ggml().permute(self.n_head, self.n_head_kv)
def permute(self, n_head: int, n_head_kv: int) -> Tensor:
raise Exception("shouldn't permute twice")
@dataclass @dataclass
class LazyTensor: class LazyTensor:
_load: Callable[[], Tensor] _load: Callable[[], Tensor]
@ -479,7 +507,9 @@ class LazyTensor:
def load(self) -> Tensor: def load(self) -> Tensor:
ret = self._load() ret = self._load()
assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description) # Should be okay if it maps to the same numpy type?
assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \
(self.data_type, ret.data_type, self.description)
return ret return ret
def astype(self, data_type: DataType) -> 'LazyTensor': def astype(self, data_type: DataType) -> 'LazyTensor':
@ -490,8 +520,8 @@ class LazyTensor:
return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}') return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
def validate_conversion_to(self, data_type: DataType) -> None: def validate_conversion_to(self, data_type: DataType) -> None:
if data_type == self.data_type: if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions:
return raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.')
LazyModel = Dict[str, LazyTensor] LazyModel = Dict[str, LazyTensor]
@ -617,9 +647,7 @@ class LazyUnpickler(pickle.Unpickler):
info = self.zip_file.getinfo(filename) info = self.zip_file.getinfo(filename)
def load(offset: int, elm_count: int) -> NDArray: def load(offset: int, elm_count: int) -> NDArray:
dtype = DATA_TYPE_TO_NUMPY.get(data_type) dtype = data_type.dtype
if dtype is None:
raise Exception("tensor stored in unsupported format")
fp = self.zip_file.open(info) fp = self.zip_file.open(info)
fp.seek(offset * dtype.itemsize) fp.seek(offset * dtype.itemsize)
size = elm_count * dtype.itemsize size = elm_count * dtype.itemsize
@ -683,7 +711,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
def convert(info: Dict[str, Any]) -> LazyTensor: def convert(info: Dict[str, Any]) -> LazyTensor:
data_type = SAFETENSORS_DATA_TYPES[info['dtype']] data_type = SAFETENSORS_DATA_TYPES[info['dtype']]
numpy_dtype = DATA_TYPE_TO_NUMPY[data_type] numpy_dtype = data_type.dtype
shape: List[int] = info['shape'] shape: List[int] = info['shape']
begin, end = info['data_offsets'] begin, end = info['data_offsets']
assert 0 <= begin <= end <= len(byte_buf) assert 0 <= begin <= end <= len(byte_buf)
@ -723,23 +751,35 @@ def lazy_load_file(path: Path) -> ModelPlus:
In = TypeVar('In') In = TypeVar('In')
Out = TypeVar('Out') Out = TypeVar('Out')
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int) -> Iterable[Out]: def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, factory: Callable = ThreadPoolExecutor) -> Iterable[Out]:
'''Parallel map, but with backpressure. If the caller doesn't call `next` '''Parallel map, but with backpressure. If the caller doesn't call `next`
fast enough, this will stop calling `func` at some point rather than fast enough, this will stop calling `func` at some point rather than
letting results pile up in memory. Specifically, there is a max of one letting results pile up in memory. Specifically, there is a max of one
output value buffered per thread.''' output value buffered per thread.'''
with concurrent.futures.ThreadPoolExecutor() as executor: if concurrency < 2:
yield from map(func, iterable)
# Not reached.
iterable = iter(iterable)
with factory(max_workers = max_workers) as executor:
futures: List[concurrent.futures.Future[Out]] = [] futures: List[concurrent.futures.Future[Out]] = []
items_rev = list(iterable)[::-1] done = False
for i in range(min(concurrency, len(items_rev))): for _ in range(concurrency):
futures.append(executor.submit(func, items_rev.pop())) try:
futures.append(executor.submit(func, next(iterable)))
except StopIteration:
done = True
break
while futures: while futures:
result = futures.pop(0).result() result = futures.pop(0).result()
if items_rev: while not done and len(futures) < concurrency:
futures.append(executor.submit(func, items_rev.pop())) try:
futures.append(executor.submit(func, next(iterable)))
except StopIteration:
done = True
break
yield result yield result
def check_vocab_size(params: Params, vocab: Vocab) -> None: def check_vocab_size(params: Params, vocab: Vocab) -> None:
if params.n_vocab != vocab.vocab_size: if params.n_vocab != vocab.vocab_size:
assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab) assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab)
@ -804,12 +844,11 @@ class OutputFile:
self.gguf.add_token_types(toktypes) self.gguf.add_token_types(toktypes)
def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
n_elements = 1 n_elements = int(np.prod(tensor.shape))
for dim in tensor.shape: raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
n_elements *= dim data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
data_type = DATA_TYPE_TO_NUMPY[tensor.data_type] data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
data_nbytes = n_elements * data_type.itemsize self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype)
self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes)
def write_meta(self) -> None: def write_meta(self) -> None:
self.gguf.write_header_to_file() self.gguf.write_header_to_file()
@ -835,7 +874,20 @@ class OutputFile:
of.close() of.close()
@staticmethod @staticmethod
def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None: def do_item(item: Tuple[str, LazyTensor]) -> Tuple[DataType, NDArray]:
name, lazy_tensor = item
tensor = lazy_tensor.load().to_ggml()
return (lazy_tensor.data_type, tensor.ndarray)
@staticmethod
def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray:
dt, arr = item
if not isinstance(dt, QuantizedDataType):
return arr
return dt.quantize(arr)
@staticmethod
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
check_vocab_size(params, vocab) check_vocab_size(params, vocab)
of = OutputFile(fname_out) of = OutputFile(fname_out)
@ -851,16 +903,19 @@ class OutputFile:
of.write_meta() of.write_meta()
of.write_tensor_info() of.write_tensor_info()
def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
name, lazy_tensor = item
return lazy_tensor.load().to_ggml().ndarray
# tensor data # tensor data
ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8) ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
if ftype == GGMLFileType.MostlyQ8_0:
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor)
else:
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
start = time.time()
for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
elapsed = time.time() - start
size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
padi = len(str(len(model))) padi = len(str(len(model)))
print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type}") print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}")
of.gguf.write_tensor_data(ndarray) of.gguf.write_tensor_data(ndarray)
of.close() of.close()
@ -872,6 +927,8 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi
return GGMLFileType.AllF32 return GGMLFileType.AllF32
if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)): if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)):
return GGMLFileType.MostlyF16 return GGMLFileType.MostlyF16
if output_type_str == "q8_0":
return GGMLFileType.MostlyQ8_0
name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
@ -918,7 +975,7 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
print(f"skipping tensor {name_new}") print(f"skipping tensor {name_new}")
continue continue
else: else:
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}") print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
out[name_new] = lazy_tensor out[name_new] = lazy_tensor
return out return out
@ -1023,6 +1080,7 @@ def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
namestr = { namestr = {
GGMLFileType.AllF32: "f32", GGMLFileType.AllF32: "f32",
GGMLFileType.MostlyF16: "f16", GGMLFileType.MostlyF16: "f16",
GGMLFileType.MostlyQ8_0:"q8_0",
}[file_type] }[file_type]
ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf" ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
if ret in model_paths: if ret in model_paths:
@ -1046,12 +1104,13 @@ def main(args_in: Optional[List[str]] = None) -> None:
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)") parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm") parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
args = parser.parse_args(args_in) args = parser.parse_args(args_in)
if args.dump_single: if args.dump_single:
@ -1073,6 +1132,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
params.ftype = { params.ftype = {
"f32": GGMLFileType.AllF32, "f32": GGMLFileType.AllF32,
"f16": GGMLFileType.MostlyF16, "f16": GGMLFileType.MostlyF16,
"q8_0": GGMLFileType.MostlyQ8_0,
}[args.outtype] }[args.outtype]
print(f"params = {params}") print(f"params = {params}")
@ -1104,7 +1164,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
params.ftype = ftype params.ftype = ftype
print(f"Writing {outfile}, format {ftype}") print(f"Writing {outfile}, format {ftype}")
OutputFile.write_all(outfile, params, model, vocab) OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency)
print(f"Wrote {outfile}") print(f"Wrote {outfile}")

View file

@ -598,7 +598,12 @@ int main(int argc, char ** argv) {
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence); last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl) { if (!penalize_nl) {
logits[llama_token_nl(ctx)] = nl_logit; for (size_t idx = 0; idx < candidates_p.size; idx++) {
if (candidates_p.data[idx].id == llama_token_nl(ctx)) {
candidates_p.data[idx].logit = nl_logit;
break;
}
}
} }
if (grammar != NULL) { if (grammar != NULL) {

View file

@ -351,6 +351,7 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count); fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
fprintf(stderr, "================================= is_spm = %d\n", is_spm);
// This is needed as usual for LLaMA models // This is needed as usual for LLaMA models
const bool add_bos = is_spm; const bool add_bos = is_spm;
@ -406,6 +407,8 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
double acc = 0.0f; double acc = 0.0f;
const int n_vocab = llama_n_vocab(ctx); const int n_vocab = llama_n_vocab(ctx);
std::vector<std::vector<int>> ending_tokens(4);
std::vector<float> tok_logits(n_vocab); std::vector<float> tok_logits(n_vocab);
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) { for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
@ -413,11 +416,21 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos); std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos);
size_t context_size = context_embd.size(); size_t context_size = context_embd.size();
for (int i = 0; i < 4; ++i) {
ending_tokens[i] = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[i], add_bos);
for (int k = 0; k < int(context_size); ++k) {
if (ending_tokens[i][k] != context_embd[k]) {
fprintf(stderr, "Oops: ending %d of task %d differs from context at position %d\n",i,int(task_idx),k);
break;
}
}
}
// Do the 1st ending // Do the 1st ending
// In this case we include the context when evaluating // In this case we include the context when evaluating
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos); //auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
auto query_embd = ending_tokens[0];
auto query_size = query_embd.size(); auto query_size = query_embd.size();
//printf("First query: %d\n",(int)query_size);
// Stop if query wont fit the ctx window // Stop if query wont fit the ctx window
if (query_size > (size_t)params.n_ctx) { if (query_size > (size_t)params.n_ctx) {
@ -462,7 +475,8 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) { for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {
// Tokenize the query // Tokenize the query
query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false); query_embd.resize(ending_tokens[ending_idx].size() - context_size);
std::memcpy(query_embd.data(), ending_tokens[ending_idx].data() + context_size, query_embd.size()*sizeof(int));
query_size = query_embd.size(); query_size = query_embd.size();
// Stop if query wont fit the ctx window // Stop if query wont fit the ctx window

12
flake.lock generated
View file

@ -5,11 +5,11 @@
"systems": "systems" "systems": "systems"
}, },
"locked": { "locked": {
"lastModified": 1685518550, "lastModified": 1692799911,
"narHash": "sha256-o2d0KcvaXzTrPRIo0kOLV0/QXHhDQ5DTi+OxcjO8xqY=", "narHash": "sha256-3eihraek4qL744EvQXsK1Ha6C3CR7nnT8X2qWap4RNk=",
"owner": "numtide", "owner": "numtide",
"repo": "flake-utils", "repo": "flake-utils",
"rev": "a1720a10a6cfe8234c0e93907ffe81be440f4cef", "rev": "f9e7cf818399d17d347f847525c5a5a8032e4e44",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -20,11 +20,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1685931219, "lastModified": 1692913444,
"narHash": "sha256-8EWeOZ6LKQfgAjB/USffUSELPRjw88A+xTcXnOUvO5M=", "narHash": "sha256-1SvMQm2DwofNxXVtNWWtIcTh7GctEVrS/Xel/mdc6iY=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "7409480d5c8584a1a83c422530419efe4afb0d19", "rev": "18324978d632ffc55ef1d928e81630c620f4f447",
"type": "github" "type": "github"
}, },
"original": { "original": {

View file

@ -6,6 +6,9 @@
outputs = { self, nixpkgs, flake-utils }: outputs = { self, nixpkgs, flake-utils }:
flake-utils.lib.eachDefaultSystem (system: flake-utils.lib.eachDefaultSystem (system:
let let
name = "llama.cpp";
src = ./.;
meta.mainProgram = "llama";
inherit (pkgs.stdenv) isAarch32 isAarch64 isDarwin; inherit (pkgs.stdenv) isAarch32 isAarch64 isDarwin;
buildInputs = with pkgs; [ openmpi ]; buildInputs = with pkgs; [ openmpi ];
osSpecific = with pkgs; buildInputs ++ osSpecific = with pkgs; buildInputs ++
@ -21,11 +24,17 @@
CoreGraphics CoreGraphics
CoreVideo CoreVideo
] ]
else if isDarwin then
with pkgs.darwin.apple_sdk.frameworks; [
Accelerate
CoreGraphics
CoreVideo
]
else else
with pkgs; [ openblas ] with pkgs; [ openblas ]
); );
pkgs = import nixpkgs { inherit system; }; pkgs = import nixpkgs { inherit system; };
nativeBuildInputs = with pkgs; [ cmake pkgconfig ]; nativeBuildInputs = with pkgs; [ cmake ninja pkgconfig ];
llama-python = llama-python =
pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]); pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
postPatch = '' postPatch = ''
@ -38,35 +47,35 @@
mv $out/bin/server $out/bin/llama-server mv $out/bin/server $out/bin/llama-server
''; '';
cmakeFlags = [ "-DLLAMA_BUILD_SERVER=ON" "-DLLAMA_MPI=ON" "-DBUILD_SHARED_LIBS=ON" "-DCMAKE_SKIP_BUILD_RPATH=ON" ]; cmakeFlags = [ "-DLLAMA_BUILD_SERVER=ON" "-DLLAMA_MPI=ON" "-DBUILD_SHARED_LIBS=ON" "-DCMAKE_SKIP_BUILD_RPATH=ON" ];
in { in
{
packages.default = pkgs.stdenv.mkDerivation { packages.default = pkgs.stdenv.mkDerivation {
name = "llama.cpp"; inherit name src meta postPatch nativeBuildInputs buildInputs postInstall;
src = ./.;
postPatch = postPatch;
nativeBuildInputs = nativeBuildInputs;
buildInputs = osSpecific;
cmakeFlags = cmakeFlags cmakeFlags = cmakeFlags
++ (if isAarch64 && isDarwin then [ ++ (if isAarch64 && isDarwin then [
"-DCMAKE_C_FLAGS=-D__ARM_FEATURE_DOTPROD=1" "-DCMAKE_C_FLAGS=-D__ARM_FEATURE_DOTPROD=1"
"-DLLAMA_METAL=ON" "-DLLAMA_METAL=ON"
] else [ ] else [
"-DLLAMA_BLAS=ON" "-DLLAMA_BLAS=ON"
"-DLLAMA_BLAS_VENDOR=OpenBLAS" "-DLLAMA_BLAS_VENDOR=OpenBLAS"
]); ]);
postInstall = postInstall;
meta.mainProgram = "llama";
}; };
packages.opencl = pkgs.stdenv.mkDerivation { packages.opencl = pkgs.stdenv.mkDerivation {
name = "llama.cpp"; inherit name src meta postPatch nativeBuildInputs postInstall;
src = ./.;
postPatch = postPatch;
nativeBuildInputs = nativeBuildInputs;
buildInputs = with pkgs; buildInputs ++ [ clblast ]; buildInputs = with pkgs; buildInputs ++ [ clblast ];
cmakeFlags = cmakeFlags ++ [ cmakeFlags = cmakeFlags ++ [
"-DLLAMA_CLBLAST=ON" "-DLLAMA_CLBLAST=ON"
]; ];
postInstall = postInstall; };
meta.mainProgram = "llama"; packages.rocm = pkgs.stdenv.mkDerivation {
inherit name src meta postPatch nativeBuildInputs postInstall;
buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ];
cmakeFlags = cmakeFlags ++ [
"-DLLAMA_HIPBLAS=1"
"-DCMAKE_C_COMPILER=hipcc"
"-DCMAKE_CXX_COMPILER=hipcc"
"-DCMAKE_POSITION_INDEPENDENT_CODE=ON"
];
}; };
apps.llama-server = { apps.llama-server = {
type = "app"; type = "app";
@ -80,8 +89,13 @@
type = "app"; type = "app";
program = "${self.packages.${system}.default}/bin/llama"; program = "${self.packages.${system}.default}/bin/llama";
}; };
apps.quantize = {
type = "app";
program = "${self.packages.${system}.default}/bin/quantize";
};
apps.default = self.apps.${system}.llama; apps.default = self.apps.${system}.llama;
devShells.default = pkgs.mkShell { devShells.default = pkgs.mkShell {
buildInputs = [ llama-python ];
packages = nativeBuildInputs ++ osSpecific; packages = nativeBuildInputs ++ osSpecific;
}; };
}); });

View file

@ -1,9 +1,6 @@
// Defines fileno on msys: // Defines fileno on msys:
#ifndef _GNU_SOURCE #ifndef _GNU_SOURCE
#define _GNU_SOURCE #define _GNU_SOURCE
#include <cstddef>
#include <cstdint>
#include <cstdio>
#endif #endif
#include "llama.h" #include "llama.h"
@ -62,6 +59,9 @@
#include <cinttypes> #include <cinttypes>
#include <climits> #include <climits>
#include <cstdarg> #include <cstdarg>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstring> #include <cstring>
#include <ctime> #include <ctime>
#include <fstream> #include <fstream>
@ -955,10 +955,10 @@ struct llama_vocab {
id linefeed_id = 13; id linefeed_id = 13;
int find_bpe_rank(std::string token_left, std::string token_right) const { int find_bpe_rank(std::string token_left, std::string token_right) const {
replace_all(token_left, " ", "Ġ"); replace_all(token_left, " ", "\u0120");
replace_all(token_left, "\n", "Ċ"); replace_all(token_left, "\n", "\u010A");
replace_all(token_right, " ", "Ġ"); replace_all(token_right, " ", "\u0120");
replace_all(token_right, "\n", "Ċ"); replace_all(token_right, "\n", "\u010A");
auto it = bpe_ranks.find(std::make_pair(token_left, token_right)); auto it = bpe_ranks.find(std::make_pair(token_left, token_right));
if (it == bpe_ranks.end()) { if (it == bpe_ranks.end()) {
@ -3894,7 +3894,7 @@ void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array *
// Calculate absolute value of second derivatives // Calculate absolute value of second derivatives
for (size_t i = 0; i < second_derivatives.size(); ++i) { for (size_t i = 0; i < second_derivatives.size(); ++i) {
second_derivatives[i] = abs(second_derivatives[i]); second_derivatives[i] = std::abs(second_derivatives[i]);
} }
// Normalize the second derivatives // Normalize the second derivatives
@ -4660,6 +4660,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname_inp, /*use_mmap*/ false)); std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname_inp, /*use_mmap*/ false));
llama_model model;
llm_load_arch(*ml, model);
llm_load_hparams(*ml, model, 0, 0, 0);
const size_t align = GGUF_DEFAULT_ALIGNMENT; const size_t align = GGUF_DEFAULT_ALIGNMENT;
struct gguf_context * ctx_out = gguf_init_empty(); struct gguf_context * ctx_out = gguf_init_empty();
@ -4685,6 +4689,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
++n_feed_forward_w2; ++n_feed_forward_w2;
} }
} }
if (n_attention_wv != n_feed_forward_w2 || (uint32_t)n_attention_wv != model.hparams.n_layer) {
LLAMA_LOG_WARN("%s ============ Strange model: n_attention_wv = %d, n_feed_forward_w2 = %d, hparams.n_layer = %d\n",
__func__, n_attention_wv, n_feed_forward_w2, model.hparams.n_layer);
}
int i_attention_wv = 0; int i_attention_wv = 0;
int i_feed_forward_w2 = 0; int i_feed_forward_w2 = 0;
@ -4761,8 +4769,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { if (name == tn(LLM_TENSOR_OUTPUT, "weight")) {
int nx = tensor->ne[0]; int nx = tensor->ne[0];
int ny = tensor->ne[1]; if (nx % QK_K == 0) {
if (nx % QK_K == 0 && ny % QK_K == 0) {
new_type = GGML_TYPE_Q6_K; new_type = GGML_TYPE_Q6_K;
} }
} else if (name.find("attn_v.weight") != std::string::npos) { } else if (name.find("attn_v.weight") != std::string::npos) {
@ -4776,6 +4783,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_attention_wv < 4) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) && else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) &&
(i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K; (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K;
if (model.type == MODEL_70B) {
// In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
// 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
// nearly negligible increase in model size by quantizing this tensor with more bits:
if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
}
++i_attention_wv; ++i_attention_wv;
} else if (name.find("ffn_down.weight") != std::string::npos) { } else if (name.find("ffn_down.weight") != std::string::npos) {
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
@ -4805,8 +4818,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) { new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) {
int nx = tensor->ne[0]; int nx = tensor->ne[0];
int ny = tensor->ne[1]; int ny = tensor->ne[1];
if (nx % QK_K != 0 || ny % QK_K != 0) { if (nx % QK_K != 0) {
LLAMA_LOG_INFO("\n\nTensor sizes %d x %d are not divisible by %d, required for k-quants.\n",nx,ny,QK_K); LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for k-quants\n", __func__, nx, ny, QK_K);
convert_incompatible_tensor = true; convert_incompatible_tensor = true;
} }
} }