diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index 0b6f6669b..8c6312508 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -10,6 +10,8 @@ on: pull_request: types: [opened, synchronize, reopened] paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/tests/**.*'] + schedule: + - cron: '00 0 * * *' jobs: server: @@ -70,14 +72,15 @@ jobs: run: | pip install -r examples/server/tests/requirements.txt - - name: Download models - id: download_models - run: | - cd examples/server/tests - ../../../scripts/hf.sh --repo ggml-org/models --file tinyllamas/stories260K.gguf - - name: Tests - id: server_integration_test + id: server_integration_tests run: | cd examples/server/tests PORT=8888 ./tests.sh + + - name: Slow tests + id: server_integration_tests_slow + if: github.event.schedule != '' + run: | + cd examples/server/tests + PORT=8888 ./tests.sh --stop --no-skipped --no-capture --tags slow diff --git a/README.md b/README.md index 67717c1e3..939646753 100644 --- a/README.md +++ b/README.md @@ -786,7 +786,7 @@ And after 4.45 hours, you will have the final perplexity. ### Interactive mode If you want a more ChatGPT-like experience, you can run in interactive mode by passing `-i` as a parameter. -In this mode, you can always interrupt generation by pressing Ctrl+C and entering one or more lines of text, which will be converted into tokens and appended to the current context. You can also specify a *reverse prompt* with the parameter `-r "reverse prompt string"`. This will result in user input being prompted whenever the exact tokens of the reverse prompt string are encountered in the generation. A typical use is to use a prompt that makes LLaMa emulate a chat between multiple users, say Alice and Bob, and pass `-r "Alice:"`. +In this mode, you can always interrupt generation by pressing Ctrl+C and entering one or more lines of text, which will be converted into tokens and appended to the current context. You can also specify a *reverse prompt* with the parameter `-r "reverse prompt string"`. This will result in user input being prompted whenever the exact tokens of the reverse prompt string are encountered in the generation. A typical use is to use a prompt that makes LLaMA emulate a chat between multiple users, say Alice and Bob, and pass `-r "Alice:"`. Here is an example of a few-shot interaction, invoked with the command @@ -850,7 +850,7 @@ Sample run: ``` == Running in interactive mode. == - Press Ctrl+C to interject at any time. - - Press Return to return control to LLaMa. + - Press Return to return control to LLaMA. - If you want to submit another line, end your input in '\'. Below is an instruction that describes a task. Write a response that appropriately completes the request. diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 28b92ac38..fa9d4f22f 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -8,9 +8,10 @@ import json import os import re import sys +from abc import ABC, abstractmethod from enum import IntEnum from pathlib import Path -from typing import TYPE_CHECKING, Any, ContextManager, Iterator, Sequence, cast +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterator, Sequence, TypeVar, cast import numpy as np import torch @@ -35,8 +36,11 @@ class SentencePieceTokenTypes(IntEnum): UNUSED = 5 BYTE = 6 +AnyModel = TypeVar("AnyModel", bound="type[Model]") + +class Model(ABC): + _model_classes: dict[str, type[Model]] = {} -class Model: def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool): self.dir_model = dir_model self.ftype = ftype @@ -47,10 +51,14 @@ class Model: self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin") self.part_names = self._get_part_names() self.hparams = Model.load_hparams(self.dir_model) - self.model_arch = self._get_model_architecture() self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False) self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) + @property + @abstractmethod + def model_arch(self) -> gguf.MODEL_ARCH: + pass + def find_hparam(self, keys: Sequence[str], optional: bool = False) -> Any: key = next((k for k in keys if k in self.hparams), None) if key is not None: @@ -176,55 +184,21 @@ class Model: with open(dir_model / "config.json", "r", encoding="utf-8") as f: return json.load(f) - @staticmethod - def from_model_architecture(model_architecture): - if model_architecture == "GPTNeoXForCausalLM": - return GPTNeoXModel - if model_architecture == "BloomForCausalLM": - return BloomModel - if model_architecture == "MPTForCausalLM": - return MPTModel - if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"): - return BaichuanModel - if model_architecture in ("FalconForCausalLM", "RWForCausalLM"): - return FalconModel - if model_architecture == "GPTBigCodeForCausalLM": - return StarCoderModel - if model_architecture == "GPTRefactForCausalLM": - return RefactModel - if model_architecture == "PersimmonForCausalLM": - return PersimmonModel - if model_architecture in ("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"): - return StableLMModel - if model_architecture == "QWenLMHeadModel": - return QwenModel - if model_architecture == "Qwen2ForCausalLM": - return Model - if model_architecture == "MixtralForCausalLM": - return MixtralModel - if model_architecture == "GPT2LMHeadModel": - return GPT2Model - if model_architecture == "PhiForCausalLM": - return Phi2Model - if model_architecture == "PlamoForCausalLM": - return PlamoModel - if model_architecture == "CodeShellForCausalLM": - return CodeShellModel - if model_architecture == "OrionForCausalLM": - return OrionModel - if model_architecture == "InternLM2ForCausalLM": - return InternLM2Model - if model_architecture == "MiniCPMForCausalLM": - return MiniCPMModel - if model_architecture == "BertModel": - return BertModel - if model_architecture == "NomicBertModel": - return NomicBertModel - if model_architecture == "GemmaForCausalLM": - return GemmaModel - if model_architecture == "Starcoder2ForCausalLM": - return Model - return Model + @classmethod + def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: + assert names + def func(modelcls: type[Model]): + for name in names: + cls._model_classes[name] = modelcls + return modelcls + return func + + @classmethod + def from_model_architecture(cls, arch): + try: + return cls._model_classes[arch] + except KeyError: + raise NotImplementedError(f'Architecture {arch!r} not supported!') from None def _is_model_safetensors(self) -> bool: return Model.count_model_parts(self.dir_model, ".safetensors") > 0 @@ -239,57 +213,6 @@ class Model: return ("pytorch_model.bin",) return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1)) - def _get_model_architecture(self) -> gguf.MODEL_ARCH: - arch = self.hparams["architectures"][0] - if arch == "GPTNeoXForCausalLM": - return gguf.MODEL_ARCH.GPTNEOX - if arch == "BloomForCausalLM": - return gguf.MODEL_ARCH.BLOOM - if arch == "MPTForCausalLM": - return gguf.MODEL_ARCH.MPT - if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"): - return gguf.MODEL_ARCH.BAICHUAN - if arch in ("FalconForCausalLM", "RWForCausalLM"): - return gguf.MODEL_ARCH.FALCON - if arch == "GPTBigCodeForCausalLM": - return gguf.MODEL_ARCH.STARCODER - if arch == "GPTRefactForCausalLM": - return gguf.MODEL_ARCH.REFACT - if arch == "PersimmonForCausalLM": - return gguf.MODEL_ARCH.PERSIMMON - if arch in ("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"): - return gguf.MODEL_ARCH.STABLELM - if arch == "QWenLMHeadModel": - return gguf.MODEL_ARCH.QWEN - if arch == "Qwen2ForCausalLM": - return gguf.MODEL_ARCH.QWEN2 - if arch == "MixtralForCausalLM": - return gguf.MODEL_ARCH.LLAMA - if arch == "GPT2LMHeadModel": - return gguf.MODEL_ARCH.GPT2 - if arch == "PhiForCausalLM": - return gguf.MODEL_ARCH.PHI2 - if arch == "PlamoForCausalLM": - return gguf.MODEL_ARCH.PLAMO - if arch == "CodeShellForCausalLM": - return gguf.MODEL_ARCH.CODESHELL - if arch == "OrionForCausalLM": - return gguf.MODEL_ARCH.ORION - if arch == "InternLM2ForCausalLM": - return gguf.MODEL_ARCH.INTERNLM2 - if arch == "MiniCPMForCausalLM": - return gguf.MODEL_ARCH.MINICPM - if arch == "BertModel": - return gguf.MODEL_ARCH.BERT - if arch == "NomicBertModel": - return gguf.MODEL_ARCH.NOMIC_BERT - if arch == "GemmaForCausalLM": - return gguf.MODEL_ARCH.GEMMA - if arch == "Starcoder2ForCausalLM": - return gguf.MODEL_ARCH.STARCODER2 - - raise NotImplementedError(f'Architecture "{arch}" not supported!') - def _set_vocab_gpt2(self): dir_model = self.dir_model hparams = self.hparams @@ -457,7 +380,10 @@ class Model: special_vocab.add_to_gguf(self.gguf_writer) +@Model.register("GPTNeoXForCausalLM") class GPTNeoXModel(Model): + model_arch = gguf.MODEL_ARCH.GPTNEOX + def set_gguf_parameters(self): block_count = self.hparams["num_hidden_layers"] @@ -474,7 +400,10 @@ class GPTNeoXModel(Model): self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) +@Model.register("BloomForCausalLM") class BloomModel(Model): + model_arch = gguf.MODEL_ARCH.BLOOM + def set_gguf_parameters(self): self.gguf_writer.add_name("Bloom") n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) @@ -566,7 +495,10 @@ class BloomModel(Model): print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}") +@Model.register("MPTForCausalLM") class MPTModel(Model): + model_arch = gguf.MODEL_ARCH.MPT + def set_gguf_parameters(self): block_count = self.hparams["n_layers"] self.gguf_writer.add_name(self.dir_model.name) @@ -629,7 +561,10 @@ class MPTModel(Model): self.gguf_writer.add_tensor(new_name, data) +@Model.register("OrionForCausalLM") class OrionModel(Model): + model_arch = gguf.MODEL_ARCH.ORION + def set_vocab(self): self._set_vocab_sentencepiece() @@ -708,7 +643,10 @@ class OrionModel(Model): self.gguf_writer.add_tensor(new_name, data) +@Model.register("BaichuanForCausalLM", "BaiChuanForCausalLM") class BaichuanModel(Model): + model_arch = gguf.MODEL_ARCH.BAICHUAN + def set_vocab(self): self._set_vocab_sentencepiece() @@ -823,7 +761,10 @@ class BaichuanModel(Model): return weights[r * n_part:r * n_part + r, ...] +@Model.register("FalconForCausalLM", "RWForCausalLM") class FalconModel(Model): + model_arch = gguf.MODEL_ARCH.FALCON + def set_gguf_parameters(self): block_count = self.hparams.get("num_hidden_layers") if block_count is None: @@ -916,7 +857,10 @@ class FalconModel(Model): self.gguf_writer.add_tensor(new_name, data) +@Model.register("GPTBigCodeForCausalLM") class StarCoderModel(Model): + model_arch = gguf.MODEL_ARCH.STARCODER + def set_gguf_parameters(self): block_count = self.hparams["n_layer"] @@ -931,7 +875,10 @@ class StarCoderModel(Model): self.gguf_writer.add_file_type(self.ftype) +@Model.register("GPTRefactForCausalLM") class RefactModel(Model): + model_arch = gguf.MODEL_ARCH.REFACT + def set_gguf_parameters(self): hidden_dim = self.hparams["n_embd"] inner_dim = 4 * hidden_dim @@ -1015,7 +962,10 @@ class RefactModel(Model): self.gguf_writer.add_tensor(new_name, data) +@Model.register("PersimmonForCausalLM") class PersimmonModel(Model): + model_arch = gguf.MODEL_ARCH.PERSIMMON + def set_gguf_parameters(self): block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers")) head_count = self.hparams["num_attention_heads"] @@ -1063,7 +1013,10 @@ class PersimmonModel(Model): self.gguf_writer.add_tensor(new_name, data) +@Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM") class StableLMModel(Model): + model_arch = gguf.MODEL_ARCH.STABLELM + def set_vocab(self): if (self.dir_model / "tokenizer.json").is_file(): self._set_vocab_gpt2() @@ -1087,12 +1040,18 @@ class StableLMModel(Model): self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"])) +@Model.register("MixtralForCausalLM") class MixtralModel(Model): + model_arch = gguf.MODEL_ARCH.LLAMA + def set_vocab(self): self._set_vocab_sentencepiece() +@Model.register("MiniCPMForCausalLM") class MiniCPMModel(Model): + model_arch = gguf.MODEL_ARCH.MINICPM + def set_gguf_parameters(self): block_count = self.hparams["num_hidden_layers"] self.gguf_writer.add_name("MiniCPM") @@ -1169,7 +1128,10 @@ class MiniCPMModel(Model): self.gguf_writer.add_tensor(new_name, data) +@Model.register("QWenLMHeadModel") class QwenModel(Model): + model_arch = gguf.MODEL_ARCH.QWEN + @staticmethod def token_bytes_to_string(b): from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode @@ -1249,7 +1211,15 @@ class QwenModel(Model): self.gguf_writer.add_tensor(new_name, data) +@Model.register("Qwen2ForCausalLM") +class Qwen2Model(Model): + model_arch = gguf.MODEL_ARCH.QWEN2 + + +@Model.register("GPT2LMHeadModel") class GPT2Model(Model): + model_arch = gguf.MODEL_ARCH.GPT2 + def set_gguf_parameters(self): self.gguf_writer.add_name(self.dir_model.name) self.gguf_writer.add_block_count(self.hparams["n_layer"]) @@ -1311,7 +1281,10 @@ class GPT2Model(Model): self.gguf_writer.add_tensor("output.weight", data) +@Model.register("PhiForCausalLM") class Phi2Model(Model): + model_arch = gguf.MODEL_ARCH.PHI2 + def set_gguf_parameters(self): block_count = self.find_hparam(["num_hidden_layers", "n_layer"]) @@ -1333,7 +1306,10 @@ class Phi2Model(Model): self.gguf_writer.add_add_bos_token(False) +@Model.register("PlamoForCausalLM") class PlamoModel(Model): + model_arch = gguf.MODEL_ARCH.PLAMO + def set_vocab(self): self._set_vocab_sentencepiece() @@ -1412,7 +1388,10 @@ class PlamoModel(Model): self.gguf_writer.add_tensor(new_name, data) +@Model.register("CodeShellForCausalLM") class CodeShellModel(Model): + model_arch = gguf.MODEL_ARCH.CODESHELL + def set_gguf_parameters(self): block_count = self.hparams["n_layer"] @@ -1477,7 +1456,10 @@ class CodeShellModel(Model): print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}") +@Model.register("InternLM2ForCausalLM") class InternLM2Model(Model): + model_arch = gguf.MODEL_ARCH.INTERNLM2 + def set_vocab(self): # (TODO): Is there a better way? # Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character @@ -1649,7 +1631,10 @@ in chat mode so that the conversation can end normally.") self.post_write_tensors(tensor_map, name, data_torch) +@Model.register("BertModel") class BertModel(Model): + model_arch = gguf.MODEL_ARCH.BERT + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.vocab_size = None @@ -1679,7 +1664,7 @@ class BertModel(Model): else: raise NotImplementedError("Only MEAN and CLS pooling types supported") - self.gguf_writer.add_pooling_type(pooling_type.value) + self.gguf_writer.add_pooling_type(pooling_type) def set_vocab(self): path = self.dir_model @@ -1755,7 +1740,10 @@ class BertModel(Model): self.gguf_writer.add_tensor(new_name, data) +@Model.register("NomicBertModel") class NomicBertModel(BertModel): + model_arch = gguf.MODEL_ARCH.NOMIC_BERT + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -1792,7 +1780,10 @@ class NomicBertModel(BertModel): yield name, data +@Model.register("GemmaForCausalLM") class GemmaModel(Model): + model_arch = gguf.MODEL_ARCH.GEMMA + def set_vocab(self): self._set_vocab_sentencepiece() @@ -1848,6 +1839,11 @@ class GemmaModel(Model): self.gguf_writer.add_tensor(new_name, data) +@Model.register("Starcoder2ForCausalLM") +class StarCoder2Model(Model): + model_arch = gguf.MODEL_ARCH.STARCODER2 + + ###### CONVERSION LOGIC ###### diff --git a/convert-llama-ggml-to-gguf.py b/convert-llama-ggml-to-gguf.py index b33108062..cd9644fcb 100755 --- a/convert-llama-ggml-to-gguf.py +++ b/convert-llama-ggml-to-gguf.py @@ -373,7 +373,7 @@ def handle_metadata(cfg, hp): raise ValueError('Unable to load metadata') vocab_path = Path(cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir) vocab_factory = convert.VocabFactory(vocab_path) - vocab, special_vocab = vocab_factory.load_vocab(cfg.vocabtype, cfg.model_metadata_dir) + vocab, special_vocab = vocab_factory.load_vocab(cfg.vocabtype.split(","), cfg.model_metadata_dir) convert.check_vocab_size(params, vocab) return params, vocab, special_vocab @@ -398,8 +398,8 @@ def handle_args(): help ='Load HuggingFace/.pth vocab and metadata from the specified directory') parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file - only meaningful with --model-metadata-dir") - parser.add_argument("--vocabtype", choices=["spm", "bpe"], default="spm", - help="vocab format - only meaningful with --model-metadata-dir and/or --vocab-dir (default: spm)") + parser.add_argument("--vocabtype", default="spm,hfft", + help="vocab format - only meaningful with --model-metadata-dir and/or --vocab-dir (default: spm,hfft)") return parser.parse_args() diff --git a/convert.py b/convert.py index 63a0a5d78..6e3a0319b 100755 --- a/convert.py +++ b/convert.py @@ -1282,35 +1282,32 @@ def load_some_model(path: Path) -> ModelPlus: class VocabFactory: + _FILES = {"spm": "tokenizer.model", "bpe": "vocab.json", "hfft": "tokenizer.json"} + def __init__(self, path: Path): self.path = path - self.files: dict[str, Path | None] = { - "tokenizer.model": None, - "vocab.json": None, - "tokenizer.json": None, - } - self._detect_files() + self.file_paths = self._detect_files() + print(f"Found vocab files: {self.file_paths}") - def _detect_files(self): - for file in self.files.keys(): - file_path = self.path / file - parent_file_path = self.path.parent / file - if file_path.exists(): - self.files[file] = file_path - elif parent_file_path.exists(): - self.files[file] = parent_file_path - print(f"Found vocab files: {self.files}") + def _detect_files(self) -> dict[str, Path | None]: + def locate(file: str) -> Path | None: + if (path := self.path / file).exists(): + return path + if (path := self.path.parent / file).exists(): + return path + return None - def _select_file(self, vocabtype: str | None) -> Path: - if vocabtype in ["spm", "bpe"]: - for file_key in self.files.keys(): - if (file := self.files[file_key]) is not None: - return file - raise FileNotFoundError(f"{vocabtype} vocab not found.") - if vocabtype == "hfft": - # For Hugging Face Fast Tokenizer, return the directory path instead of a specific file - return self.path - raise ValueError(f"Unsupported vocabulary type {vocabtype}") + return {vt: locate(f) for vt, f in self._FILES.items()} + + def _select_file(self, vocab_types: list[str]) -> tuple[str, Path]: + for vtype in vocab_types: + try: + path = self.file_paths[vtype] + except KeyError: + raise ValueError(f"Unsupported vocabulary type {vtype}") from None + if path is not None: + return vtype, path + raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}") def _create_special_vocab(self, vocab: Vocab, vocabtype: str, model_parent_path: Path) -> gguf.SpecialVocab: load_merges = vocabtype == "bpe" @@ -1322,30 +1319,30 @@ class VocabFactory: n_vocab=n_vocab, ) - def load_vocab(self, vocabtype: str, model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]: - path = self._select_file(vocabtype) - print(f"Loading vocab file '{path}', type '{vocabtype}'") + def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]: + vocab_type, path = self._select_file(vocab_types) + print(f"Loading vocab file {path!r}, type {vocab_type!r}") added_tokens_path = path.parent / "added_tokens.json" vocab: Vocab - if vocabtype == "bpe": + if vocab_type == "bpe": vocab = BpeVocab( path, added_tokens_path if added_tokens_path.exists() else None ) - elif vocabtype == "spm": + elif vocab_type == "spm": vocab = SentencePieceVocab( path, added_tokens_path if added_tokens_path.exists() else None ) - elif vocabtype == "hfft": + elif vocab_type == "hfft": vocab = HfVocab( - path, added_tokens_path if added_tokens_path.exists() else None + path.parent, added_tokens_path if added_tokens_path.exists() else None ) else: - raise ValueError(f"Unsupported vocabulary type {vocabtype}") + raise ValueError(vocab_type) # FIXME: Respect --vocab-dir? special_vocab = self._create_special_vocab( vocab, - vocabtype, + vocab_type, model_parent_path, ) return vocab, special_vocab @@ -1379,15 +1376,14 @@ def main(args_in: list[str] | None = None) -> None: if np.uint32(1) == np.uint32(1).newbyteorder("<"): # We currently only support Q8_0 output on little endian systems. output_choices.append("q8_0") - vocab_types = ["spm", "bpe", "hfft"] - parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file") + parser = argparse.ArgumentParser(description="Convert a LLaMA model to a GGML compatible file") parser.add_argument("--awq-path", type=Path, help="Path to scale awq cache file", default=None) parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") - parser.add_argument("--vocab-type", choices=vocab_types, help="The vocabulary format used to define the tokenizer model (default: spm)", default="spm") + parser.add_argument("--vocab-type", help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft)", default="spm,hfft") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") @@ -1448,7 +1444,7 @@ def main(args_in: list[str] | None = None) -> None: model_parent_path = model_plus.paths[0].parent vocab_path = Path(args.vocab_dir or args.model or model_parent_path) vocab_factory = VocabFactory(vocab_path) - vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type, model_parent_path) + vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type.split(","), model_parent_path) if args.vocab_only: if not args.outfile: diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index d4b8729dd..91c39c5ae 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -378,10 +378,10 @@ int main(int argc, char ** argv) { if (params.interactive) { const char *control_message; if (params.multiline_input) { - control_message = " - To return control to LLaMa, end your input with '\\'.\n" + control_message = " - To return control to LLaMA, end your input with '\\'.\n" " - To return control without starting a new line, end your input with '/'.\n"; } else { - control_message = " - Press Return to return control to LLaMa.\n" + control_message = " - Press Return to return control to LLaMA.\n" " - To return control without starting a new line, end your input with '/'.\n" " - If you want to submit another line, end your input with '\\'.\n"; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e37da6133..d237d2e66 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1811,8 +1811,8 @@ struct llama_server_context } slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - // if input prompt is too big, truncate it - if (slot.n_prompt_tokens >= slot.n_ctx) + // if input prompt is too big, truncate it, if group attention self-extend is disabled + if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) { const int n_left = slot.n_ctx - slot.params.n_keep; const int n_block_size = n_left / 2; @@ -1887,9 +1887,11 @@ struct llama_server_context } LOG_INFO("slot progression", { - { "slot_id", slot.id }, - { "task_id", slot.task_id }, - { "n_past", slot.n_past }, + { "slot_id", slot.id }, + { "task_id", slot.task_id }, + { "n_past", slot.n_past }, + { "n_past_se", slot.n_past_se }, + { "ga_i", slot.ga_i }, { "n_prompt_tokens_processed", slot.n_prompt_tokens_processed } }); } @@ -2113,6 +2115,17 @@ struct llama_server_context return true; } + + json model_meta() { + return json{ + {"vocab_type", llama_vocab_type(model)}, + {"n_vocab", llama_n_vocab(model)}, + {"n_ctx_train", llama_n_ctx_train(model)}, + {"n_embd", llama_n_embd(model)}, + {"n_params", llama_model_n_params(model)}, + {"size", llama_model_size(model)}, + }; + } }; static void server_print_usage(const char *argv0, const gpt_params ¶ms, @@ -3075,9 +3088,10 @@ int main(int argc, char **argv) for (const auto& metric_def : metrics_def) { std::string name = metric_def["name"]; std::string help = metric_def["help"]; - prometheus << "# HELP llamacpp:" << name << " " << help << "\n" - << "# TYPE llamacpp:" << name << " " << type << "\n" - << "llamacpp:" << name << " " << metric_def["value"] << "\n"; + auto value = json_value(metric_def, "value", 0); + prometheus << "# HELP llamacpp:" << name << " " << help << "\n" + << "# TYPE llamacpp:" << name << " " << type << "\n" + << "llamacpp:" << name << " " << value << "\n"; } } @@ -3165,6 +3179,7 @@ int main(int argc, char **argv) state.store(SERVER_STATE_READY); LOG_INFO("model loaded", {}); } + const auto model_meta = llama.model_meta(); if (sparams.chat_template.empty()) { // custom chat template is not supplied // check if the template comes with the model is supported by us @@ -3329,7 +3344,7 @@ int main(int argc, char **argv) } }); - svr.Get("/v1/models", [¶ms](const httplib::Request& req, httplib::Response& res) + svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request& req, httplib::Response& res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); std::time_t t = std::time(0); @@ -3338,10 +3353,11 @@ int main(int argc, char **argv) {"object", "list"}, {"data", { { - {"id", params.model_alias}, - {"object", "model"}, - {"created", t}, - {"owned_by", "llamacpp"} + {"id", params.model_alias}, + {"object", "model"}, + {"created", t}, + {"owned_by", "llamacpp"}, + {"meta", model_meta} }, }} }; diff --git a/examples/server/tests/README.md b/examples/server/tests/README.md index 0b9fdc4e7..95a0353b6 100644 --- a/examples/server/tests/README.md +++ b/examples/server/tests/README.md @@ -1,22 +1,30 @@ # Server tests -Python based server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_development) and [behave](https://behave.readthedocs.io/en/latest/): - * [issues.feature](./features/issues.feature) Pending issues scenario - * [parallel.feature](./features/parallel.feature) Scenario involving multi slots and concurrent requests - * [security.feature](./features/security.feature) Security, CORS and API Key - * [server.feature](./features/server.feature) Server base scenario: completion, embedding, tokenization, etc... +Python based server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_development) +and [behave](https://behave.readthedocs.io/en/latest/): + +* [issues.feature](./features/issues.feature) Pending issues scenario +* [parallel.feature](./features/parallel.feature) Scenario involving multi slots and concurrent requests +* [security.feature](./features/security.feature) Security, CORS and API Key +* [server.feature](./features/server.feature) Server base scenario: completion, embedding, tokenization, etc... Tests target GitHub workflows job runners with 4 vCPU. -Requests are using [aiohttp](https://docs.aiohttp.org/en/stable/client_reference.html), [asyncio](https://docs.python.org/fr/3/library/asyncio.html) based http client. +Requests are +using [aiohttp](https://docs.aiohttp.org/en/stable/client_reference.html), [asyncio](https://docs.python.org/fr/3/library/asyncio.html) +based http client. -Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail. To mitigate it, you can increase values in `n_predict`, `kv_size`. +Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail. +To mitigate it, you can increase values in `n_predict`, `kv_size`. ### Install dependencies + `pip install -r requirements.txt` ### Run tests + 1. Build the server + ```shell cd ../../.. mkdir build @@ -24,24 +32,36 @@ cd build cmake ../ cmake --build . --target server ``` -2. download required models: - 1. `../../../scripts/hf.sh --repo ggml-org/models --file tinyllamas/stories260K.gguf` -3. Start the test: `./tests.sh` + +2. Start the test: `./tests.sh` It's possible to override some scenario steps values with environment variables: - - `PORT` -> `context.server_port` to set the listening port of the server during scenario, default: `8080` - - `LLAMA_SERVER_BIN_PATH` -> to change the server binary path, default: `../../../build/bin/server` - - `DEBUG` -> "ON" to enable steps and server verbose mode `--verbose` - - `SERVER_LOG_FORMAT_JSON` -> if set switch server logs to json format + +| variable | description | +|--------------------------|------------------------------------------------------------------------------------------------| +| `PORT` | `context.server_port` to set the listening port of the server during scenario, default: `8080` | +| `LLAMA_SERVER_BIN_PATH` | to change the server binary path, default: `../../../build/bin/server` | +| `DEBUG` | "ON" to enable steps and server verbose mode `--verbose` | +| `SERVER_LOG_FORMAT_JSON` | if set switch server logs to json format | +| `N_GPU_LAYERS` | number of model layers to offload to VRAM `-ngl --n-gpu-layers` | ### Run @bug, @wip or @wrong_usage annotated scenario Feature or Scenario must be annotated with `@llama.cpp` to be included in the default scope. + - `@bug` annotation aims to link a scenario with a GitHub issue. - `@wrong_usage` are meant to show user issue that are actually an expected behavior - `@wip` to focus on a scenario working in progress +- `@slow` heavy test, disabled by default To run a scenario annotated with `@bug`, start: -`DEBUG=ON ./tests.sh --no-skipped --tags bug` + +```shell +DEBUG=ON ./tests.sh --no-skipped --tags bug +``` After changing logic in `steps.py`, ensure that `@bug` and `@wrong_usage` scenario are updated. + +```shell +./tests.sh --no-skipped --tags bug,wrong_usage || echo "should failed but compile" +``` diff --git a/examples/server/tests/features/environment.py b/examples/server/tests/features/environment.py index 09e826747..9fd330db6 100644 --- a/examples/server/tests/features/environment.py +++ b/examples/server/tests/features/environment.py @@ -7,7 +7,10 @@ from signal import SIGKILL def before_scenario(context, scenario): - print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m") + context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON' + if context.debug: + print("DEBUG=ON\n") + print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m\n") port = 8080 if 'PORT' in os.environ: port = int(os.environ['PORT']) diff --git a/examples/server/tests/features/issues.feature b/examples/server/tests/features/issues.feature index bf5a175a3..7b13e44ca 100644 --- a/examples/server/tests/features/issues.feature +++ b/examples/server/tests/features/issues.feature @@ -1,4 +1,5 @@ # List of ongoing issues +# run with: DEBUG=ON ./tests.sh --no-skipped --tags bug @bug Feature: Issues # No confirmed issue at the moment diff --git a/examples/server/tests/features/parallel.feature b/examples/server/tests/features/parallel.feature index 5f895cf90..86cdf7282 100644 --- a/examples/server/tests/features/parallel.feature +++ b/examples/server/tests/features/parallel.feature @@ -1,11 +1,12 @@ @llama.cpp +@parallel Feature: Parallel Background: Server startup Given a server listening on localhost:8080 - And a model file stories260K.gguf - And a model alias tinyllama-2 + And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And 42 as server seed + And 512 as batch size And 64 KV cache size And 2 slots And embeddings extraction diff --git a/examples/server/tests/features/passkey.feature b/examples/server/tests/features/passkey.feature new file mode 100644 index 000000000..1bde7aab8 --- /dev/null +++ b/examples/server/tests/features/passkey.feature @@ -0,0 +1,55 @@ +# run with: ./tests.sh --no-skipped --tags passkey +@passkey +@slow +Feature: Passkey / Self-extend with context shift + + Background: Server startup + Given a server listening on localhost:8080 + + # Generates a long text of junk and inserts a secret passkey number inside it. + # Then we query the LLM for the secret passkey. + # see #3856 and #4810 + Scenario Outline: Passkey + Given a model file from HF repo + And as batch size + And as number of junk + And server max tokens to predict + And 42 as seed + And KV cache size + And 1 slots + And group attention factor to extend context size through self-extend + And group attention width to extend context size through self-extend + # Can be override with N_GPU_LAYERS + And GPU offloaded layers + Then the server is starting + Then the server is healthy + Given available models + Then model 0 is trained on tokens context + Given a prefix prompt: + """ + here is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there. + """ + And a passkey prompt template: + """ + The pass key is Remember it. is the pass key. + """ + And a junk suffix prompt: + """ + The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. + """ + And a suffix prompt: + """ + What is the pass key? The pass key is + """ + Given a "" passkey challenge prompt with the passkey inserted every junk + And a completion request with no api error + Then tokens are predicted matching + + Examples: + | hf_repo | hf_file | n_ctx_train | ngl | n_ctx | n_batch | n_ga | n_ga_w | n_junk | i_pos | passkey | n_predicted | re_content | + | TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 8192 | 512 | 4 | 512 | 250 | 50 | 42 | 1 | 42 | + | TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 8192 | 512 | 2 | 512 | 250 | 50 | 42 | 1 | \b((?!42)\w)+\b | + #| TheBloke/Llama-2-7B-GGUF | llama-2-7b.Q2_K.gguf | 4096 | 3 | 16384 | 512 | 4 | 512 | 500 | 300 | 1234 | 5 | 1234 | + #| TheBloke/Mixtral-8x7B-v0.1-GGUF | mixtral-8x7b-v0.1.Q2_K.gguf | 32768 | 2 | 16384 | 512 | 4 | 512 | 500 | 100 | 0987 | 5 | 0 + # 987 | + diff --git a/examples/server/tests/features/security.feature b/examples/server/tests/features/security.feature index db06d3977..42a6709a5 100644 --- a/examples/server/tests/features/security.feature +++ b/examples/server/tests/features/security.feature @@ -1,9 +1,10 @@ @llama.cpp +@security Feature: Security Background: Server startup with an api key defined Given a server listening on localhost:8080 - And a model file stories260K.gguf + And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And a server api key llama.cpp Then the server is starting Then the server is healthy diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index b571582a7..7c977bcce 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -1,15 +1,17 @@ @llama.cpp +@server Feature: llama.cpp server Background: Server startup Given a server listening on localhost:8080 - And a model file stories260K.gguf + And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And a model alias tinyllama-2 And 42 as server seed # KV Cache corresponds to the total amount of tokens # that can be stored across all independent sequences: #4130 # see --ctx-size and #5568 And 32 KV cache size + And 512 as batch size And 1 slots And embeddings extraction And 32 server max tokens to predict @@ -29,9 +31,9 @@ Feature: llama.cpp server And prometheus metrics are exposed Examples: Prompts - | prompt | n_predict | re_content | n_predicted | - | I believe the meaning of life is | 8 | (readgoing)+ | 8 | - | Write a joke about AI | 64 | (parkfriendsscaredalways)+ | 32 | + | prompt | n_predict | re_content | n_predicted | + | I believe the meaning of life is | 8 | (read\|going)+ | 8 | + | Write a joke about AI | 64 | (park\|friends\|scared\|always)+ | 32 | Scenario Outline: OAI Compatibility Given a model @@ -43,9 +45,9 @@ Feature: llama.cpp server Then tokens are predicted matching Examples: Prompts - | model | system_prompt | user_prompt | max_tokens | re_content | n_predicted | enable_streaming | - | llama-2 | Book | What is the best book | 8 | (Momwhat)+ | 8 | disabled | - | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thankshappybird)+ | 32 | enabled | + | model | system_prompt | user_prompt | max_tokens | re_content | n_predicted | enable_streaming | + | llama-2 | Book | What is the best book | 8 | (Mom\|what)+ | 8 | disabled | + | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks\|happy\|bird)+ | 32 | enabled | Scenario: Embedding When embeddings are computed for: @@ -75,10 +77,15 @@ Feature: llama.cpp server When an OAI compatible embeddings computation request for multiple inputs Then embeddings are generated - Scenario: Tokenize / Detokenize When tokenizing: """ What is the capital of France ? """ Then tokens can be detokenize + + Scenario: Models available + Given available models + Then 1 models are supported + Then model 0 is identified by tinyllama-2 + Then model 0 is trained on 128 tokens context diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 381da105e..319527802 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -13,6 +13,7 @@ import aiohttp import openai from behave import step from behave.api.async_step import async_run_until_complete +from huggingface_hub import hf_hub_download from prometheus_client import parser @@ -26,17 +27,23 @@ def step_server_config(context, server_fqdn, server_port): context.base_url = f'http://{context.server_fqdn}:{context.server_port}' - context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON' context.model_alias = None + context.n_batch = None context.n_ctx = None + context.n_ga = None + context.n_ga_w = None + context.n_gpu_layer = None context.n_predict = None context.n_server_predict = None context.n_slots = None + context.prompt_prefix = None + context.prompt_suffix = None context.server_api_key = None context.server_continuous_batching = False context.server_embeddings = False context.server_metrics = False context.server_process = None + context.seed = None context.server_seed = None context.user_api_key = None @@ -45,9 +52,11 @@ def step_server_config(context, server_fqdn, server_port): context.prompts = [] -@step(u'a model file {model_file}') -def step_model_file(context, model_file): - context.model_file = model_file +@step(u'a model file {hf_file} from HF repo {hf_repo}') +def step_download_hf_model(context, hf_file, hf_repo): + context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file) + if context.debug: + print(f"model file: {context.model_file}\n") @step(u'a model alias {model_alias}') @@ -55,24 +64,34 @@ def step_model_alias(context, model_alias): context.model_alias = model_alias -@step(u'{seed} as server seed') +@step(u'{seed:d} as server seed') def step_seed(context, seed): - context.server_seed = int(seed) + context.server_seed = seed -@step(u'{n_ctx} KV cache size') +@step(u'{ngl:d} GPU offloaded layers') +def step_n_gpu_layer(context, ngl): + if 'N_GPU_LAYERS' in os.environ: + new_ngl = int(os.environ['N_GPU_LAYERS']) + if context.debug: + print(f"-ngl upgraded from {ngl} to {new_ngl}") + ngl = new_ngl + context.n_gpu_layer = ngl + + +@step(u'{n_ctx:d} KV cache size') def step_n_ctx(context, n_ctx): - context.n_ctx = int(n_ctx) + context.n_ctx = n_ctx -@step(u'{n_slots} slots') +@step(u'{n_slots:d} slots') def step_n_slots(context, n_slots): - context.n_slots = int(n_slots) + context.n_slots = n_slots -@step(u'{n_predict} server max tokens to predict') +@step(u'{n_predict:d} server max tokens to predict') def step_server_n_predict(context, n_predict): - context.n_server_predict = int(n_predict) + context.n_server_predict = n_predict @step(u'continuous batching') @@ -116,11 +135,13 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status): case 'ready' | 'idle': await wait_for_health_status(context, context.base_url, 200, 'ok', + timeout=10, params={'fail_on_no_slot': 0, 'include_slots': 0}, slots_idle=context.n_slots, slots_processing=0, expected_slots=[{'id': slot_id, 'state': 0} - for slot_id in range(context.n_slots)]) + for slot_id in + range(context.n_slots if context.n_slots else 1)]) case 'busy': await wait_for_health_status(context, context.base_url, 503, 'no slot available', @@ -128,7 +149,8 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status): slots_idle=0, slots_processing=context.n_slots, expected_slots=[{'id': slot_id, 'state': 1} - for slot_id in range(context.n_slots)]) + for slot_id in + range(context.n_slots if context.n_slots else 1)]) case _: assert False, "unknown status" @@ -157,24 +179,24 @@ async def step_request_completion(context, api_error): context.base_url, debug=context.debug, n_predict=context.n_predict, - server_seed=context.server_seed, + seed=await completions_seed(context), expect_api_error=expect_api_error, user_api_key=context.user_api_key) context.tasks_result.append(completion) if context.debug: - print(f"Completion response: {completion}") + print(f"Completion response: {completion}\n") if expect_api_error: assert completion == 401, f"completion must be an 401 status code: {completion}" -@step(u'{predicted_n} tokens are predicted matching {re_content}') +@step(u'{predicted_n:d} tokens are predicted matching {re_content}') def step_n_tokens_predicted_with_content(context, predicted_n, re_content): - assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n), re_content) + assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content) -@step(u'{predicted_n} tokens are predicted') +@step(u'{predicted_n:d} tokens are predicted') def step_n_tokens_predicted(context, predicted_n): - assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n)) + assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n) @step(u'a user prompt {user_prompt}') @@ -192,9 +214,9 @@ def step_model(context, model): context.model = model -@step(u'{max_tokens} max tokens to predict') +@step(u'{max_tokens:d} max tokens to predict') def step_max_tokens(context, max_tokens): - context.n_predict = int(max_tokens) + context.n_predict = max_tokens @step(u'streaming is {enable_streaming}') @@ -222,11 +244,70 @@ def step_server_api_key(context, server_api_key): context.server_api_key = server_api_key +@step(u'{n_junk:d} as number of junk') +def step_n_junk(context, n_junk): + context.n_junk = n_junk + + +@step(u'{n_batch:d} as batch size') +def step_n_batch(context, n_batch): + context.n_batch = n_batch + + +@step(u'{seed:d} as seed') +def step_seed(context, seed): + context.seed = seed + + +@step(u'a prefix prompt') +def step_prompt_prefix(context): + context.prompt_prefix = context.text + + +@step(u'a junk suffix prompt') +def step_prompt_junk_suffix(context): + context.prompt_junk_suffix = context.text + + +@step(u'a suffix prompt') +def step_prompt_suffix(context): + context.prompt_suffix = context.text + + +@step(u'{n_ga:d} group attention factor' + u' to extend context size through self-extend') +def step_impl(context, n_ga): + context.n_ga = n_ga + + +@step(u'{n_ga_w:d} group attention width to extend context size through self-extend') +def step_impl(context, n_ga_w): + context.n_ga_w = n_ga_w + + +@step(u'a passkey prompt template') +def step_prompt_passkey(context): + context.prompt_passkey = context.text + + +@step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk') +def step_prompt_passkey(context, passkey, i_pos): + prompt = "" + for i in range(context.n_junk): + if i % context.n_junk == i_pos: + prompt += context.prompt_passkey # the passkey is already substituted + prompt += context.prompt_junk_suffix + if context.debug: + passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m" + print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n") + context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix) + + @step(u'an OAI compatible chat completions request with {api_error} api error') @async_run_until_complete async def step_oai_chat_completions(context, api_error): if context.debug: - print(f"Submitting OAI compatible completions request...") + print(f"Submitting OAI compatible completions request...\n") expect_api_error = api_error == 'raised' completion = await oai_chat_completions(context.prompts.pop(), context.system_prompt, @@ -241,8 +322,7 @@ async def step_oai_chat_completions(context, api_error): enable_streaming=context.enable_streaming if hasattr(context, 'enable_streaming') else None, - server_seed=context.server_seed - if hasattr(context, 'server_seed') else None, + seed=await completions_seed(context), user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None, @@ -276,8 +356,10 @@ async def step_concurrent_completion_requests(context): # prompt is inserted automatically context.base_url, debug=context.debug, + prompt_prefix=context.prompt_prefix, + prompt_suffix=context.prompt_suffix, n_predict=context.n_predict if hasattr(context, 'n_predict') else None, - server_seed=context.server_seed if hasattr(context, 'server_seed') else None, + seed=await completions_seed(context), user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None) @@ -297,8 +379,7 @@ async def step_oai_chat_completions(context): if hasattr(context, 'n_predict') else None, enable_streaming=context.enable_streaming if hasattr(context, 'enable_streaming') else None, - server_seed=context.server_seed - if hasattr(context, 'server_seed') else None, + seed=await completions_seed(context), user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None) @@ -318,7 +399,9 @@ async def step_oai_chat_completions(context): if hasattr(context, 'n_predict') else None, enable_streaming=context.enable_streaming if hasattr(context, 'enable_streaming') else None, - server_seed=context.server_seed + seed=context.seed + if hasattr(context, 'seed') else + context.server_seed if hasattr(context, 'server_seed') else None, user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None) @@ -330,11 +413,10 @@ async def step_all_prompts_are_predicted(context): await all_prompts_are_predicted(context) -@step(u'all prompts are predicted with {n_predict} tokens') +@step(u'all prompts are predicted with {n_expected_predicted:d} tokens') @async_run_until_complete -async def step_all_prompts_are_predicted_with_n_tokens(context, n_predict): - expected_predicted_n = int(n_predict) - await all_prompts_are_predicted(context, expected_predicted_n) +async def step_all_prompts_are_predicted_with_n_tokens(context, n_expected_predicted): + await all_prompts_are_predicted(context, n_expected_predicted) async def all_prompts_are_predicted(context, expected_predicted_n=None): @@ -464,6 +546,8 @@ async def step_prometheus_metrics_exported(context): assert metrics_response.headers['Content-Type'] == "text/plain; version=0.0.4" metrics_raw = await metrics_response.text() metric_exported = False + if context.debug: + print(f"/metrics answer:\n{metrics_raw}\n") for metric in parser.text_string_to_metric_families(metrics_raw): match metric.name: case "llamacpp:kv_cache_usage_ratio": @@ -472,6 +556,37 @@ async def step_prometheus_metrics_exported(context): assert metric_exported, "No metrics exported" +@step(u'available models') +def step_available_models(context): + # openai client always expects an api_key + openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope' + openai.api_base = f'{context.base_url}/v1' + context.models = openai.Model.list().data + + +@step(u'{n_model:d} models are supported') +def step_supported_models(context, n_model): + if context.debug: + print("server models available:", context.models) + assert len(context.models) == n_model + + +@step(u'model {i_model:d} is {param} {preposition} {param_value}') +def step_supported_models(context, i_model, param, preposition, param_value): + assert i_model < len(context.models) + model = context.models[i_model] + + param_value = param_value.split(' ', 1)[0] + match param: + case 'identified': + value = model.id + case 'trained': + value = str(model.meta.n_ctx_train) + case _: + assert False, "param {param} not supported" + assert param_value == value, f"model param {param} {value} != {param_value}" + + async def concurrent_requests(context, f_completion, *args, **kwargs): n_prompts = len(context.prompts) if context.debug: @@ -486,8 +601,10 @@ async def concurrent_requests(context, f_completion, *args, **kwargs): async def request_completion(prompt, base_url, debug=False, + prompt_prefix=None, + prompt_suffix=None, n_predict=None, - server_seed=None, + seed=None, expect_api_error=None, user_api_key=None): if debug: @@ -504,11 +621,14 @@ async def request_completion(prompt, async with aiohttp.ClientSession() as session: async with session.post(f'{base_url}/completion', json={ + "input_prefix": prompt_prefix, "prompt": prompt, - "n_predict": int(n_predict) if n_predict is not None else -1, - "seed": server_seed if server_seed is not None else 42 + "input_suffix": prompt_suffix, + "n_predict": n_predict if n_predict is not None else -1, + "seed": seed if seed is not None else 42 }, - headers=headers) as response: + headers=headers, + timeout=3600) as response: if expect_api_error is None or not expect_api_error: assert response.status == 200 assert response.headers['Access-Control-Allow-Origin'] == origin @@ -526,14 +646,14 @@ async def oai_chat_completions(user_prompt, model=None, n_predict=None, enable_streaming=None, - server_seed=None, + seed=None, user_api_key=None, expect_api_error=None): if debug: print(f"Sending OAI Chat completions request: {user_prompt}") # openai client always expects an api key user_api_key = user_api_key if user_api_key is not None else 'nope' - seed = server_seed if server_seed is not None else 42 + seed = seed if seed is not None else 42 enable_streaming = enable_streaming if enable_streaming is not None else False payload = { "messages": [ @@ -692,20 +812,32 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re content = completion_response['content'] n_predicted = completion_response['timings']['predicted_n'] assert len(content) > 0, "no token predicted" - if expected_predicted_n is not None: + if re_content is not None: + p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL) + matches = p.finditer(content) + last_match = 0 + highlighted = '' + for match in matches: + start, end = match.span() + highlighted += content[last_match: start] + highlighted += '\x1b[33m' + highlighted += content[start: end] + highlighted += '\x1b[0m' + last_match = end + highlighted += content[last_match:] + if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': + print(f"Checking completion response: {highlighted}\n") + assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```' + if expected_predicted_n and expected_predicted_n > 0: assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:' f' {n_predicted} <> {expected_predicted_n}') - if re_content is not None: - re_content = '^.*' + re_content.replace('', '|') + '.*$' - assert re.match(re_content, content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL), ( - f'invalid tokens predicted:' - f' ```\n{content}\n``` do not match /{re_content}/') + async def gather_tasks_results(context): n_tasks = len(context.concurrent_tasks) if context.debug: - print(f"Waiting for all {n_tasks} tasks results...") + print(f"Waiting for all {n_tasks} tasks results...\n") for task_no in range(n_tasks): context.tasks_result.append(await context.concurrent_tasks.pop()) n_completions = len(context.tasks_result) @@ -716,15 +848,13 @@ async def wait_for_health_status(context, base_url, expected_http_status_code, expected_health_status, + timeout=3, params=None, slots_idle=None, slots_processing=None, expected_slots=None): if context.debug: - print(f"Starting checking for health for expected_health_status={expected_health_status}") - timeout = 3 # seconds - if expected_health_status == 'ok': - timeout = 10 # CI slow inference + print(f"Starting checking for health for expected_health_status={expected_health_status}\n") interval = 0.5 counter = 0 async with aiohttp.ClientSession() as session: @@ -734,7 +864,7 @@ async def wait_for_health_status(context, health = await health_response.json() if context.debug: print(f"HEALTH - response for expected health status='{expected_health_status}' on " - f"'{base_url}/health'?{params} is {health}") + f"'{base_url}/health'?{params} is {health}\n") if (status_code == expected_http_status_code and health['status'] == expected_health_status and (slots_idle is None or health['slots_idle'] == slots_idle) @@ -757,7 +887,7 @@ async def wait_for_health_status(context, if expected_http_status_code == 503: if len(context.tasks_result) == 0: print("\x1b[5;37;43mWARNING: forcing concurrent tasks," - " busy health check missed, probably too fast inference\x1b[0m") + " busy health check missed, probably too fast inference\x1b[0m\n") n_completions = await gather_tasks_results(context) if n_completions > 0: return @@ -791,6 +921,11 @@ def assert_slots_status(slots, expected_slots): f" = {expected[key]} != {slot[key]}") +async def completions_seed(context): + return context.seed if hasattr(context, 'seed') and context.seed is not None \ + else context.server_seed if hasattr(context, 'server_seed') else None + + def start_server_background(context): context.server_path = '../../../build/bin/server' if 'LLAMA_SERVER_BIN_PATH' in os.environ: @@ -800,27 +935,35 @@ def start_server_background(context): '--port', context.server_port, '--model', context.model_file ] + if context.n_batch: + server_args.extend(['--batch-size', context.n_batch]) + if context.n_gpu_layer: + server_args.extend(['--n-gpu-layers', context.n_gpu_layer]) if context.server_continuous_batching: server_args.append('--cont-batching') if context.server_embeddings: server_args.append('--embedding') if context.server_metrics: server_args.append('--metrics') - if context.model_alias is not None: + if context.model_alias: server_args.extend(['--alias', context.model_alias]) - if context.n_ctx is not None: + if context.n_ctx: server_args.extend(['--ctx-size', context.n_ctx]) - if context.n_slots is not None: + if context.n_slots: server_args.extend(['--parallel', context.n_slots]) - if context.n_server_predict is not None: + if context.n_server_predict: server_args.extend(['--n-predict', context.n_server_predict]) - if context.server_api_key is not None: + if context.server_api_key: server_args.extend(['--api-key', context.server_api_key]) + if context.n_ga: + server_args.extend(['--grp-attn-n', context.n_ga]) + if context.n_ga_w: + server_args.extend(['--grp-attn-w', context.n_ga_w]) if context.debug: server_args.append('--verbose') if 'SERVER_LOG_FORMAT_JSON' not in os.environ: server_args.extend(['--log-format', "text"]) - print(f"starting server with: {context.server_path}", *server_args) + print(f"starting server with: {context.server_path} {server_args}\n") context.server_process = subprocess.Popen( [str(arg) for arg in [context.server_path, *server_args]], close_fds=True) diff --git a/examples/server/tests/features/wrong_usages.feature b/examples/server/tests/features/wrong_usages.feature index e228b2371..cf14b3b44 100644 --- a/examples/server/tests/features/wrong_usages.feature +++ b/examples/server/tests/features/wrong_usages.feature @@ -1,4 +1,4 @@ -# run with ./test.sh --tags wrong_usage +# run with: ./tests.sh --no-skipped --tags wrong_usage @wrong_usage Feature: Wrong usage of llama.cpp server @@ -7,7 +7,7 @@ Feature: Wrong usage of llama.cpp server # or pass n_predict/max_tokens in the request. Scenario: Infinite loop Given a server listening on localhost:8080 - And a model file stories260K.gguf + And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models # Uncomment below to fix the issue #And 64 server max tokens to predict Then the server is starting @@ -18,4 +18,5 @@ Feature: Wrong usage of llama.cpp server # Uncomment below to fix the issue #And 128 max tokens to predict Given concurrent completion requests + Then the server is idle Then all prompts are predicted diff --git a/examples/server/tests/requirements.txt b/examples/server/tests/requirements.txt index 334fa4a70..5d4210164 100644 --- a/examples/server/tests/requirements.txt +++ b/examples/server/tests/requirements.txt @@ -1,4 +1,5 @@ aiohttp~=3.9.3 behave~=1.2.6 +huggingface_hub~=0.20.3 openai~=0.25.0 prometheus-client~=0.20.0 diff --git a/examples/server/tests/tests.sh b/examples/server/tests/tests.sh index 17a4e6fc6..1c6c5695f 100755 --- a/examples/server/tests/tests.sh +++ b/examples/server/tests/tests.sh @@ -5,7 +5,7 @@ set -eu if [ $# -lt 1 ] then # Start @llama.cpp scenario - behave --summary --stop --no-capture --exclude 'issues|wrong_usages' --tags llama.cpp + behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp else behave "$@" fi diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 4c7ade096..dfd38e2ce 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -219,8 +219,7 @@ static inline void server_log( for (const auto& el : log.items()) { const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace); - snprintf(buf, 1024, "\033[85;0H %s=%s", el.key().c_str(), value.c_str()); - ss << buf; + ss << " " << el.key() << "=" << value; } const std::string str = ss.str(); diff --git a/flake.lock b/flake.lock index 9f659ba8f..b1b091656 100644 --- a/flake.lock +++ b/flake.lock @@ -5,11 +5,11 @@ "nixpkgs-lib": "nixpkgs-lib" }, "locked": { - "lastModified": 1706830856, - "narHash": "sha256-a0NYyp+h9hlb7ddVz4LUn1vT/PLwqfrWYcHMvFB1xYg=", + "lastModified": 1709336216, + "narHash": "sha256-Dt/wOWeW6Sqm11Yh+2+t0dfEWxoMxGBvv3JpIocFl9E=", "owner": "hercules-ci", "repo": "flake-parts", - "rev": "b253292d9c0a5ead9bc98c4e9a26c6312e27d69f", + "rev": "f7b3c975cf067e56e7cda6cb098ebe3fb4d74ca2", "type": "github" }, "original": { @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1708655239, - "narHash": "sha256-ZrP/yACUvDB+zbqYJsln4iwotbH6CTZiTkANJ0AgDv4=", + "lastModified": 1709237383, + "narHash": "sha256-cy6ArO4k5qTx+l5o+0mL9f5fa86tYUX3ozE1S+Txlds=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "cbc4211f0afffe6dfd2478a62615dd5175a13f9a", + "rev": "1536926ef5621b09bba54035ae2bb6d806d72ac8", "type": "github" }, "original": { @@ -37,11 +37,11 @@ "nixpkgs-lib": { "locked": { "dir": "lib", - "lastModified": 1706550542, - "narHash": "sha256-UcsnCG6wx++23yeER4Hg18CXWbgNpqNXcHIo5/1Y+hc=", + "lastModified": 1709237383, + "narHash": "sha256-cy6ArO4k5qTx+l5o+0mL9f5fa86tYUX3ozE1S+Txlds=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "97b17f32362e475016f942bbdfda4a4a72a8a652", + "rev": "1536926ef5621b09bba54035ae2bb6d806d72ac8", "type": "github" }, "original": { diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0c6501e98..7ed97430f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2018,74 +2018,73 @@ static const __device__ uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; -static const __device__ uint32_t iq3xs_grid[512] = { - 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14, - 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414, - 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24, - 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c, - 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c, - 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34, - 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c, - 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414, - 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c, - 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404, - 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434, - 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c, - 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404, - 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414, - 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414, - 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404, - 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c, - 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c, - 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404, - 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e, - 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14, - 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c, - 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424, - 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c, - 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c, - 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e, - 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e, - 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e, - 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424, - 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e, - 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424, - 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404, - 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c, - 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e, - 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c, - 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c, - 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c, - 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404, - 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04, - 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c, - 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414, - 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c, - 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c, - 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424, - 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c, - 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c, - 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414, - 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c, - 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e, - 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04, - 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424, - 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14, - 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34, - 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c, - 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434, - 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c, - 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424, - 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24, - 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24, - 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e, - 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c, - 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c, - 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c, - 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, +static const __device__ uint32_t iq3s_grid[512] = { + 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, + 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, + 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09, + 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b, + 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b, + 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d, + 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03, + 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505, + 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03, + 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901, + 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d, + 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303, + 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501, + 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105, + 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505, + 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101, + 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707, + 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b, + 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01, + 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f, + 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305, + 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103, + 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509, + 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503, + 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b, + 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f, + 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f, + 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f, + 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109, + 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f, + 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509, + 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501, + 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303, + 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f, + 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907, + 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703, + 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03, + 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01, + 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01, + 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903, + 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505, + 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b, + 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107, + 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509, + 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303, + 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103, + 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05, + 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b, + 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f, + 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701, + 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909, + 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305, + 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d, + 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b, + 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d, + 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307, + 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09, + 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309, + 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709, + 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f, + 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303, + 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503, + 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, + 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, }; - static const __device__ uint64_t iq1s_grid[512] = { 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01, @@ -2392,9 +2391,9 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ const int ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; const uint8_t * qs = x[i].qs + 8*ib; - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256))); - const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f; + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256))); + const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)); const uint8_t signs = x[i].signs[4*ib + il]; for (int j = 0; j < 4; ++j) { y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); @@ -5211,8 +5210,8 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( const int8_t * q8 = bq8_1[ib32].qs; int sumi = 0; for (int l = 0; l < 4; ++l) { - const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)); - const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)); + const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)); + const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)); uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); const int grid_l = __vsub4(grid1[0] ^ signs0, signs0); @@ -5221,7 +5220,7 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( sumi = __dp4a(grid_h, *((int *)q8+1), sumi); q8 += 8; } - const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f; + const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds); return d * sumi; #else assert(false); diff --git a/ggml-metal.metal b/ggml-metal.metal index 74a5e0b03..8b9488437 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4087,71 +4087,71 @@ constexpr constant static uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; -constexpr constant static uint32_t iq3xs_grid[512] = { - 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14, - 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414, - 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24, - 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c, - 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c, - 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34, - 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c, - 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414, - 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c, - 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404, - 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434, - 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c, - 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404, - 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414, - 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414, - 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404, - 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c, - 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c, - 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404, - 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e, - 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14, - 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c, - 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424, - 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c, - 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c, - 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e, - 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e, - 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e, - 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424, - 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e, - 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424, - 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404, - 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c, - 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e, - 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c, - 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c, - 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c, - 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404, - 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04, - 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c, - 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414, - 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c, - 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c, - 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424, - 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c, - 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c, - 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414, - 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c, - 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e, - 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04, - 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424, - 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14, - 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34, - 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c, - 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434, - 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c, - 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424, - 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24, - 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24, - 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e, - 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c, - 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c, - 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c, - 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, +constexpr constant static uint32_t iq3s_grid[512] = { + 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, + 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, + 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09, + 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b, + 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b, + 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d, + 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03, + 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505, + 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03, + 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901, + 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d, + 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303, + 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501, + 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105, + 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505, + 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101, + 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707, + 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b, + 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01, + 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f, + 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305, + 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103, + 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509, + 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503, + 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b, + 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f, + 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f, + 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f, + 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109, + 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f, + 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509, + 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501, + 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303, + 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f, + 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907, + 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703, + 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03, + 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01, + 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01, + 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903, + 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505, + 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b, + 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107, + 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509, + 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303, + 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103, + 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05, + 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b, + 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f, + 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701, + 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909, + 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305, + 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d, + 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b, + 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d, + 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307, + 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09, + 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309, + 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709, + 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f, + 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303, + 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503, + 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, + 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, }; #define NGRID_IQ1S 512 @@ -4742,7 +4742,7 @@ void kernel_mul_mv_iq3_s_f32_impl( { int nval = 8; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i]; + for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4769,12 +4769,14 @@ void kernel_mul_mv_iq3_s_f32_impl( for (int row = 0; row < N_DST; row++) { const float db = dh[0]; - const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf)); + const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); float2 sum = {0}; for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); + const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; + const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); for (int j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); @@ -4795,7 +4797,7 @@ void kernel_mul_mv_iq3_s_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; } } } @@ -5685,15 +5687,15 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & device const uint8_t * qs = xb->qs + 8*ib32; device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; const uint8_t qh = xb->qh[ib32] >> 4*il; - const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f; - constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256))); - constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256))); + const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); + constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); + constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); for (int i = 0; i < 4; ++i) { reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); } - grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256))); - grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256))); + grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); + grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); for (int i = 0; i < 4; ++i) { reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); diff --git a/ggml-quants.c b/ggml-quants.c index 371826f14..2a8881d73 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3818,71 +3818,71 @@ static const uint32_t iq3xxs_grid[256] = { 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, }; -static const uint32_t iq3xs_grid[512] = { - 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14, - 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414, - 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24, - 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c, - 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c, - 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34, - 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c, - 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414, - 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c, - 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404, - 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434, - 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c, - 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404, - 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414, - 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414, - 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404, - 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c, - 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c, - 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404, - 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e, - 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14, - 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c, - 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424, - 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c, - 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c, - 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e, - 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e, - 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e, - 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424, - 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e, - 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424, - 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404, - 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c, - 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e, - 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c, - 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c, - 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c, - 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404, - 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04, - 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c, - 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414, - 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c, - 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c, - 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424, - 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c, - 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c, - 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414, - 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c, - 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e, - 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04, - 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424, - 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14, - 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34, - 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c, - 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434, - 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c, - 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424, - 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24, - 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24, - 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e, - 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c, - 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c, - 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c, - 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, +static const uint32_t iq3s_grid[512] = { + 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, + 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, + 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09, + 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b, + 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b, + 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d, + 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03, + 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505, + 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03, + 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901, + 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d, + 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303, + 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501, + 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105, + 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505, + 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101, + 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707, + 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b, + 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01, + 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f, + 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305, + 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103, + 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509, + 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503, + 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b, + 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f, + 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f, + 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f, + 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109, + 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f, + 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509, + 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501, + 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303, + 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f, + 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907, + 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703, + 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03, + 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01, + 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01, + 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903, + 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505, + 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b, + 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107, + 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509, + 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303, + 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103, + 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05, + 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b, + 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f, + 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701, + 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909, + 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305, + 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d, + 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b, + 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d, + 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307, + 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09, + 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309, + 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709, + 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f, + 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303, + 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503, + 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, + 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, }; #define NGRID_IQ2XXS 512 @@ -4162,11 +4162,11 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in const uint8_t * signs = x[i].signs; for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const float db1 = d * (0.5f + (x[i].scales[ib32/2] & 0xf)) * 0.5f; - const float db2 = d * (0.5f + (x[i].scales[ib32/2] >> 4)) * 0.5f; + const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf)); + const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4)); for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); for (int j = 0; j < 4; ++j) { y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); @@ -4176,8 +4176,8 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in qs += 8; signs += 4; for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256))); + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256))); for (int j = 0; j < 4; ++j) { y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); @@ -10089,18 +10089,34 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v #if defined(__ARM_NEON) + typedef union { + uint16x8_t vec_index; + uint16_t index[8]; + } vec_index_t; + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 }; static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1); - const uint8x16_t mask2 = vld1q_u8(k_mask2); + static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + + const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1); + const uint8x16_t mask2 = vld1q_u8(k_mask2); + const int16x8_t hshift = vld1q_s16(k_shift); + const uint16x8_t m256 = vdupq_n_u16(256); + const uint8x16_t m1 = vdupq_n_u8(1); uint8x16x2_t vs; ggml_int8x16x4_t q3s; ggml_int8x16x4_t q8b; + vec_index_t idx; + +#if QK_K == 256 + uint32_t scales32[2]; + const uint8_t * scales8 = (const uint8_t *)scales32; +#endif float sumf = 0; for (int i = 0; i < nb; ++i) { @@ -10109,47 +10125,63 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const uint8_t * restrict qh = x[i].qh; const uint16_t * restrict signs = (const uint16_t *)x[i].signs; const int8_t * restrict q8 = y[i].qs; + +#if QK_K == 256 + memcpy(scales32, x[i].scales, 4); + scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; + scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; +#endif + int sumi1 = 0, sumi2 = 0; for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - const uint32x4_t aux32x4_0 = {iq3xs_grid[qs[ 0] | ((qh[ib32+0] << 8) & 256)], iq3xs_grid[qs[ 1] | ((qh[ib32+0] << 7) & 256)], - iq3xs_grid[qs[ 2] | ((qh[ib32+0] << 6) & 256)], iq3xs_grid[qs[ 3] | ((qh[ib32+0] << 5) & 256)]}; - const uint32x4_t aux32x4_1 = {iq3xs_grid[qs[ 4] | ((qh[ib32+0] << 4) & 256)], iq3xs_grid[qs[ 5] | ((qh[ib32+0] << 3) & 256)], - iq3xs_grid[qs[ 6] | ((qh[ib32+0] << 2) & 256)], iq3xs_grid[qs[ 7] | ((qh[ib32+0] << 1) & 256)]}; - const uint32x4_t aux32x4_2 = {iq3xs_grid[qs[ 8] | ((qh[ib32+1] << 8) & 256)], iq3xs_grid[qs[ 9] | ((qh[ib32+1] << 7) & 256)], - iq3xs_grid[qs[10] | ((qh[ib32+1] << 6) & 256)], iq3xs_grid[qs[11] | ((qh[ib32+1] << 5) & 256)]}; - const uint32x4_t aux32x4_3 = {iq3xs_grid[qs[12] | ((qh[ib32+1] << 4) & 256)], iq3xs_grid[qs[13] | ((qh[ib32+1] << 3) & 256)], - iq3xs_grid[qs[14] | ((qh[ib32+1] << 2) & 256)], iq3xs_grid[qs[15] | ((qh[ib32+1] << 1) & 256)]}; - qs += 16; + + const uint8x16_t idx_l = vld1q_u8(qs); qs += 16; + idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256)); + const uint32x4_t aux32x4_0 = {iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], + iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]}; + const uint32x4_t aux32x4_1 = {iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], + iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]}; + idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256)); + const uint32x4_t aux32x4_2 = {iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], + iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]}; + const uint32x4_t aux32x4_3 = {iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], + iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]}; + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16))); vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vceqq_u8(vs.val[0], mask2); - vs.val[1] = vceqq_u8(vs.val[1], mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); - q3s.val[0] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_0))), vreinterpretq_s8_u8(vs.val[0])); - q3s.val[1] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_1))), vreinterpretq_s8_u8(vs.val[1])); + q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0)); + q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1)); vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16))); vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vceqq_u8(vs.val[0], mask2); - vs.val[1] = vceqq_u8(vs.val[1], mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); signs += 4; - q3s.val[2] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_2))), vreinterpretq_s8_u8(vs.val[0])); - q3s.val[3] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_3))), vreinterpretq_s8_u8(vs.val[1])); + q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2)); + q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3)); const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); +#if QK_K == 256 + sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0]; + sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4]; +#else sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf)); sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4)); +#endif } sumf += d*(sumi1 + sumi2); } - *s = 0.25f * sumf; + *s = sumf; #elif defined(__AVX2__) @@ -10164,6 +10196,16 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); + const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); + const __m256i idx_mask = _mm256_set1_epi32(256); + + typedef union { + __m256i vec[2]; + uint32_t index[16]; + } index_t; + + index_t idx; + __m256 accumf = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; @@ -10176,24 +10218,25 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q2_1 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+0] << 1) & 256)], - iq3xs_grid[qs[6] | ((qh[ib32+0] << 2) & 256)], - iq3xs_grid[qs[5] | ((qh[ib32+0] << 3) & 256)], - iq3xs_grid[qs[4] | ((qh[ib32+0] << 4) & 256)], - iq3xs_grid[qs[3] | ((qh[ib32+0] << 5) & 256)], - iq3xs_grid[qs[2] | ((qh[ib32+0] << 6) & 256)], - iq3xs_grid[qs[1] | ((qh[ib32+0] << 7) & 256)], - iq3xs_grid[qs[0] | ((qh[ib32+0] << 8) & 256)]); - qs += 8; - const __m256i q2_2 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+1] << 1) & 256)], - iq3xs_grid[qs[6] | ((qh[ib32+1] << 2) & 256)], - iq3xs_grid[qs[5] | ((qh[ib32+1] << 3) & 256)], - iq3xs_grid[qs[4] | ((qh[ib32+1] << 4) & 256)], - iq3xs_grid[qs[3] | ((qh[ib32+1] << 5) & 256)], - iq3xs_grid[qs[2] | ((qh[ib32+1] << 6) & 256)], - iq3xs_grid[qs[1] | ((qh[ib32+1] << 7) & 256)], - iq3xs_grid[qs[0] | ((qh[ib32+1] << 8) & 256)]); - qs += 8; + const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16; + idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]); + idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]); + idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask); + idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask); + idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l))); + idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1))); + + // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. + //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); + //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); + const __m256i q2_1 = _mm256_set_epi32( + iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], + iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] + ); + const __m256i q2_2 = _mm256_set_epi32( + iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], + iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] + ); __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); @@ -10221,7 +10264,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v } - *s = 0.25f * hsum_float_8(accumf); + *s = hsum_float_8(accumf); #else @@ -10238,8 +10281,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; int32_t sumi = 0; for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); for (int j = 0; j < 4; ++j) { sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); @@ -10251,8 +10294,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v bsum += sumi * ls1; sumi = 0; for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); for (int j = 0; j < 4; ++j) { sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); @@ -10265,7 +10308,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v } sumf += d * bsum; } - *s = 0.25f * sumf; + *s = sumf; #endif } @@ -11912,7 +11955,8 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo } float best = 0; float scale = max/(2*kMaxQ-1); - for (int is = -15; is <= 15; ++is) { + for (int k = 0; k < bs4; ++k) is_on_grid[k] = false; + for (int is = -9; is <= 9; ++is) { float id = (2*kMaxQ-1+is*0.2f)/max; float this_scale = 1/id; for (int k = 0; k < bs4; ++k) { @@ -11948,7 +11992,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo if (n_not_ongrid > 0 && scale > 0) { float id = 1/scale; for (int k = 0; k < bs4; ++k) { - if (is_on_grid[k]) continue; + //if (is_on_grid[k]) continue; uint16_t u = 0; for (int i = 0; i < 4; ++i) { int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); @@ -12004,7 +12048,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo } float d = max_scale/31; - y[ibl].d = GGML_FP32_TO_FP16(d); + y[ibl].d = GGML_FP32_TO_FP16(d * 1.033f); float id = 1/d; for (int ib = 0; ib < QK_K/block_size; ib += 2) { int l1 = nearest_int(0.5f*(id*scales[ib+0]-1)); diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index e4681475c..801160832 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -362,7 +362,7 @@ class GGUFWriter: self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value) def add_pooling_type(self, value: PoolingType) -> None: - self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value) + self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) def add_rope_dimension_count(self, count: int) -> None: self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count) diff --git a/llama.cpp b/llama.cpp index 05e2980f3..02c643a2f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -216,7 +216,7 @@ enum llm_arch { LLM_ARCH_UNKNOWN, }; -static std::map LLM_ARCH_NAMES = { +static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA, "llama" }, { LLM_ARCH_FALCON, "falcon" }, { LLM_ARCH_GPT2, "gpt2" }, @@ -241,6 +241,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_MINICPM, "minicpm" }, { LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_UNKNOWN, "(unknown)" }, }; enum llm_kv { @@ -301,7 +302,7 @@ enum llm_kv { LLM_KV_TOKENIZER_RWKV, }; -static std::map LLM_KV_NAMES = { +static const std::map LLM_KV_NAMES = { { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, @@ -365,7 +366,7 @@ struct LLM_KV { llm_arch arch; std::string operator()(llm_kv kv) const { - return ::format(LLM_KV_NAMES[kv], LLM_ARCH_NAMES[arch]); + return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); } }; @@ -400,7 +401,7 @@ enum llm_tensor { LLM_TENSOR_LAYER_OUT_NORM, }; -static std::map> LLM_TENSOR_NAMES = { +static const std::map> LLM_TENSOR_NAMES = { { LLM_ARCH_LLAMA, { @@ -833,38 +834,38 @@ struct LLM_TN { llm_arch arch; std::string operator()(llm_tensor tensor) const { - if (LLM_TENSOR_NAMES[arch].find(tensor) == LLM_TENSOR_NAMES[arch].end()) { + if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { return "__missing__"; } - return LLM_TENSOR_NAMES[arch].at(tensor); + return LLM_TENSOR_NAMES.at(arch).at(tensor); } std::string operator()(llm_tensor tensor, const std::string & suffix) const { - if (LLM_TENSOR_NAMES[arch].find(tensor) == LLM_TENSOR_NAMES[arch].end()) { + if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { return "__missing__"; } - return LLM_TENSOR_NAMES[arch].at(tensor) + "." + suffix; + return LLM_TENSOR_NAMES.at(arch).at(tensor) + "." + suffix; } std::string operator()(llm_tensor tensor, int bid) const { - if (LLM_TENSOR_NAMES[arch].find(tensor) == LLM_TENSOR_NAMES[arch].end()) { + if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { return "__missing__"; } - return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid); + return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid); } std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const { - if (LLM_TENSOR_NAMES[arch].find(tensor) == LLM_TENSOR_NAMES[arch].end()) { + if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { return "__missing__"; } - return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix; + return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid) + "." + suffix; } std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const { - if (LLM_TENSOR_NAMES[arch].find(tensor) == LLM_TENSOR_NAMES[arch].end()) { + if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { return "__missing__"; } - return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid, xid) + "." + suffix; + return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid, xid) + "." + suffix; } }; @@ -872,7 +873,7 @@ struct LLM_TN { // gguf helpers // -static std::map LLAMA_ROPE_SCALING_TYPES = { +static const std::map LLAMA_ROPE_SCALING_TYPES = { { LLAMA_ROPE_SCALING_TYPE_NONE, "none" }, { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" }, { LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" }, @@ -1986,6 +1987,9 @@ struct llama_context { std::vector buf_compute_meta; ggml_backend_sched_t sched = nullptr; + ggml_abort_callback abort_callback = nullptr; + void * abort_callback_data = nullptr; + // input tensors ggml_backend_buffer_t buf_input = nullptr; ggml_context * ctx_input = nullptr; @@ -8070,6 +8074,7 @@ static void llama_graph_compute( if (lctx.backend_cpu != nullptr) { ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads); + ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data); } ggml_backend_sched_graph_compute(lctx.sched, gf); @@ -10835,7 +10840,7 @@ struct quantize_state_internal { {} }; -static void llama_convert_tensor_internal( +static void llama_tensor_dequantize_internal( struct ggml_tensor * tensor, std::vector> & output, std::vector & workers, const size_t nelements, const int nthread ) { @@ -11176,6 +11181,46 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty return new_type; } +static int32_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int chunk_size, int nrows, int n_per_row, int64_t * hist_cur, const float * imatrix, std::vector & workers, const int nthread) { + std::mutex mutex; + int counter = 0; + size_t new_size = 0; + if (nthread < 2) { + // single-thread + return ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, hist_cur, imatrix); + } + auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, chunk_size, + nrows, n_per_row, imatrix]() { + std::array local_hist = {}; + const int nrows_per_chunk = chunk_size / n_per_row; + size_t local_size = 0; + while (true) { + std::unique_lock lock(mutex); + int first_row = counter; counter += nrows_per_chunk; + if (first_row >= nrows) { + if (local_size > 0) { + for (int j=0; jftype; @@ -11288,7 +11333,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s std::vector workers; workers.reserve(nthread); - std::mutex mutex; int idx = 0; @@ -11402,7 +11446,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); } else { - llama_convert_tensor_internal(tensor, f32_conv_buf, workers, nelements, nthread); + llama_tensor_dequantize_internal(tensor, f32_conv_buf, workers, nelements, nthread); f32_data = (float *) f32_conv_buf.data(); } @@ -11423,41 +11467,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s const int nchunk = (nelements + chunk_size - 1)/chunk_size; const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1; - if (nthread_use < 2) { - new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, hist_cur.data(), imatrix); - } else { - int counter = 0; - new_size = 0; - auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, chunk_size, - nrows, n_per_row, imatrix]() { - std::array local_hist = {}; - const int nrows_per_chunk = chunk_size / n_per_row; - size_t local_size = 0; - while (true) { - std::unique_lock lock(mutex); - int first_row = counter; counter += nrows_per_chunk; - if (first_row >= nrows) { - if (local_size > 0) { - for (int j=0; j %8.2f MiB", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); int64_t tot_count = 0; @@ -11850,6 +11860,8 @@ struct llama_context_params llama_context_default_params() { /*.embedding =*/ false, /*.offload_kqv =*/ true, /*.do_pooling =*/ true, + /*.abort_callback =*/ nullptr, + /*.abort_callback_data =*/ nullptr, }; return result; @@ -12032,8 +12044,11 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); - ctx->rng = std::mt19937(params.seed); - ctx->logits_all = params.logits_all; + ctx->abort_callback = params.abort_callback; + ctx->abort_callback_data = params.abort_callback_data; + + ctx->rng = std::mt19937(params.seed); + ctx->logits_all = params.logits_all; const ggml_type type_k = params.type_k; const ggml_type type_v = params.type_v; @@ -12983,6 +12998,11 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_ ctx->cparams.n_threads_batch = n_threads_batch; } +void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { + ctx->abort_callback = abort_callback; + ctx->abort_callback_data = abort_callback_data; +} + struct llama_batch llama_batch_get_one( llama_token * tokens, int32_t n_tokens, diff --git a/llama.h b/llama.h index ed51f478a..6406b5270 100644 --- a/llama.h +++ b/llama.h @@ -255,10 +255,16 @@ extern "C" { enum ggml_type type_v; // data type for V cache // Keep the booleans together to avoid misalignment during copy-by-value. - bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) + bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embedding; // embedding mode only bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) + + // Abort callback + // if it returns true, execution of llama_decode() will be aborted + // currently works only with CPU execution + ggml_abort_callback abort_callback; + void * abort_callback_data; }; // model quantization parameters @@ -632,7 +638,10 @@ extern "C" { // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); - // Token logits obtained from the last call to llama_eval() + // Set abort callback + LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); + + // Token logits obtained from the last call to llama_decode() // The logits for the last token are stored in the last row // Logits for which llama_batch.logits[i] == 0 are undefined // Rows: n_tokens provided with llama_batch diff --git a/scripts/pod-llama.sh b/scripts/pod-llama.sh new file mode 100644 index 000000000..6cf1ab4f3 --- /dev/null +++ b/scripts/pod-llama.sh @@ -0,0 +1,213 @@ +#!/bin/bash +# +# Use this script only on fresh pods (runpod.io)! +# Otherwise, it can break your environment! +# + +if [ -z "$1" ]; then + echo "Usage: $0 " + echo " 0: no models" + echo " 1: tinyllama-1b" + echo " 2: codellama-7b" + echo " 3: codellama-13b" + echo " 4: codellama-34b" + echo " 5: codellama-7b-instruct" + echo " 6: codellama-13b-instruct" + echo " 7: codellama-34b-instruct" + + exit 1 +fi + +set -x + +# setup deps +apt-get update +apt-get install -y git-lfs cmake cmake-curses-gui vim ruby +git-lfs install + +if [ ! -d "/workspace" ]; then + ln -sfn $(pwd) /workspace +fi + +# download data +cd /workspace + +# this is useful to git clone repos without doubling the disk size due to .git +git clone https://github.com/iboB/git-lfs-download +ln -sfn /workspace/git-lfs-download/git-lfs-download /usr/local/bin/git-lfs-download + +# llama.cpp +cd /workspace +git clone https://github.com/ggerganov/llama.cpp + +cd llama.cpp + +LLAMA_CUBLAS=1 make -j + +ln -sfn /workspace/TinyLlama-1.1B-Chat-v0.3 ./models/tinyllama-1b +ln -sfn /workspace/CodeLlama-7b-hf ./models/codellama-7b +ln -sfn /workspace/CodeLlama-13b-hf ./models/codellama-13b +ln -sfn /workspace/CodeLlama-34b-hf ./models/codellama-34b +ln -sfn /workspace/CodeLlama-7b-Instruct-hf ./models/codellama-7b-instruct +ln -sfn /workspace/CodeLlama-13b-Instruct-hf ./models/codellama-13b-instruct +ln -sfn /workspace/CodeLlama-34b-Instruct-hf ./models/codellama-34b-instruct + +pip install -r requirements.txt + +# cmake +cd /workspace/llama.cpp + +mkdir build-cublas +cd build-cublas + +cmake -DLLAMA_CUBLAS=1 ../ +make -j + +if [ "$1" -eq "0" ]; then + exit 0 +fi + +# more models +if [ "$1" -eq "1" ]; then + cd /workspace + + git-lfs-download https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.3 + + cd /workspace/llama.cpp + + python3 convert.py ./models/tinyllama-1b --outfile ./models/tinyllama-1b/ggml-model-f16.gguf --outtype f16 + + ./quantize ./models/tinyllama-1b/ggml-model-f16.gguf ./models/tinyllama-1b/ggml-model-q4_0.gguf q4_0 + ./quantize ./models/tinyllama-1b/ggml-model-f16.gguf ./models/tinyllama-1b/ggml-model-q4_k.gguf q4_k + ./quantize ./models/tinyllama-1b/ggml-model-f16.gguf ./models/tinyllama-1b/ggml-model-q8_0.gguf q8_0 +fi + +if [ "$1" -eq "2" ]; then + cd /workspace + + git-lfs-download https://huggingface.co/codellama/CodeLlama-7b-hf --without *safetensors* + rm -v ./CodeLlama-7b-hf/*safetensors* + + cd /workspace/llama.cpp + + python3 convert.py ./models/codellama-7b --outfile ./models/codellama-7b/ggml-model-f16.gguf --outtype f16 + + ./quantize ./models/codellama-7b/ggml-model-f16.gguf ./models/codellama-7b/ggml-model-q4_0.gguf q4_0 + ./quantize ./models/codellama-7b/ggml-model-f16.gguf ./models/codellama-7b/ggml-model-q4_k.gguf q4_k + ./quantize ./models/codellama-7b/ggml-model-f16.gguf ./models/codellama-7b/ggml-model-q8_0.gguf q8_0 +fi + +if [ "$1" -eq "3" ]; then + cd /workspace + + git-lfs-download https://huggingface.co/codellama/CodeLlama-13b-hf --without *safetensors* + rm -v ./CodeLlama-13b-hf/*safetensors* + + cd /workspace/llama.cpp + + python3 convert.py ./models/codellama-13b --outfile ./models/codellama-13b/ggml-model-f16.gguf --outtype f16 + + ./quantize ./models/codellama-13b/ggml-model-f16.gguf ./models/codellama-13b/ggml-model-q4_0.gguf q4_0 + ./quantize ./models/codellama-13b/ggml-model-f16.gguf ./models/codellama-13b/ggml-model-q4_k.gguf q4_k + ./quantize ./models/codellama-13b/ggml-model-f16.gguf ./models/codellama-13b/ggml-model-q8_0.gguf q8_0 +fi + +if [ "$1" -eq "4" ]; then + cd /workspace + + git-lfs-download https://huggingface.co/codellama/CodeLlama-34b-hf --without *safetensors* + rm -v ./CodeLlama-34b-hf/*safetensors* + + cd /workspace/llama.cpp + + python3 convert.py ./models/codellama-34b --outfile ./models/codellama-34b/ggml-model-f16.gguf --outtype f16 + + ./quantize ./models/codellama-34b/ggml-model-f16.gguf ./models/codellama-34b/ggml-model-q4_0.gguf q4_0 + ./quantize ./models/codellama-34b/ggml-model-f16.gguf ./models/codellama-34b/ggml-model-q4_k.gguf q4_k + ./quantize ./models/codellama-34b/ggml-model-f16.gguf ./models/codellama-34b/ggml-model-q8_0.gguf q8_0 +fi + +if [ "$1" -eq "5" ]; then + cd /workspace + + git-lfs-download https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf --without *safetensors* + rm -v ./CodeLlama-7b-Instruct-hf/*safetensors* + + cd /workspace/llama.cpp + + python3 convert.py ./models/codellama-7b-instruct --outfile ./models/codellama-7b-instruct/ggml-model-f16.gguf --outtype f16 + + ./quantize ./models/codellama-7b-instruct/ggml-model-f16.gguf ./models/codellama-7b-instruct/ggml-model-q4_0.gguf q4_0 + ./quantize ./models/codellama-7b-instruct/ggml-model-f16.gguf ./models/codellama-7b-instruct/ggml-model-q4_k.gguf q4_k + ./quantize ./models/codellama-7b-instruct/ggml-model-f16.gguf ./models/codellama-7b-instruct/ggml-model-q8_0.gguf q8_0 +fi + +if [ "$1" -eq "6" ]; then + cd /workspace + + git-lfs-download https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf --without *safetensors* + rm -v ./CodeLlama-13b-Instruct-hf/*safetensors* + + cd /workspace/llama.cpp + + python3 convert.py ./models/codellama-13b-instruct --outfile ./models/codellama-13b-instruct/ggml-model-f16.gguf --outtype f16 + + ./quantize ./models/codellama-13b-instruct/ggml-model-f16.gguf ./models/codellama-13b-instruct/ggml-model-q4_0.gguf q4_0 + ./quantize ./models/codellama-13b-instruct/ggml-model-f16.gguf ./models/codellama-13b-instruct/ggml-model-q4_k.gguf q4_k + ./quantize ./models/codellama-13b-instruct/ggml-model-f16.gguf ./models/codellama-13b-instruct/ggml-model-q8_0.gguf q8_0 +fi + +if [ "$1" -eq "7" ]; then + cd /workspace + + git-lfs-download https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf --without *safetensors* + rm -v ./CodeLlama-34b-Instruct-hf/*safetensors* + + cd /workspace/llama.cpp + + python3 convert.py ./models/codellama-34b-instruct --outfile ./models/codellama-34b-instruct/ggml-model-f16.gguf --outtype f16 + + ./quantize ./models/codellama-34b-instruct/ggml-model-f16.gguf ./models/codellama-34b-instruct/ggml-model-q4_0.gguf q4_0 + ./quantize ./models/codellama-34b-instruct/ggml-model-f16.gguf ./models/codellama-34b-instruct/ggml-model-q4_k.gguf q4_k + ./quantize ./models/codellama-34b-instruct/ggml-model-f16.gguf ./models/codellama-34b-instruct/ggml-model-q8_0.gguf q8_0 +fi + +if [ "$1" -eq "1" ]; then + # perf + perplexity + cd /workspace/llama.cpp/build-cublas + + make -j && ../scripts/run-all-perf.sh tinyllama-1b "f16" "-ngl 99 -t 1 -p 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,64,128,256,512,1024,2048 -n 128" + + ../scripts/get-wikitext-2.sh + unzip wikitext-2-raw-v1.zip + + make -j && ./bin/perplexity -m ../models/tinyllama-1b/ggml-model-f16.gguf -f ./wikitext-2-raw/wiki.test.raw -ngl 100 --chunks 32 + + # batched + cd /workspace/llama.cpp + + LLAMA_CUBLAS=1 make -j && ./batched ./models/tinyllama-1b/ggml-model-f16.gguf "Hello, my name is" 8 128 999 + + # batched-bench + cd /workspace/llama.cpp + + LLAMA_CUBLAS=1 make -j && ./batched-bench ./models/tinyllama-1b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32 + + # parallel + cd /workspace/llama.cpp + + LLAMA_CUBLAS=1 make -j && ./parallel -m ./models/tinyllama-1b/ggml-model-f16.gguf -t 1 -ngl 100 -c 4096 -b 512 -s 1 -np 8 -ns 128 -n 100 -cb + +fi + +# speculative +#if [ "$1" -eq "7" ]; then +# cd /workspace/llama.cpp +# +# LLAMA_CUBLAS=1 make -j && ./speculative -m ./models/codellama-34b-instruct/ggml-model-f16.gguf -md ./models/codellama-7b-instruct/ggml-model-q4_0.gguf -p "# Dijkstra's shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n" -e -ngl 999 -ngld 999 -t 4 -n 512 -c 4096 -s 21 --draft 16 -np 1 --temp 0.0 +#fi + +# more benches +#LLAMA_CUBLAS=1 make -j && ./batched-bench ./models/codellama-7b/ggml-model-q4_k.gguf 4096 1 99 1 512,3200 128,128,800 1 +#LLAMA_CUBLAS=1 make -j && ./batched-bench ./models/codellama-13b/ggml-model-q4_k.gguf 4096 1 99 1 512,3200 128,128,800 1 +