diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1f275603a..a1e183d11 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -860,7 +860,7 @@ jobs: mkdir build cd build cmake .. -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_CUDA=ON -DBUILD_SHARED_LIBS=ON - cmake --build . --config Release -j ${env:NUMBER_OF_PROCESSORS} + cmake --build . --config Release -j $((${env:NUMBER_OF_PROCESSORS} - 1)) - name: Determine tag name id: tag diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 9097e7103..71ae1513e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -48,7 +48,7 @@ class Model: dir_model: Path ftype: gguf.LlamaFileType - fname_out: Path | None + fname_out: Path is_big_endian: bool endianess: gguf.GGUFEndian use_temp_file: bool @@ -62,11 +62,12 @@ class Model: gguf_writer: gguf.GGUFWriter model_name: str | None metadata_override: Path | None + dir_model_card: Path # subclasses should define this! model_arch: gguf.MODEL_ARCH - def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path | None, is_big_endian: bool = False, + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False): @@ -90,6 +91,7 @@ class Model: self.tensor_names = None self.metadata_override = metadata_override self.model_name = model_name + self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: @@ -345,7 +347,7 @@ class Model: total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count() - self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model, self.model_name, total_params) + self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model_card, self.model_name, total_params) # Fallback to model directory name if metadata name is still missing if self.metadata.name is None: @@ -359,27 +361,22 @@ class Model: output_type: str = self.ftype.name.partition("_")[2] # Filename Output - # Note: `not is_dir()` is used because `.is_file()` will not detect - # file template strings as it doesn't actually exist as a file - if self.fname_out is not None and not self.fname_out.is_dir(): - # Output path is a custom defined templated filename - - # Process templated file name with the output ftype, useful with the "auto" ftype - self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type) - else: + if self.fname_out.is_dir(): # Generate default filename based on model specification and available metadata if not vocab_only: fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type, model_type="LoRA" if total_params < 0 else None) else: fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=None, model_type="vocab") - # Check if preferred output directory path was provided - if self.fname_out is not None and self.fname_out.is_dir(): - # output path is a directory - self.fname_out = self.fname_out / f"{fname_default}.gguf" - else: - # output in the same directory as the model by default - self.fname_out = self.dir_model / f"{fname_default}.gguf" + # Use the default filename + self.fname_out = self.fname_out / f"{fname_default}.gguf" + else: + # Output path is a custom defined templated filename + # Note: `not is_dir()` is used because `.is_file()` will not detect + # file template strings as it doesn't actually exist as a file + + # Process templated file name with the output ftype, useful with the "auto" ftype + self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type) self.set_type() @@ -596,6 +593,9 @@ class Model: if chkhsh == "7b3e7548e4308f52a76e8229e4e6cc831195d0d1df43aed21ac6c93da05fec5f": # ref: https://huggingface.co/WisdomShell/CodeShell-7B res = "codeshell" + if chkhsh == "63b97e4253352e6f357cc59ea5b583e3a680eaeaf2632188c2b952de2588485e": + # ref: https://huggingface.co/mistralai/Mistral-Nemo-Base-2407 + res = "tekken" if res is None: logger.warning("\n") @@ -753,7 +753,8 @@ class Model: token_id = int(token_id) token: str = token_data["content"] if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: - assert tokens[token_id] == token.encode("utf-8") + if tokens[token_id] != token.encode("utf-8"): + logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token!r}') if token_data.get("special") or self.does_token_look_special(token): toktypes[token_id] = SentencePieceTokenTypes.CONTROL else: @@ -1312,6 +1313,7 @@ class RefactModel(Model): special_vocab._set_special_token("prefix", 1) special_vocab._set_special_token("suffix", 3) special_vocab._set_special_token("middle", 2) + special_vocab.chat_template = None # do not add it twice special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): @@ -2014,7 +2016,8 @@ class Phi3MiniModel(Model): token_id = int(token_id) token = foken_data["content"].encode("utf-8") if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: - assert tokens[token_id] == token + if tokens[token_id] != token: + logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}') tokens[token_id] = token scores[token_id] = -1000.0 toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED @@ -2030,7 +2033,8 @@ class Phi3MiniModel(Model): token_id = int(foken_data["id"]) token = foken_data["content"].encode("utf-8") if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: - assert tokens[token_id] == token + if tokens[token_id] != token: + logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}') tokens[token_id] = token scores[token_id] = -1000.0 toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED @@ -2269,7 +2273,8 @@ class InternLM2Model(Model): chat_eos_token_id = token_id token = token.encode("utf-8") if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: - assert(tokens[token_id] == token) + if tokens[token_id] != token: + logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}') tokens[token_id] = token scores[token_id] = -1000.0 toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED @@ -2288,7 +2293,8 @@ class InternLM2Model(Model): chat_eos_token_id = token_id token = token.encode("utf-8") if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: - assert(tokens[token_id] == token) + if tokens[token_id] != token: + logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}') tokens[token_id] = token scores[token_id] = -1000.0 toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED @@ -2474,6 +2480,7 @@ class GemmaModel(Model): special_vocab._set_special_token("middle", 68) special_vocab._set_special_token("fsep", 70) special_vocab._set_special_token("eot", 107) + special_vocab.chat_template = None # do not add it twice special_vocab.add_to_gguf(self.gguf_writer) self.gguf_writer.add_add_space_prefix(False) @@ -3627,10 +3634,10 @@ def main() -> None: logger.error("Error: Cannot use temp file when splitting") sys.exit(1) - fname_out = None - if args.outfile is not None: fname_out = args.outfile + else: + fname_out = dir_model logger.info(f"Loading model: {dir_model.name}") @@ -3661,7 +3668,6 @@ def main() -> None: else: logger.info("Exporting model...") model_instance.write() - assert model_instance.fname_out is not None out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out logger.info(f"Model successfully exported to {out_path}") diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index e159c0f71..a92379666 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -92,6 +92,7 @@ models = [ {"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", }, {"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", }, {"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", }, + {"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", }, ] diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index 66e8da37c..a88d0d4a9 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -290,7 +290,7 @@ if __name__ == '__main__': fname_out = args.outfile else: # output in the same directory as the model by default - fname_out = dir_lora / 'ggml-lora-{ftype}.gguf' + fname_out = dir_lora if os.path.exists(input_model): # lazy import load_file only if lora is in safetensors format. @@ -304,12 +304,6 @@ if __name__ == '__main__': # load base model logger.info(f"Loading base model: {dir_base_model.name}") hparams = Model.load_hparams(dir_base_model) - - with open(lora_config, "r") as f: - lparams: dict[str, Any] = json.load(f) - - alpha: float = lparams["lora_alpha"] - with torch.inference_mode(): try: model_class = Model.from_model_architecture(hparams["architectures"][0]) @@ -320,12 +314,21 @@ if __name__ == '__main__': class LoraModel(model_class): model_arch = model_class.model_arch + lora_alpha: float + + def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs): + + super().__init__(*args, **kwargs) + + self.dir_model_card = dir_lora_model + self.lora_alpha = float(lora_alpha) + def set_type(self): self.gguf_writer.add_type(gguf.GGUFType.ADAPTER) self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora") def set_gguf_parameters(self): - self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha)) + self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) super().set_gguf_parameters() def get_tensors(self) -> Iterator[tuple[str, Tensor]]: @@ -368,6 +371,11 @@ if __name__ == '__main__': yield (dest_name + ".lora_a", lora_a) yield (dest_name + ".lora_b", lora_b) + with open(lora_config, "r") as f: + lparams: dict[str, Any] = json.load(f) + + alpha: float = lparams["lora_alpha"] + model_instance = LoraModel( dir_base_model, ftype, @@ -376,6 +384,8 @@ if __name__ == '__main__': use_temp_file=False, eager=args.no_lazy, dry_run=args.dry_run, + dir_lora_model=dir_lora, + lora_alpha=alpha, ) logger.info("Exporting model...") diff --git a/examples/gguf/gguf.cpp b/examples/gguf/gguf.cpp index 575143771..7498f85ef 100644 --- a/examples/gguf/gguf.cpp +++ b/examples/gguf/gguf.cpp @@ -92,6 +92,11 @@ static bool gguf_ex_read_0(const std::string & fname) { struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params); + if (!ctx) { + fprintf(stderr, "%s: failed to load '%s'\n", __func__, fname.c_str()); + return false; + } + printf("%s: version: %d\n", __func__, gguf_get_version(ctx)); printf("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx)); printf("%s: data offset: %zu\n", __func__, gguf_get_data_offset(ctx)); diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 2a3f9f758..58c32ca53 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -26,11 +26,12 @@ actor LlamaContext { private var context: OpaquePointer private var batch: llama_batch private var tokens_list: [llama_token] + var is_done: Bool = false /// This variable is used to store temporarily invalid cchars private var temporary_invalid_cchars: [CChar] - var n_len: Int32 = 64 + var n_len: Int32 = 1024 var n_cur: Int32 = 0 var n_decode: Int32 = 0 @@ -160,6 +161,7 @@ actor LlamaContext { if llama_token_is_eog(model, new_token_id) || n_cur == n_len { print("\n") + is_done = true let new_token_str = String(cString: temporary_invalid_cchars + [0]) temporary_invalid_cchars.removeAll() return new_token_str diff --git a/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift b/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift index 2c1e3f61b..b8f6a31d5 100644 --- a/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift +++ b/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift @@ -132,7 +132,7 @@ class LlamaState: ObservableObject { messageLog += "\(text)" Task.detached { - while await llamaContext.n_cur < llamaContext.n_len { + while await !llamaContext.is_done { let result = await llamaContext.completion_loop() await MainActor.run { self.messageLog += "\(result)" diff --git a/examples/pydantic_models_to_grammar_examples.py b/examples/pydantic_models_to_grammar_examples.py old mode 100644 new mode 100755 index 504ed98df..eb000d5cc --- a/examples/pydantic_models_to_grammar_examples.py +++ b/examples/pydantic_models_to_grammar_examples.py @@ -1,8 +1,15 @@ -# Function calling example using pydantic models. +#!/usr/bin/env python3 + +"""Function calling example using pydantic models.""" + from __future__ import annotations +import argparse import datetime import json +import logging +import textwrap +import sys from enum import Enum from typing import Optional, Union @@ -12,30 +19,54 @@ from pydantic_models_to_grammar import (add_run_method_to_dynamic_model, convert create_dynamic_model_from_function, generate_gbnf_grammar_and_documentation) -# Function to get completion on the llama.cpp server with grammar. -def create_completion(prompt, grammar): +def create_completion(host, prompt, gbnf_grammar): + """Calls the /completion API on llama-server. + + See + https://github.com/ggerganov/llama.cpp/tree/HEAD/examples/server#api-endpoints + """ + print(f" Request:\n Grammar:\n{textwrap.indent(gbnf_grammar, ' ')}\n Prompt:\n{textwrap.indent(prompt.rstrip(), ' ')}") headers = {"Content-Type": "application/json"} - data = {"prompt": prompt, "grammar": grammar} - - response = requests.post("http://127.0.0.1:8080/completion", headers=headers, json=data) - data = response.json() - + data = {"prompt": prompt, "grammar": gbnf_grammar} + result = requests.post(f"http://{host}/completion", headers=headers, json=data).json() assert data.get("error") is None, data - - print(data["content"]) - return data["content"] + logging.info("Result: %s", result) + content = result["content"] + print(f" Model: {result['model']}") + print(f" Result:\n{textwrap.indent(json.dumps(json.loads(content), indent=2), ' ')}") + return content # A function for the agent to send a message to the user. class SendMessageToUser(BaseModel): - """ - Send a message to the User. - """ + """Send a message to the User.""" chain_of_thought: str = Field(..., description="Your chain of thought while sending the message.") message: str = Field(..., description="Message you want to send to the user.") def run(self): - print(self.message) + print(f"SendMessageToUser: {self.message}") + + +def example_rce(host): + """Minimal test case where the LLM call an arbitrary python function.""" + print("- example_rce") + tools = [SendMessageToUser] + gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation( + pydantic_model_list=tools, outer_object_name="function", + outer_object_content="function_parameters", model_prefix="Function", fields_prefix="Parameters") + system_message = "You are an advanced AI, tasked to assist the user by calling functions in JSON format. The following are the available functions and their parameters and types:\n\n" + documentation + user_message = "What is 42 * 42?" + prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" + text = create_completion(host, prompt, gbnf_grammar) + json_data = json.loads(text) + tools_map = {tool.__name__:tool for tool in tools} + # This finds "SendMessageToUser": + tool = tools_map.get(json_data["function"]) + if not tool: + print(f"Error: unknown tool {json_data['function']}") + return 1 + tool(**json_data["function_parameters"]).run() + return 0 # Enum for the calculator tool. @@ -46,11 +77,11 @@ class MathOperation(Enum): DIVIDE = "divide" -# Simple pydantic calculator tool for the agent that can add, subtract, multiply, and divide. Docstring and description of fields will be used in system prompt. +# Simple pydantic calculator tool for the agent that can add, subtract, +# multiply, and divide. Docstring and description of fields will be used in +# system prompt. class Calculator(BaseModel): - """ - Perform a math operation on two numbers. - """ + """Perform a math operation on two numbers.""" number_one: Union[int, float] = Field(..., description="First number.") operation: MathOperation = Field(..., description="Math operation to perform.") number_two: Union[int, float] = Field(..., description="Second number.") @@ -68,55 +99,61 @@ class Calculator(BaseModel): raise ValueError("Unknown operation.") -# Here the grammar gets generated by passing the available function models to generate_gbnf_grammar_and_documentation function. This also generates a documentation usable by the LLM. -# pydantic_model_list is the list of pydanitc models -# outer_object_name is an optional name for an outer object around the actual model object. Like a "function" object with "function_parameters" which contains the actual model object. If None, no outer object will be generated -# outer_object_content is the name of outer object content. -# model_prefix is the optional prefix for models in the documentation. (Default="Output Model") -# fields_prefix is the prefix for the model fields in the documentation. (Default="Output Fields") -gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation( - pydantic_model_list=[SendMessageToUser, Calculator], outer_object_name="function", - outer_object_content="function_parameters", model_prefix="Function", fields_prefix="Parameters") +def example_calculator(host): + """Have the LLM ask to get a calculation done. -print(gbnf_grammar) -print(documentation) + Here the grammar gets generated by passing the available function models to + generate_gbnf_grammar_and_documentation function. This also generates a + documentation usable by the LLM. -system_message = "You are an advanced AI, tasked to assist the user by calling functions in JSON format. The following are the available functions and their parameters and types:\n\n" + documentation + pydantic_model_list is the list of pydantic models outer_object_name is an + optional name for an outer object around the actual model object. Like a + "function" object with "function_parameters" which contains the actual model + object. If None, no outer object will be generated outer_object_content is + the name of outer object content. -user_message = "What is 42 * 42?" -prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" - -text = create_completion(prompt=prompt, grammar=gbnf_grammar) -# This should output something like this: -# { -# "function": "calculator", -# "function_parameters": { -# "number_one": 42, -# "operation": "multiply", -# "number_two": 42 -# } -# } -function_dictionary = json.loads(text) -if function_dictionary["function"] == "calculator": - function_parameters = {**function_dictionary["function_parameters"]} - - print(Calculator(**function_parameters).run()) - # This should output: 1764 + model_prefix is the optional prefix for models in the documentation. (Default="Output Model") + fields_prefix is the prefix for the model fields in the documentation. (Default="Output Fields") + """ + print("- example_calculator") + tools = [SendMessageToUser, Calculator] + gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation( + pydantic_model_list=tools, outer_object_name="function", + outer_object_content="function_parameters", model_prefix="Function", fields_prefix="Parameters") + system_message = "You are an advanced AI, tasked to assist the user by calling functions in JSON format. The following are the available functions and their parameters and types:\n\n" + documentation + user_message1 = "What is 42 * 42?" + prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message1}<|im_end|>\n<|im_start|>assistant" + text = create_completion(host, prompt, gbnf_grammar) + json_data = json.loads(text) + expected = { + "function": "Calculator", + "function_parameters": { + "number_one": 42, + "operation": "multiply", + "number_two": 42 + } + } + if json_data != expected: + print(" Result is not as expected!") + tools_map = {tool.__name__:tool for tool in tools} + # This finds "Calculator": + tool = tools_map.get(json_data["function"]) + if not tool: + print(f"Error: unknown tool {json_data['function']}") + return 1 + result = tool(**json_data["function_parameters"]).run() + print(f" Call {json_data['function']} gave result {result}") + return 0 -# A example structured output based on pydantic models. The LLM will create an entry for a Book database out of an unstructured text. class Category(Enum): - """ - The category of the book. - """ + """The category of the book.""" Fiction = "Fiction" NonFiction = "Non-Fiction" class Book(BaseModel): - """ - Represents an entry about a book. - """ + """Represents an entry about a book.""" title: str = Field(..., description="Title of the book.") author: str = Field(..., description="Author of the book.") published_year: Optional[int] = Field(..., description="Publishing year of the book.") @@ -125,33 +162,42 @@ class Book(BaseModel): summary: str = Field(..., description="Summary of the book.") -# We need no additional parameters other than our list of pydantic models. -gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation([Book]) +def example_struct(host): + """A example structured output based on pydantic models. -system_message = "You are an advanced AI, tasked to create a dataset entry in JSON for a Book. The following is the expected output model:\n\n" + documentation + The LLM will create an entry for a Book database out of an unstructured + text. We need no additional parameters other than our list of pydantic + models. + """ + print("- example_struct") + tools = [Book] + gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation(pydantic_model_list=tools) + system_message = "You are an advanced AI, tasked to create a dataset entry in JSON for a Book. The following is the expected output model:\n\n" + documentation + text = """The Feynman Lectures on Physics is a physics textbook based on some lectures by Richard Feynman, a Nobel laureate who has sometimes been called "The Great Explainer". The lectures were presented before undergraduate students at the California Institute of Technology (Caltech), during 1961–1963. The book's co-authors are Feynman, Robert B. Leighton, and Matthew Sands.""" + prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant" + text = create_completion(host, prompt, gbnf_grammar) + json_data = json.loads(text) + # In this case, there's no function nor function_parameters. + # Here the result will vary based on the LLM used. + keys = sorted(["title", "author", "published_year", "keywords", "category", "summary"]) + if keys != sorted(json_data.keys()): + print(f"Unexpected result: {sorted(json_data.keys())}") + return 1 + book = Book(**json_data) + print(f" As a Book object: %s" % book) + return 0 -text = """The Feynman Lectures on Physics is a physics textbook based on some lectures by Richard Feynman, a Nobel laureate who has sometimes been called "The Great Explainer". The lectures were presented before undergraduate students at the California Institute of Technology (Caltech), during 1961–1963. The book's co-authors are Feynman, Robert B. Leighton, and Matthew Sands.""" -prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant" - -text = create_completion(prompt=prompt, grammar=gbnf_grammar) - -json_data = json.loads(text) - -print(Book(**json_data)) -# An example for parallel function calling with a Python function, a pydantic function model and an OpenAI like function definition. def get_current_datetime(output_format: Optional[str] = None): - """ - Get the current date and time in the given format. + """Get the current date and time in the given format. + Args: output_format: formatting string for the date and time, defaults to '%Y-%m-%d %H:%M:%S' """ - if output_format is None: - output_format = '%Y-%m-%d %H:%M:%S' - return datetime.datetime.now().strftime(output_format) + return datetime.datetime.now().strftime(output_format or "%Y-%m-%d %H:%M:%S") -# Example function to get the weather +# Example function to get the weather. def get_current_weather(location, unit): """Get the current weather in a given location""" if "London" in location: @@ -160,68 +206,107 @@ def get_current_weather(location, unit): return json.dumps({"location": "New York", "temperature": "24", "unit": unit.value}) elif "North Pole" in location: return json.dumps({"location": "North Pole", "temperature": "-42", "unit": unit.value}) - else: - return json.dumps({"location": location, "temperature": "unknown"}) + return json.dumps({"location": location, "temperature": "unknown"}) -# Here is a function definition in OpenAI style -current_weather_tool = { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", +def example_concurrent(host): + """An example for parallel function calling with a Python function, a pydantic + function model and an OpenAI like function definition. + """ + print("- example_concurrent") + # Function definition in OpenAI style. + current_weather_tool = { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + "required": ["location"], }, - "required": ["location"], }, - }, -} + } + # Convert OpenAI function definition into pydantic model. + current_weather_tool_model = convert_dictionary_to_pydantic_model(current_weather_tool) + # Add the actual function to a pydantic model. + current_weather_tool_model = add_run_method_to_dynamic_model(current_weather_tool_model, get_current_weather) -# Convert OpenAI function definition into pydantic model -current_weather_tool_model = convert_dictionary_to_pydantic_model(current_weather_tool) -# Add the actual function to a pydantic model -current_weather_tool_model = add_run_method_to_dynamic_model(current_weather_tool_model, get_current_weather) + # Convert normal Python function to a pydantic model. + current_datetime_model = create_dynamic_model_from_function(get_current_datetime) -# Convert normal Python function to a pydantic model -current_datetime_model = create_dynamic_model_from_function(get_current_datetime) - -tool_list = [SendMessageToUser, Calculator, current_datetime_model, current_weather_tool_model] + tools = [SendMessageToUser, Calculator, current_datetime_model, current_weather_tool_model] + gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation( + pydantic_model_list=tools, outer_object_name="function", + outer_object_content="params", model_prefix="Function", fields_prefix="Parameters", list_of_outputs=True) + system_message = "You are an advanced AI assistant. You are interacting with the user and with your environment by calling functions. You call functions by writing JSON objects, which represent specific function calls.\nBelow is a list of your available function calls:\n\n" + documentation + text = """Get the date and time, get the current weather in celsius in London and solve the following calculation: 42 * 42""" + prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant" + text = create_completion(host, prompt, gbnf_grammar) + json_data = json.loads(text) + expected = [ + { + "function": "get_current_datetime", + "params": { + "output_format": "%Y-%m-%d %H:%M:%S" + } + }, + { + "function": "get_current_weather", + "params": { + "location": "London", + "unit": "celsius" + } + }, + { + "function": "Calculator", + "params": { + "number_one": 42, + "operation": "multiply", + "number_two": 42 + } + } + ] + res = 0 + if json_data != expected: + print(" Result is not as expected!") + print(" This can happen on highly quantized models") + res = 1 + tools_map = {tool.__name__:tool for tool in tools} + for call in json_data: + tool = tools_map.get(call["function"]) + if not tool: + print(f"Error: unknown tool {call['function']}") + return 1 + result = tool(**call["params"]).run() + print(f" Call {call['function']} returned {result}") + # Should output something like this: + # Call get_current_datetime returned 2024-07-15 09:50:38 + # Call get_current_weather returned {"location": "London", "temperature": "42", "unit": "celsius"} + # Call Calculator returned 1764 + return res -gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation( - pydantic_model_list=tool_list, outer_object_name="function", - outer_object_content="params", model_prefix="Function", fields_prefix="Parameters", list_of_outputs=True) - -system_message = "You are an advanced AI assistant. You are interacting with the user and with your environment by calling functions. You call functions by writing JSON objects, which represent specific function calls.\nBelow is a list of your available function calls:\n\n" + documentation +def main(): + parser = argparse.ArgumentParser(description=sys.modules[__name__].__doc__) + parser.add_argument("--host", default="localhost:8080", help="llama.cpp server") + parser.add_argument("-v", "--verbose", action="store_true", help="enables logging") + args = parser.parse_args() + logging.basicConfig(level=logging.INFO if args.verbose else logging.ERROR) + ret = 0 + # Comment out below to only run the example you want. + ret = ret or example_rce(args.host) + ret = ret or example_calculator(args.host) + ret = ret or example_struct(args.host) + ret = ret or example_concurrent(args.host) + return ret -text = """Get the date and time, get the current weather in celsius in London and solve the following calculation: 42 * 42""" -prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant" - -text = create_completion(prompt=prompt, grammar=gbnf_grammar) - -json_data = json.loads(text) - -print(json_data) -# Should output something like this: -# [{'function': 'get_current_datetime', 'params': {'output_format': '%Y-%m-%d %H:%M:%S'}}, {'function': 'get_current_weather', 'params': {'location': 'London', 'unit': 'celsius'}}, {'function': 'Calculator', 'params': {'number_one': 42, 'operation': 'multiply', 'number_two': 42}}] - - -for call in json_data: - if call["function"] == "Calculator": - print(Calculator(**call["params"]).run()) - elif call["function"] == "get_current_datetime": - print(current_datetime_model(**call["params"]).run()) # pyright: ignore[reportAttributeAccessIssue] - elif call["function"] == "get_current_weather": - print(current_weather_tool_model(**call["params"]).run()) # pyright: ignore[reportAttributeAccessIssue] -# Should output something like this: -# 2024-01-14 13:36:06 -# {"location": "London", "temperature": "42", "unit": "celsius"} -# 1764 +if __name__ == "__main__": + sys.exit(main()) diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 3af2eec2f..84f6387e2 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -59,6 +59,24 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_Q6_K: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_IQ2_XXS: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_IQ4_XS: mul_mat_q_case(ctx, args, stream); break; @@ -93,6 +111,12 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: mmq_supported = true; diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 51c44d857..f08a4758d 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -63,6 +63,14 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_Q5_K: return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_IQ1_S: + return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: return MMQ_Q8_1_DS_LAYOUT_D4; @@ -131,15 +139,16 @@ static constexpr __device__ int get_mmq_y_device() { #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) } -#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} -#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} -#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0} -#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0} -#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} -#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} +#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} +#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0} +#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0} +#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0} +#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} +#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : @@ -152,42 +161,46 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : + type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 : + type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 : + type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 : type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 : type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 : tile_x_sizes{0, 0, 0}; } -#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4) -#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4) #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) #define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) #define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4) -#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/(2*QI3_K) + WARP_SIZE/8 + 7) -#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7) -#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2 + 4) #define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7) -static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding."); -static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); -static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding."); -static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 : - type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 : + return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 : type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 : type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 : type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K : - type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K : + type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 : + type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 : type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : + type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K : + type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K : + type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 : type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 : type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 : 0; @@ -216,7 +229,7 @@ template static __device__ __forceinlin #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE); + float * x_df = (float *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; @@ -235,11 +248,13 @@ template static __device__ __forceinlin } const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; + const int qs0 = get_int_b2(bxi->qs, kqsx); #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx); + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; #endif // INT8_MMA_AVAILABLE } @@ -257,7 +272,7 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; #ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + kbxd] = bxi->d; + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d; #endif // INT8_MMA_AVAILABLE @@ -304,95 +319,12 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( } } -template -static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE - - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + WARP_SIZE; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - mma_A A[ntx][4]; - float dA[ntx][mma_C::ne/2][4]; - - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) { - const int k0 = k00 + k01; - -#pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + n*mma_A::I + mma_A::get_i(l); - const int k = k0/QR4_0 + mma_A::get_k(l) % QI4_0; - const int shift = 4*(mma_A::get_k(l) / QI4_0); - - A[n][k01/(QR4_0*QI4_0)].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808); - } - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - - dA[n][l][k01/(QR4_0*QI4_0)] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/(QR4_0*QI4_0)]; - } - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) { - mma_B B; - float dB[mma_C::ne/2]; - - B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); - - dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n][k01/(QR4_0*QI4_0)], B); - -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2][k01/(QR4_0*QI4_0)]*dB[l%2]*C.x[l]; - } - } - } - } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE -} - template static __device__ __forceinline__ void load_tiles_q4_1( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; @@ -411,11 +343,13 @@ template static __device__ __forceinlin } const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; + const int qs0 = get_int_b4(bxi->qs, kqsx); #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_b4(bxi->qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx); + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; #endif // INT8_MMA_AVAILABLE } @@ -433,7 +367,7 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; #ifdef INT8_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + kbxd] = bxi->dm; + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; #endif // INT8_MMA_AVAILABLE @@ -480,88 +414,6 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( } } -template -static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE - - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_A_I16K4 mma_A_K4; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - mma_A A[ntx][4]; - half2 dmA[ntx][mma_C::ne/2][4]; - - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) { - const int k0 = k00 + k01; - - A[n][k01/(QR4_1*QI4_1)].load_low(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0/QR4_1, MMQ_MMA_TILE_X_K_Q4_1); - A[n][k01/(QR4_1*QI4_1)].x[2] = (A[n][k01/(QR4_1*QI4_1)].x[0] >> 4) & 0x0F0F0F0F; - A[n][k01/(QR4_1*QI4_1)].x[3] = (A[n][k01/(QR4_1*QI4_1)].x[1] >> 4) & 0x0F0F0F0F; - A[n][k01/(QR4_1*QI4_1)].x[0] &= 0x0F0F0F0F; - A[n][k01/(QR4_1*QI4_1)].x[1] &= 0x0F0F0F0F; - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - - dmA[n][l][k01/(QR4_1*QI4_1)] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/(QR4_1*QI4_1)]; - } - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) { - mma_B B; - half2 dsB[mma_C::ne/2]; - - B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); - - dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]; - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n][k01/(QR4_1*QI4_1)], B); - -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - const half2 dmA_dsB = dmA[n][l/2][k01/(QR4_1*QI4_1)]*dsB[l%2]; - sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); - } - } - } - } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE -} - template static __device__ __forceinline__ void load_tiles_q5_0( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -789,10 +641,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( } } -template +template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; @@ -808,6 +659,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const float * x_df = (const float *) x_qs + 2*WARP_SIZE; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; mma_A A[ntx][WARP_SIZE/QI8_0]; float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0]; @@ -840,18 +692,20 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { - const int k0 = k00 + k01; - - mma_B B; + mma_B B; float dB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K); + B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - dB[l] = y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]; + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } else { + dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } } #pragma unroll @@ -866,10 +720,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( } } } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE } template @@ -905,7 +755,6 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; @@ -922,8 +771,8 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * y_qs = (const int *) y + 4; const half2 * y_dm = (const half2 *) y; - mma_A A[ntx][WARP_SIZE/QI8_1]; - half2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1]; + mma_A A[ntx][WARP_SIZE/QI8_1]; + float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; @@ -944,7 +793,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { const int k0 = k00 + k01; - dmA[n][l][k01/QI8_1] = x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]; + dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); } } } @@ -953,18 +802,16 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { - const int k0 = k00 + k01; + mma_B B; + float2 dsB[mma_C::ne/2]; - mma_B B; - half2 dsB[mma_C::ne/2]; - - B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K); + B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - dsB[l] = y_dm[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]; + dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); } #pragma unroll @@ -974,8 +821,120 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - const half2 dmA_dsB = dmA[n][l/2][k01/QI8_1]*dsB[l%2]; - sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l]; + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y; + } + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl( + &x_qs[i*(2*WARP_SIZE + 1) + k0], + &y_qs[j*MMQ_TILE_Y_K + k01], + &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)], + y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { +#ifdef INT8_MMA_AVAILABLE + + typedef mma_int_A_I16K4 mma_A; + typedef mma_int_A_I16K8 mma_A_K8; + typedef mma_int_B_J8K4 mma_B; + typedef mma_int_C_I16J8 mma_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + WARP_SIZE*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + + mma_A A[ntx][8]; + float dA[ntx][mma_C::ne/2][8]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + const int k0 = k00 + k01; + + ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) { + const int k0 = k00 + k01; + + dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4]; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + mma_B B[2]; + float dB[mma_C::ne/2]; + + B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C[2]; + C[0].mma_K4(A[n][k01/4 + 0], B[0]); + C[1].mma_K4(A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]); } } } @@ -1222,7 +1181,6 @@ template static __device__ __forceinlin #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); - int * x_sc = (int *) (x_df + 1); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); int * x_qs = (int *) x_tile; @@ -1262,23 +1220,6 @@ template static __device__ __forceinlin } } -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) { - int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; - -#ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q3_K] = bxi->d; -#else - x_df[i] = bxi->d; -#endif // INT8_MMA_AVAILABLE - } - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) { int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8); @@ -1302,11 +1243,32 @@ template static __device__ __forceinlin const int sc = __vsubss4(sc_low | sc_high, 0x20202020); #ifdef INT8_MMA_AVAILABLE - x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/8)] = sc; + const int8_t * sc8 = (const int8_t *) ≻ + const float d = bxi->d; + +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l]; + } #else - x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc; + x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc; #endif // INT8_MMA_AVAILABLE } + +#ifndef INT8_MMA_AVAILABLE +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) { + int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + + x_df[i] = bxi->d; + } +#endif // INT8_MMA_AVAILABLE } template @@ -1342,99 +1304,14 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( } } -template -static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE - - typedef mma_int_A_I16K4 mma_A; - typedef mma_int_A_I16K8 mma_A_K8; - typedef mma_int_B_J8K4 mma_B; - typedef mma_int_C_I16J8 mma_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + WARP_SIZE*2; - const int * x_sc = (const int *) x_df + 1; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - - mma_A A[ntx][8]; - int scA[ntx][mma_C::ne/2][8]; - float dA[ntx][mma_C::ne/2]; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { - const int k0 = k00 + k01; - - ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); - } - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) { - const int k0 = k00 + k01; - - const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + k0/16]; - const int8_t * sc = (const int8_t *) &sc_packed; - -#pragma unroll - for (int ksc = 0; ksc < sizeof(int); ++ksc) { - scA[n][l][k01/4 + ksc] = sc[ksc]; - } - } - - dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K]; - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { - mma_B B[2]; - float dB[mma_C::ne/2]; - - B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); - - dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C[2]; - C[0].mma_K4(A[n][k01/4 + 0], B[0]); - C[1].mma_K4(A[n][k01/4 + 1], B[1]); - -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]* - (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1]); - } - } - } - } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) { + // scale arrangement after the following two lines: + // - ksc == 0: sc0, sc1, sc2, sc3 + // - ksc == 1: sc4, sc5, sc6, sc7 + // - ksc == 2: m0, m1, m2, m3 + // - ksc == 3: m4, m5, m6, m7 + return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits + ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits } template static __device__ __forceinline__ void load_tiles_q4_K( @@ -1442,8 +1319,7 @@ template static __device__ __forceinlin #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); - int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K); + half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); int * x_qs = (int *) x_tile; @@ -1451,9 +1327,6 @@ template static __device__ __forceinlin int * x_sc = (int *) (x_dm + txs.dm); #endif // INT8_MMA_AVAILABLE - const int kbx = 0; // threadIdx.x / QI4_K - const int kqsx = threadIdx.x; // threadIdx.x % QI4_K - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { int i = i0 + threadIdx.y; @@ -1462,33 +1335,59 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx; + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const int qs0 = get_int_b4(bxi->qs, threadIdx.x); #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_b4(bxi->qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx); + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; #endif // INT8_MMA_AVAILABLE } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 - const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 +#ifdef INT8_MMA_AVAILABLE #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) { - int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { + int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd; + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % (WARP_SIZE/16); + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); + +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } -#ifdef INT8_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + kbxd] = bxi->dm; #else - x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K + kbxd] = bxi->dm; -#endif // INT8_MMA_AVAILABLE + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) { + int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + + x_dm[i] = bxi->dm; } #pragma unroll @@ -1504,17 +1403,11 @@ template static __device__ __forceinlin const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % (WARP_SIZE/8); + const int scales8 = unpack_scales_q45_K(scales, ksc); - // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 - int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits - scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - -#ifdef INT8_MMA_AVAILABLE - x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + ksc] = scales8; -#else - x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; -#endif // INT8_MMA_AVAILABLE + x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; } +#endif // INT8_MMA_AVAILABLE } template @@ -1544,133 +1437,18 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, - x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template -static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE - - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE; - const int * x_sc = (const int *) x_dm + WARP_SIZE/QI4_K; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - - mma_A A[ntx][4]; - int scA[ntx][mma_C::ne/2][4]; - int mA[ntx][mma_C::ne/2][4]; - half2 dmA[ntx][mma_C::ne/2]; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) { - const int k0 = k00 + k01; - - A[n][k01/8 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0/QR4_K, MMQ_MMA_TILE_X_K_Q4_K); - -#pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - A[n][k01/8 + 1].x[l] = (A[n][k01/8 + 0].x[l] >> 4) & 0x0F0F0F0F; - A[n][k01/8 + 0].x[l] &= 0x0F0F0F0F; - } - } - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - - const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 0)]; - const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 2)]; - - const uint8_t * sc = (const uint8_t *) &sc_packed; - const uint8_t * m = (const uint8_t *) &m_packed; - -#pragma unroll - for (int ksc = 0; ksc < sizeof(int); ++ksc) { - scA[n][l][ksc] = sc[ksc]; - mA[n][l][ksc] = m[ksc]; - } - } - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); - - dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K]; - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - float tmpd[ntx][mma_C::ne] = {{0.0f}}; - float tmpm[ntx][mma_C::ne] = {{0.0f}}; - -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { - mma_B B; - half2 dsB[mma_C::ne/2]; - - B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); - - dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]; - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n][k01/8], B); - -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]); - tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]); - } - } - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l]; - } - } - } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE -} - template static __device__ __forceinline__ void load_tiles_q5_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2); - int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); int * x_qs = (int *) x_tile; @@ -1678,9 +1456,6 @@ template static __device__ __forceinlin int * x_sc = (int *) (x_dm + txs.dm); #endif // INT8_MMA_AVAILABLE - const int kbx = 0; // threadIdx.x / QI5_K - const int kqsx = threadIdx.x; // threadIdx.x % QI5_K - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { int i = i0 + threadIdx.y; @@ -1689,73 +1464,91 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx; - const int ky = QR5_K*kqsx; + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const int ky = QR5_K*threadIdx.x; - const int ql = get_int_b4(bxi->qs, kqsx); + const int ql = get_int_b4(bxi->qs, threadIdx.x); const int ql0 = (ql >> 0) & 0x0F0F0F0F; const int ql1 = (ql >> 4) & 0x0F0F0F0F; - const int qh = get_int_b4(bxi->qh, kqsx % (QI5_K/4)); - const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010; - const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010; + const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010; const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0; - const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4); + const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4; #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq0] = ql0 | qh0; - x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq1] = ql1 | qh1; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; #else x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0; x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1; #endif // INT8_MMA_AVAILABLE } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 - const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 +#ifdef INT8_MMA_AVAILABLE #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) { - int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { + int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd; + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % (WARP_SIZE/16); + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); + +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } -#ifdef INT8_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + kbxd] = bxi->dm; #else - x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + kbxd] = bxi->dm; -#endif // INT8_MMA_AVAILABLE + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) { + int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + x_dm[i] = bxi->dm; } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) { + int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8); + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % (WARP_SIZE/8); + const int scales8 = unpack_scales_q45_K(scales, ksc); - // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 - int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits - scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - -#ifdef INT8_MMA_AVAILABLE - x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + ksc] = scales8; -#else - x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; -#endif // INT8_MMA_AVAILABLE + x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; } +#endif // INT8_MMA_AVAILABLE } template @@ -1785,122 +1578,12 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, - x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template -static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE - - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2; - const int * x_sc = (const int *) x_dm + WARP_SIZE/QI5_K; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - - mma_A A[ntx][4]; - int scA[ntx][mma_C::ne/2][4]; - int mA[ntx][mma_C::ne/2][4]; - half2 dmA[ntx][mma_C::ne/2]; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { - const int k0 = k00 + k01; - - A[n][k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + k0, MMQ_MMA_TILE_X_K_Q5_K); - } - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - - const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 0)]; - const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 2)]; - - const uint8_t * sc = (const uint8_t *) &sc_packed; - const uint8_t * m = (const uint8_t *) &m_packed; - -#pragma unroll - for (int ksc = 0; ksc < sizeof(int); ++ksc) { - scA[n][l][ksc] = sc[ksc]; - mA[n][l][ksc] = m[ksc]; - } - } - - #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - - dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K]; - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - float tmpd[ntx][mma_C::ne] = {{0.0f}}; - float tmpm[ntx][mma_C::ne] = {{0.0f}}; - -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { - const int k0 = k00 + k01; - - mma_B B; - half2 dsB[mma_C::ne/2]; - - B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); - - dsB[l] = y_ds[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]; - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n][k01/8], B); - -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]); - tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]); - } - } - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l]; - } - } - } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE -} - template static __device__ __forceinline__ void load_tiles_q6_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -1915,9 +1598,6 @@ template static __device__ __forceinlin int * x_sc = (int *) (x_df + txs.dm); #endif // INT8_MMA_AVAILABLE - const int kbx = 0; // threadIdx.x / QI6_K - const int kqsx = threadIdx.x; // threadIdx.x % QI6_K - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { int i = i0 + threadIdx.y; @@ -1926,19 +1606,18 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx; - const int ky = QR6_K*kqsx; + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; - const int ql = get_int_b2(bxi->ql, kqsx); + const int ql = get_int_b2(bxi->ql, threadIdx.x); const int ql0 = (ql >> 0) & 0x0F0F0F0F; const int ql1 = (ql >> 4) & 0x0F0F0F0F; - const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)); - const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030; - const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030; + const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4)); + const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030; + const int qh1 = (qh >> ((threadIdx.x & 0x08) >> 2)) & 0x30303030; - const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0; - const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2); + const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0; + const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2; #ifdef INT8_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); @@ -2187,6 +1866,358 @@ template static __device__ __forceinlin } } +template static __device__ __forceinline__ void load_tiles_iq2_xxs( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % (QI2_XXS/2); + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride; + + const int q2 = get_int_b2(bxi->qs, 2*kqsx+0); + const uint8_t * aux8 = (const uint8_t *) &q2; + const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1); + +#pragma unroll + for (int l = 0; l < QR2_XXS; ++l) { + const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]); + const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F]; + + const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); + const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + + const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); + const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1; +#endif // INT8_MMA_AVAILABLE + } + + const int ls = aux32 >> 28; + const float d = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_xs( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % (QI2_XS/2); + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride; + + const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint16_t * q2 = (const uint16_t *) &q2_packed; + + #pragma unroll + for (int l = 0; l < QR2_XS; ++l) { + const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + + const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // INT8_MMA_AVAILABLE + } + + const int ls = bxi->scales[kqsx]; + const float d = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % (QI2_S/2); + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride; + + const int qs_packed = get_int_b2(bxi->qs, kqsx); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + +#pragma unroll + for (int l = 0; l < QR2_S; ++l) { + const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300))); + + const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // INT8_MMA_AVAILABLE + } + + const int ls = bxi->scales[kqsx]; + const float d = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq3_xxs( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % (QI3_XXS/2); + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride; + + const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint8_t * q3 = (const uint8_t *) &q3_packed; + const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx); + +#pragma unroll + for (int l = 0; l < QR3_XXS; ++l) { + const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]); + + const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F)); + + const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // INT8_MMA_AVAILABLE + } + + const int ls = aux32 >> 28; + const float d = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq3_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % (QI3_S/2); + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride; + + const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + const int signs_packed_32 = get_int_b2(bxi->signs, kqsx); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + +#pragma unroll + for (int l = 0; l < QR3_S; ++l) { + const int2 grid_pos = make_int2( + iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)], + iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]); + + const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h; +#endif // INT8_MMA_AVAILABLE + } + + const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); + const float d = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq1_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_ds = (half2 *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % QI1_S; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) { + int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride; + + const int qs_packed = get_int_b2(bxi->qs, kqsx); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + #pragma unroll + for (int l = 0; l < QR1_S/2; ++l) { + const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1; +#endif // INT8_MMA_AVAILABLE + } + + const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); + const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); + +#ifdef INT8_MMA_AVAILABLE + x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); +#else + x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); +#endif // INT8_MMA_AVAILABLE + } +} + template static __device__ __forceinline__ void load_tiles_iq4_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -2320,7 +2351,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; }; @@ -2328,7 +2359,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; }; @@ -2336,7 +2367,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; @@ -2352,7 +2383,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; @@ -2368,7 +2399,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q3_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; }; @@ -2376,7 +2407,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; }; @@ -2384,7 +2415,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; }; @@ -2396,11 +2427,59 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; @@ -2408,7 +2487,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; @@ -2837,6 +2916,12 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); extern DECL_MMQ_CASE(GGML_TYPE_Q5_K); extern DECL_MMQ_CASE(GGML_TYPE_Q6_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index ffeb3c27d..d7874e6ea 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -23,7 +23,8 @@ SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block} TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", - "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS" + "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", + "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS" ] SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu new file mode 100644 index 000000000..84ec85029 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ1_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu new file mode 100644 index 000000000..583c4e5a5 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu new file mode 100644 index 000000000..edaf1560d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu new file mode 100644 index 000000000..233d9342c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu new file mode 100644 index 000000000..6092dc713 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ3_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu new file mode 100644 index 000000000..1d5bd201f --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 6a17d0f3e..40091a0ef 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -188,6 +188,27 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp return sumi*d8d8 + m8s8 / (QI8_1 / vdr); } +template static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_impl( + const int * v, const int * u, const float * d8_0, const float & d8_1) { + + float sumf = 0.0f; + +#pragma unroll + for (int i0 = 0; i0 < vdr; i0 += QI8_0/2) { + int sumi = 0; + +#pragma unroll + for (int i = i0; i < i0 + QI8_0/2; ++i) { + // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(v[i], u[i], sumi); + } + + sumf += d8_0[i0/(QI8_0/2)]*sumi; + } + + return d8_1*sumf; +} + #define VDR_Q2_K_Q8_1_MMVQ 1 #define VDR_Q2_K_Q8_1_MMQ 4 diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7a39c685b..dbb3a3ebe 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -21015,7 +21015,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p gguf_tensor_info_sanitize(info); // make sure there is no duplicated tensor names - for (uint64_t j = 0; j < i; ++j) { + for (uint64_t j = 0; j < i && ok; ++j) { if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) { fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data); ok = false; diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index bac6ebfb3..15189f717 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -54,6 +54,7 @@ class Metadata: model_card = Metadata.load_model_card(model_path) hf_params = Metadata.load_hf_parameters(model_path) + # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter # heuristics metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params) @@ -177,6 +178,12 @@ class Metadata: org_component = None name_parts: list[str] = model_full_name_component.split('-') + + # Remove empty parts + for i in reversed(range(len(name_parts))): + if len(name_parts[i]) == 0: + del name_parts[i] + name_types: list[ set[Literal["basename", "size_label", "finetune", "version", "type"]] ] = [set() for _ in name_parts] @@ -223,9 +230,19 @@ class Metadata: name_parts[i] = part # Some easy to recognize finetune names elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE): - name_types[i].add("finetune") - if part.lower() == "lora": - name_parts[i] = "LoRA" + if total_params < 0 and part.lower() == "lora": + # ignore redundant "lora" in the finetune part when the output is a lora adapter + name_types[i].add("type") + else: + name_types[i].add("finetune") + + # Ignore word-based size labels when there is at least a number-based one present + # TODO: should word-based size labels always be removed instead? + if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n): + for n, t in zip(name_parts, name_types): + if "size_label" in t: + if all(c.isalpha() for c in n): + t.remove("size_label") at_start = True # Find the basename through the annotated name @@ -240,18 +257,18 @@ class Metadata: # Remove the basename annotation from trailing version for part, t in zip(reversed(name_parts), reversed(name_types)): - if "basename" in t: - if len(t) > 1: - t.remove("basename") + if "basename" in t and len(t) > 1: + t.remove("basename") else: break basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None - size_label = "-".join(s for s, t in zip(name_parts, name_types) if "size_label" in t) or None + # Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys) + size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None # TODO: should the basename version always be excluded? - # TODO: should multiple versions be joined together? - version = ([v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t] or [None])[-1] + # NOTE: multiple finetune versions are joined together + version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None if size_label is None and finetune is None and version is None: # Too ambiguous, output nothing diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index ef76831b5..40d59b75e 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -50,15 +50,15 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st # Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention if base_name is not None: - name = base_name.strip().title().replace(' ', '-').replace('/', '-') + name = base_name.strip().replace(' ', '-').replace('/', '-') elif model_name is not None: - name = model_name.strip().title().replace(' ', '-').replace('/', '-') + name = model_name.strip().replace(' ', '-').replace('/', '-') else: name = "ggml-model" parameters = f"-{size_label}" if size_label is not None else "" - finetune = f"-{finetune_string.strip().title().replace(' ', '-')}" if finetune_string is not None else "" + finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else "" version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else "" diff --git a/gguf-py/scripts/gguf_dump.py b/gguf-py/scripts/gguf_dump.py index a73ca2776..1b6546541 100755 --- a/gguf-py/scripts/gguf_dump.py +++ b/gguf-py/scripts/gguf_dump.py @@ -4,6 +4,7 @@ from __future__ import annotations import logging import argparse import os +import re import sys from pathlib import Path from typing import Any @@ -244,26 +245,58 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None else: pretty_type = str(field.types[-1].name) + def escape_markdown_inline_code(value_string): + # Find the longest contiguous sequence of backticks in the string then + # wrap string with appropriate number of backticks required to escape it + max_backticks = max((len(match.group(0)) for match in re.finditer(r'`+', value_string)), default=0) + inline_code_marker = '`' * (max_backticks + 1) + + # If the string starts or ends with a backtick, add a space at the beginning and end + if value_string.startswith('`') or value_string.endswith('`'): + value_string = f" {value_string} " + + return f"{inline_code_marker}{value_string}{inline_code_marker}" + total_elements = len(field.data) value = "" if len(field.types) == 1: curr_type = field.types[0] if curr_type == GGUFValueType.STRING: - value = repr(str(bytes(field.parts[-1]), encoding='utf-8')[:60]) + truncate_length = 60 + value_string = str(bytes(field.parts[-1]), encoding='utf-8') + if len(value_string) > truncate_length: + head = escape_markdown_inline_code(value_string[:truncate_length // 2]) + tail = escape_markdown_inline_code(value_string[-truncate_length // 2:]) + value = "{head}...{tail}".format(head=head, tail=tail) + else: + value = escape_markdown_inline_code(value_string) elif curr_type in reader.gguf_scalar_to_np: value = str(field.parts[-1][0]) else: if field.types[0] == GGUFValueType.ARRAY: curr_type = field.types[1] + array_elements = [] + if curr_type == GGUFValueType.STRING: render_element = min(5, total_elements) for element_pos in range(render_element): - value += repr(str(bytes(field.parts[-1 - element_pos]), encoding='utf-8')[:5]) + (", " if total_elements > 1 else "") + truncate_length = 30 + value_string = str(bytes(field.parts[-1 - (total_elements - element_pos - 1) * 2]), encoding='utf-8') + if len(value_string) > truncate_length: + head = escape_markdown_inline_code(value_string[:truncate_length // 2]) + tail = escape_markdown_inline_code(value_string[-truncate_length // 2:]) + value = "{head}...{tail}".format(head=head, tail=tail) + else: + value = escape_markdown_inline_code(value_string) + array_elements.append(value) + elif curr_type in reader.gguf_scalar_to_np: render_element = min(7, total_elements) for element_pos in range(render_element): - value += str(field.parts[-1 - element_pos][0]) + (", " if total_elements > 1 else "") - value = f'[ {value}{" ..." if total_elements > 1 else ""} ]' + array_elements.append(str(field.parts[-1 - (total_elements - element_pos - 1)][0])) + + value = f'[ {", ".join(array_elements).strip()}{", ..." if total_elements > len(array_elements) else ""} ]' + kv_dump_table.append({"n":n, "pretty_type":pretty_type, "total_elements":total_elements, "field_name":field.name, "value":value}) kv_dump_table_header_map = [ @@ -382,7 +415,7 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None markdown_content += f"- Percentage of total elements: {group_percentage:.2f}%\n" markdown_content += "\n\n" - print(markdown_content) # noqa: NP100 + print(markdown_content) # noqa: NP100 def main() -> None: diff --git a/gguf-py/tests/test_metadata.py b/gguf-py/tests/test_metadata.py index 3fac82188..81a2a30ae 100755 --- a/gguf-py/tests/test_metadata.py +++ b/gguf-py/tests/test_metadata.py @@ -54,7 +54,7 @@ class TestMetadataMethod(unittest.TestCase): self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"), ('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B')) - # Can't detect all non standard form in a heuristically safe way... best to err in caution and output nothing... + # Non standard naming self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"), ('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B')) @@ -71,7 +71,7 @@ class TestMetadataMethod(unittest.TestCase): self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3), ('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K')) - # None standard and not easy to disambiguate + # Non standard and not easy to disambiguate self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"), ('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None)) @@ -123,6 +123,51 @@ class TestMetadataMethod(unittest.TestCase): self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"), ('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B')) + # Ignore full-text size labels when there are number-based ones, and deduplicate size labels + self.assertEqual(gguf.Metadata.get_model_id_components("MaziyarPanahi/GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1"), + ('GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1', 'MaziyarPanahi', 'GreenNode-mini', 'multilingual-v1olet-Mistral-Instruct', 'v0.1', '7B')) + + # Instruct in a name without a size label + self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/Mistral-Nemo-Instruct-2407"), + ('Mistral-Nemo-Instruct-2407', 'mistralai', 'Mistral-Nemo', 'Instruct', '2407', None)) + + # Non-obvious splitting relying on 'chat' keyword + self.assertEqual(gguf.Metadata.get_model_id_components("deepseek-ai/DeepSeek-V2-Chat-0628"), + ('DeepSeek-V2-Chat-0628', 'deepseek-ai', 'DeepSeek-V2', 'Chat', '0628', None)) + + # Multiple versions + self.assertEqual(gguf.Metadata.get_model_id_components("OpenGVLab/Mini-InternVL-Chat-2B-V1-5"), + ('Mini-InternVL-Chat-2B-V1-5', 'OpenGVLab', 'Mini-InternVL', 'Chat', 'V1-5', '2B')) + + # TODO: DPO in the name + self.assertEqual(gguf.Metadata.get_model_id_components("jondurbin/bagel-dpo-2.8b-v0.2"), + ('bagel-dpo-2.8b-v0.2', 'jondurbin', 'bagel-dpo', None, 'v0.2', '2.8B')) + + # DPO in name, but can't be used for the finetune to keep 'LLaMA-3' in the basename + self.assertEqual(gguf.Metadata.get_model_id_components("voxmenthe/SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized"), + ('SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized', 'voxmenthe', 'SFR-Iterative-DPO-LLaMA-3', 'R-unquantized', None, '8B')) + + # Too ambiguous + # TODO: should "base" be a 'finetune' or 'size_label'? + # (in this case it should be a size label, but other models use it to signal that they are not finetuned) + self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Florence-2-base"), + ('Florence-2-base', 'microsoft', None, None, None, None)) + + ## Invalid cases ## + + # Start with a dash and has dashes in rows + self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/-Mistral--Nemo-Base-2407-"), + ('-Mistral--Nemo-Base-2407-', 'mistralai', 'Mistral-Nemo-Base', None, '2407', None)) + + ## LoRA ## + + self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B"), + ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration-LoRA', None, '8B')) + + # Negative size --> output is a LoRA adaper --> prune "LoRA" out of the name to avoid redundancy with the suffix + self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B", -1234), + ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration', None, '8B')) + def test_apply_metadata_heuristic_from_model_card(self): model_card = { 'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], @@ -134,7 +179,7 @@ class TestMetadataMethod(unittest.TestCase): } got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) expect = gguf.Metadata() - expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': 'v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}] + expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': '14-v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}] expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'] expect.languages=['en'] expect.datasets=['teknium/OpenHermes-2.5'] diff --git a/include/llama.h b/include/llama.h index 8fd4e0bb9..2dc4e6d30 100644 --- a/include/llama.h +++ b/include/llama.h @@ -93,6 +93,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_VIKING = 18, LLAMA_VOCAB_PRE_TYPE_JAIS = 19, LLAMA_VOCAB_PRE_TYPE_CODESHELL = 20, + LLAMA_VOCAB_PRE_TYPE_TEKKEN = 21, }; // note: these values should be synchronized with ggml_rope diff --git a/src/llama.cpp b/src/llama.cpp index 372decad4..34b4191d6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5527,6 +5527,12 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "codeshell") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL; + } else if ( + tokenizer_pre == "tekken") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN; + vocab.tokenizer_clean_spaces = false; + vocab.tokenizer_ignore_merges = true; + vocab.tokenizer_add_bos = true; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -15589,6 +15595,13 @@ struct llm_tokenizer_bpe { "\\p{N}", }; break; + case LLAMA_VOCAB_PRE_TYPE_TEKKEN: + // original regex from tokenizer.json + // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + regex_exprs = { + "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = {