diff --git a/common/chat-template.hpp b/common/chat-template.hpp index 0e88fb361..882ba41bd 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -249,16 +249,30 @@ class chat_template { inputs.add_generation_prompt = false; full = apply(inputs); } - - if (full.find(prefix) != 0) { - if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { - prefix = prefix.substr(0, prefix.size() - eos_token_.size()); + auto eos_pos_last = full.rfind(eos_token_); + if (eos_pos_last == prefix.size() - eos_token_.size() || + (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) { + full = full.substr(0, eos_pos_last); + } + size_t common_prefix_length = 0; + for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) { + if (prefix[i] != full[i]) { + break; } + if (prefix[i] == '<') { + // DeepSeek R1's template (as of 20250209) adds a trailing if add_generation_prompt, + // but it removes thinking tags for past messages. + // The prefix and full strings diverge at vs. <|tool▁calls▁begin|>, we avoid consuming the leading <. + continue; + } + common_prefix_length = i + 1; } - if (full.find(prefix) != 0) { + auto example = full.substr(common_prefix_length); + if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) { fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n"); + } else { + tool_call_example_ = example; } - tool_call_example_ = full.substr(prefix.size()); } } catch (const std::exception & e) { fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); @@ -363,7 +377,7 @@ class chat_template { if (polyfill_tools) { adjusted_messages = add_system(inputs.messages, "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) + - (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_)); + (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n")); } else { adjusted_messages = inputs.messages; } diff --git a/common/common.h b/common/common.h index 4eac8db5a..b208d0c7e 100644 --- a/common/common.h +++ b/common/common.h @@ -383,7 +383,7 @@ struct common_params { int32_t i_pos = -1; // position of the passkey in the junk text // imatrix params - std::string out_file = "imatrix.gguf"; // save the resulting imatrix to this file + std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations diff --git a/common/minja.hpp b/common/minja.hpp index c304b5c66..c58dd66e0 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -1385,6 +1385,13 @@ static std::string strip(const std::string & s) { return s.substr(start, end - start + 1); } +static std::string capitalize(const std::string & s) { + if (s.empty()) return s; + auto result = s; + result[0] = std::toupper(result[0]); + return result; +} + static std::string html_escape(const std::string & s) { std::string result; result.reserve(s.size()); @@ -1462,6 +1469,9 @@ public: if (method->get_name() == "strip") { vargs.expectArgs("strip method", {0, 0}, {0, 0}); return Value(strip(str)); + } else if (method->get_name() == "capitalize") { + vargs.expectArgs("capitalize method", {0, 0}, {0, 0}); + return Value(capitalize(str)); } else if (method->get_name() == "endswith") { vargs.expectArgs("endswith method", {1, 1}, {0, 0}); auto suffix = vargs.args[0].get(); @@ -1792,7 +1802,7 @@ private: auto left = parseStringConcat(); if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression"); - static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)"); + static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)"); static std::regex not_tok(R"(not\b)"); std::string op_str; while (!(op_str = consumeToken(compare_tok)).empty()) { @@ -2171,7 +2181,7 @@ private: using TemplateTokenIterator = TemplateTokenVector::const_iterator; std::vector parseVarNames() { - static std::regex varnames_regex(R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)"); + static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)"); std::vector group; if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names"); @@ -2194,13 +2204,13 @@ private: } TemplateTokenVector tokenize() { - static std::regex comment_tok(R"(\{#([-~]?)([\s\S\r\n]*?)([-~]?)#\})"); + static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})"); static std::regex expr_open_regex(R"(\{\{([-~])?)"); - static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)"); + static std::regex block_open_regex(R"(^\{%([-~])?\s*)"); static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)"); static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); - static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})"); - static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})"); + static std::regex expr_close_regex(R"(\s*([-~])?\}\})"); + static std::regex block_close_regex(R"(\s*([-~])?%\})"); TemplateTokenVector tokens; std::vector group; @@ -2284,7 +2294,7 @@ private: auto post_space = parseBlockClose(); tokens.push_back(std::make_unique(location, pre_space, post_space)); } else if (keyword == "set") { - static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))"); + static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))"); std::string ns; std::vector var_names; @@ -2336,6 +2346,11 @@ private: throw std::runtime_error("Unexpected block: " + keyword); } } else if (std::regex_search(it, end, match, non_text_open_regex)) { + if (!match.position()) { + if (match[0] != "{#") + throw std::runtime_error("Internal error: Expected a comment"); + throw std::runtime_error("Missing end of comment tag"); + } auto text_end = it + match.position(); text = std::string(it, text_end); it = text_end; @@ -2400,7 +2415,7 @@ private: auto text = text_token->text; if (post_space == SpaceHandling::Strip) { - static std::regex trailing_space_regex(R"((\s|\r|\n)+$)"); + static std::regex trailing_space_regex(R"(\s+$)"); text = std::regex_replace(text, trailing_space_regex, ""); } else if (options.lstrip_blocks && it != end) { auto i = text.size(); @@ -2410,7 +2425,7 @@ private: } } if (pre_space == SpaceHandling::Strip) { - static std::regex leading_space_regex(R"(^(\s|\r|\n)+)"); + static std::regex leading_space_regex(R"(^\s+)"); text = std::regex_replace(text, leading_space_regex, ""); } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast((*(it - 2)).get())) { if (text.length() > 0 && text[0] == '\n') { diff --git a/convert_legacy_imatrix_to_gguf.py b/convert_legacy_imatrix_to_gguf.py deleted file mode 100644 index bd72655bf..000000000 --- a/convert_legacy_imatrix_to_gguf.py +++ /dev/null @@ -1,122 +0,0 @@ -#!/usr/bin/env python3 - -from __future__ import annotations - -import os -import sys -import logging -import argparse - -from typing import Any -from pathlib import Path -from dataclasses import dataclass - -import numpy as np -import numpy.typing as npt - -if 'NO_LOCAL_GGUF' not in os.environ: - sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) -import gguf - - -logger = logging.getLogger("imatrix-to-gguf") - - -class IMatrixWriter(gguf.GGUFWriter): - def add_architecture(self) -> None: - # no arch is stored in imatrix files - pass - - -@dataclass -class IMatrixEntry: - values: np.ndarray[Any, np.dtype[np.float32]] - counts: np.ndarray[Any, np.dtype[np.float32]] - - -class IMatrixReader: - chunk_size: int = 512 # guess - offset: int = 0 - data: np.ndarray[Any, np.dtype[np.uint8]] - n_enties: int - entries: dict[str, IMatrixEntry] - chunk_count: int - dataset: str - - def _get(self, dtype: npt.DTypeLike, count: int = 1) -> npt.NDArray[Any]: - count = int(count) - itemsize = int(np.empty([], dtype=dtype).itemsize) - offset = self.offset - self.offset = offset + itemsize * count - return self.data[offset:self.offset].view(dtype=dtype)[:count] - - def __init__(self, imatrix: Path): - self.offset = 0 - self.entries = {} - self.data = np.memmap(imatrix) - n_entries = self._get(np.int32).item() - assert n_entries >= 0 - for _ in range(n_entries): - len = self._get(np.int32).item() - name = self._get(np.uint8, len).tobytes().decode("utf-8") - ncall = self._get(np.int32).item() - nval = self._get(np.int32).item() - data = self._get(np.float32, nval) - assert name not in self.entries, f"duplicated name: {name!r}" - - self.entries[name] = IMatrixEntry(data * np.float32(self.chunk_size), np.array([ncall * self.chunk_size], dtype=np.float32)) - - self.chunk_count = self._get(np.int32).item() - dataset_len = self._get(np.int32).item() - self.dataset = self._get(np.uint8, dataset_len).tobytes().decode("utf-8") - - def to_writer(self, outfile: Path) -> IMatrixWriter: - writer = IMatrixWriter(path=outfile, arch="") - - writer.add_type(gguf.GGUFType.IMATRIX) - writer.add_key_value(gguf.Keys.IMatrix.CHUNK_COUNT, self.chunk_count, gguf.GGUFValueType.UINT32) - writer.add_key_value(gguf.Keys.IMatrix.CHUNK_SIZE, self.chunk_size, gguf.GGUFValueType.UINT32) - writer.add_key_value(gguf.Keys.IMatrix.DATASET, self.dataset, gguf.GGUFValueType.STRING) - - for name, entry in self.entries.items(): - writer.add_tensor(name + ".sums", entry.values) - writer.add_tensor(name + ".counts", entry.counts) - - return writer - - -def parse_args(): - parser = argparse.ArgumentParser( - description="Convert an old imatrix.dat file to a GGUF compatible file") - parser.add_argument( - "--outfile", type=Path, - help="path to write to; default: based on input.", - ) - parser.add_argument( - "--verbose", action="store_true", - help="increase output verbosity", - ) - parser.add_argument( - "imatrix", type=Path, - help="path to an imatrix file", - ) - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) - - if args.outfile is None: - input_file: Path = args.imatrix - if input_file.suffix != ".gguf": - args.outfile = input_file.with_suffix(".gguf") - if args.outfile.exists(): - logger.error(f"default file exists, specify with --outfile to overwrite: {args.outfile}") - exit(1) - - writer = IMatrixReader(args.imatrix).to_writer(args.outfile) - - writer.write_header_to_file(args.outfile) - writer.write_kv_data_to_file() - writer.write_tensors_to_file() diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 99056e74c..b5f3feb9f 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -2,7 +2,6 @@ #include "common.h" #include "log.h" #include "llama.h" -#include "gguf.h" #include #include @@ -11,8 +10,8 @@ #include #include #include +#include #include -#include #include #if defined(_MSC_VER) @@ -22,27 +21,16 @@ static void print_usage(int, char ** argv) { LOG("\nexample usage:\n"); LOG("\n %s \\\n" - " -m model.gguf -f some-text.txt [-o imatrix.gguf] [--process-output] \\\n" + " -m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] \\\n" " [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \\\n" - " [--in-file imatrix-prev-0.gguf --in-file imatrix-prev-1.gguf ...]\n" , argv[0]); + " [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...]\n" , argv[0]); LOG("\n"); } -static bool str_remove_suffix(std::string & str, const std::string & suffix) { - bool has_suffix = str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), str.size(), suffix) == 0; - if (has_suffix) { - str = str.substr(0, str.size() - suffix.size()); - } - return has_suffix; -} - -static const char * const LLM_KV_IMATRIX_DATASET = "imatrix.dataset"; -static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count"; -static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size"; - struct Stats { - std::vector values; - std::vector counts; + std::vector values; + std::vector counts; + int ncall = 0; }; class IMatrixCollector { @@ -50,13 +38,13 @@ public: IMatrixCollector() = default; void set_params(common_params params) { m_params = std::move(params); } bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); - void save_imatrix(int32_t n_chunk = -1) const; - bool load_imatrix(const char * file_name); + void save_imatrix(int ncall = -1) const; + bool load_imatrix(const char * fname); private: std::unordered_map m_stats; common_params m_params; std::mutex m_mutex; - int32_t m_last_chunk = 0; + int m_last_call = 0; std::vector m_src1_data; std::vector m_ids; // the expert ids from ggml_mul_mat_id }; @@ -130,23 +118,17 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * auto & e = m_stats[wname]; - if (e.counts.size() == 1 && n_as > 1) { - // broadcast, when loading an old imatrix - e.counts.resize(n_as, e.counts[0]); - } + ++e.ncall; + if (e.values.empty()) { e.values.resize(src1->ne[0]*n_as, 0); - e.counts.resize(n_as, 0); + e.counts.resize(src1->ne[0]*n_as, 0); } else if (e.values.size() != (size_t)src1->ne[0]*n_as) { LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as); exit(1); //GGML_ABORT("fatal error"); } - else if (e.counts.size() != (size_t)n_as) { - LOG_ERR("Oops: inconsistent expert count for %s (%d vs %d)\n", wname.c_str(), (int)e.counts.size(), (int)n_as); - exit(1); //GGML_ABORT("fatal error"); - } - LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type); + LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type); // loop over all possible experts, regardless if they are used or not in the batch for (int ex = 0; ex < n_as; ++ex) { size_t e_start = ex*src1->ne[0]; @@ -163,26 +145,24 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * const int64_t i12 = row; const float * x = (const float *)((const char *)data + i11*src1->nb[1] + i12*src1->nb[2]); - e.counts[ex]++; - for (int j = 0; j < (int)src1->ne[0]; ++j) { - e.values[e_start + j] = std::fma(x[j], x[j], e.values[e_start + j]); - if (!std::isfinite((float)e.values[e_start + j])) { - LOG_ERR("%f detected in %s\n", (float)e.values[e_start + j], wname.c_str()); + e.values[e_start + j] += x[j]*x[j]; + e.counts[e_start + j]++; + if (!std::isfinite(e.values[e_start + j])) { + LOG("\n"); + LOG_ERR("%f detected in %s\n", e.values[e_start + j], wname.c_str()); exit(1); } } } } - const int32_t n_chunk = e.counts[ex] / (m_params.n_ctx / m_params.n_parallel); - if (n_chunk > m_last_chunk) { - const int32_t chunk_step = n_chunk - m_last_chunk; - m_last_chunk = n_chunk; - if ((m_last_chunk % m_params.n_out_freq) / chunk_step == 0) { + if (e.ncall > m_last_call) { + m_last_call = e.ncall; + if (m_last_call % m_params.n_out_freq == 0) { save_imatrix(); } - if (m_params.n_save_freq > 0 && (m_last_chunk % m_params.n_save_freq) / chunk_step == 0) { - save_imatrix(m_last_chunk); + if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) { + save_imatrix(m_last_call); } } } @@ -190,38 +170,32 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * auto & e = m_stats[wname]; if (e.values.empty()) { e.values.resize(src1->ne[0], 0); - e.counts.resize(1, 0); + e.counts.resize(src1->ne[0], 0); } else if (e.values.size() != (size_t)src1->ne[0]) { LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)src1->ne[0]); exit(1); //GGML_ABORT("fatal error"); } - else if (e.counts.size() != 1) { - LOG_ERR("Oops: inconsistent expert count for %s (%d vs %d)\n", wname.c_str(), (int)e.counts.size(), 1); - exit(1); //GGML_ABORT("fatal error"); - } - LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); - // TODO: higher dimensions + ++e.ncall; + LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); for (int row = 0; row < (int)src1->ne[1]; ++row) { const float * x = data + row * src1->ne[0]; - e.counts[0]++; for (int j = 0; j < (int)src1->ne[0]; ++j) { - e.values[j] = std::fma(x[j], x[j], e.values[j]); - if (!std::isfinite((float)e.values[j])) { - LOG_ERR("%f detected in %s\n", (float)e.values[j], wname.c_str()); + e.values[j] += x[j]*x[j]; + e.counts[j]++; + if (!std::isfinite(e.values[j])) { + LOG_ERR("%f detected in %s\n", e.values[j], wname.c_str()); exit(1); } } } - const int32_t n_chunk = e.counts[0] / (m_params.n_ctx / m_params.n_parallel); - if (n_chunk > m_last_chunk) { - const int32_t chunk_step = n_chunk - m_last_chunk; - m_last_chunk = n_chunk; - if ((m_last_chunk % m_params.n_out_freq) / chunk_step == 0) { + if (e.ncall > m_last_call) { + m_last_call = e.ncall; + if (m_last_call % m_params.n_out_freq == 0) { save_imatrix(); } - if (m_params.n_save_freq > 0 && (m_last_chunk % m_params.n_save_freq) / chunk_step == 0) { - save_imatrix(m_last_chunk); + if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) { + save_imatrix(m_last_call); } } } @@ -229,22 +203,22 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * return true; } -void IMatrixCollector::save_imatrix(int32_t n_chunk) const { +void IMatrixCollector::save_imatrix(int ncall) const { auto fname = m_params.out_file; if (fname.empty()) { - fname = "imatrix.gguf"; + fname = "imatrix.dat"; } - if (n_chunk > 0) { + if (ncall > 0) { fname += ".at_"; - fname += std::to_string(n_chunk); + fname += std::to_string(ncall); } // avoid writing imatrix entries that do not have full data // this can happen with MoE models where some of the experts end up not being exercised by the provided training data + int n_entries = 0; std::vector to_store; - size_t data_size = 0; bool is_first = true; // for printing for (const auto & kv : m_stats) { @@ -276,157 +250,101 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { continue; } + n_entries++; to_store.push_back(kv.first); - data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.values.size(), GGML_MEM_ALIGN); - data_size += GGML_PAD(ggml_tensor_overhead() + sizeof(float) * kv.second.counts.size(), GGML_MEM_ALIGN); } if (to_store.size() < m_stats.size()) { LOG_WRN("%s: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size()); } - // deterministic tensor name order - std::sort(to_store.begin(), to_store.end()); - - struct ggml_init_params params = { - /* .mem_size = */ data_size, - /* .mem_buffer = */ NULL, - /* .no_alloc = */ false, - }; - struct ggml_context * ctx = ggml_init(params); - struct gguf_context * ctx_gguf = gguf_init_empty(); - - gguf_set_val_str(ctx_gguf, "general.type", "imatrix"); - // Write the input filename to later on specify it in quantize - gguf_set_val_str(ctx_gguf, LLM_KV_IMATRIX_DATASET, m_params.prompt_file.c_str()); - // Write the number of chunks the matrix was computed with - gguf_set_val_u32(ctx_gguf, LLM_KV_IMATRIX_CHUNK_COUNT, m_last_chunk); - gguf_set_val_u32(ctx_gguf, LLM_KV_IMATRIX_CHUNK_SIZE, m_params.n_ctx / m_params.n_parallel); - + std::ofstream out(fname, std::ios::binary); + out.write((const char *) &n_entries, sizeof(n_entries)); for (const auto & name : to_store) { const auto & stat = m_stats.at(name); - const int32_t nval = (int32_t) stat.values.size(); - const int32_t nmat = (int32_t) stat.counts.size(); + int len = name.size(); + out.write((const char *) &len, sizeof(len)); + out.write(name.c_str(), len); + out.write((const char *) &stat.ncall, sizeof(stat.ncall)); + int nval = stat.values.size(); + out.write((const char *) &nval, sizeof(nval)); if (nval > 0) { - struct ggml_tensor * sums = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nval / nmat, nmat); - struct ggml_tensor * counts = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, nmat); - ggml_format_name(sums, "%s.sums", name.c_str()); - ggml_format_name(counts, "%s.counts", name.c_str()); - - for (int32_t j = 0; j < nval; ++j) { - ((float *) sums->data)[j] = (float) stat.values[j]; + std::vector tmp(nval); + for (int i = 0; i < nval; i++) { + tmp[i] = (stat.values[i] / static_cast(stat.counts[i])) * static_cast(stat.ncall); } - for (int32_t j = 0; j < nmat; ++j) { - ((float *) counts->data)[j] = (float) stat.counts[j]; - } - - gguf_add_tensor(ctx_gguf, sums); - gguf_add_tensor(ctx_gguf, counts); + out.write((const char*)tmp.data(), nval*sizeof(float)); } } - gguf_write_to_file(ctx_gguf, fname.c_str(), false); + // Write the number of call the matrix was computed with + out.write((const char *) &m_last_call, sizeof(m_last_call)); + + // Write the input filename at the end of the file to later on specify it in quantize + { + int len = m_params.prompt_file.size(); + out.write((const char *) &len, sizeof(len)); + out.write(m_params.prompt_file.c_str(), len); + } LOGV(1, "\n"); - LOG_DBGV(1, "%s: stored collected data after %d chunks in %s\n", __func__, m_last_chunk, fname.c_str()); - - gguf_free(ctx_gguf); - ggml_free(ctx); + LOG_DBGV(1, "%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname.c_str()); } -bool IMatrixCollector::load_imatrix(const char * file_name) { - struct ggml_context * ctx = nullptr; - struct gguf_init_params meta_gguf_params = { - /* .no_alloc = */ false, // the data is needed - /* .ctx = */ &ctx, - }; - struct gguf_context * ctx_gguf = gguf_init_from_file(file_name, meta_gguf_params); - if (!ctx_gguf) { +bool IMatrixCollector::load_imatrix(const char * fname) { + std::ifstream in(fname, std::ios::binary); + if (!in) { + LOG_ERR("%s: failed to open %s\n",__func__, fname); return false; } - const int32_t n_entries = gguf_get_n_tensors(ctx_gguf); - if (n_entries < 1) { - LOG_ERR("%s: no data in file %s\n", __func__, file_name); - gguf_free(ctx_gguf); - ggml_free(ctx); + int n_entries; + in.read((char*)&n_entries, sizeof(n_entries)); + if (in.fail() || n_entries < 1) { + LOG_ERR("%s: no data in file %s\n", __func__, fname); return false; } - - const std::string sums_suffix{".sums"}; - const std::string counts_suffix{".counts"}; - - // Could re-use m_stats instead, but this allows - // checking for completeness of *each* loaded imatrix file - // and also makes it easier to re-use a similar implementation in quantize.cpp - // Using an ordered map to get a deterministic iteration order. - std::map> sums_counts_for; - - for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { - std::string name = cur->name; - - if (name.empty()) { continue; } - - if (str_remove_suffix(name, sums_suffix)) { - // sums - sums_counts_for[name].first = cur; - } else if (str_remove_suffix(name, counts_suffix)) { - // counts - sums_counts_for[name].second = cur; - } else { - LOG_ERR("%s: invalid imatrix tensor name: %s\n", __func__, name.c_str()); - gguf_free(ctx_gguf); - ggml_free(ctx); + for (int i = 0; i < n_entries; ++i) { + int len; in.read((char *)&len, sizeof(len)); + std::vector name_as_vec(len+1); + in.read((char *)name_as_vec.data(), len); + if (in.fail()) { + LOG_ERR("%s: failed reading name for entry %d from %s\n",__func__,i+1, fname); return false; } - } - - for (const auto & sc : sums_counts_for) { - const std::string & name = sc.first; - const struct ggml_tensor * sums = sc.second.first; - const struct ggml_tensor * counts = sc.second.second; - - if (!sums || !counts) { - LOG_ERR("%s: mismatched sums and counts for %s\n", __func__, name.c_str()); - gguf_free(ctx_gguf); - ggml_free(ctx); + name_as_vec[len] = 0; + std::string name{name_as_vec.data()}; + auto & e = m_stats[std::move(name)]; + int ncall; + in.read((char*)&ncall, sizeof(ncall)); + int nval; + in.read((char *)&nval, sizeof(nval)); + if (in.fail() || nval < 1) { + LOG_ERR("%s: failed reading number of values for entry %d\n",__func__,i); + m_stats = {}; return false; } - auto & e = m_stats[name]; - - int64_t nval = ggml_nelements(sums); if (e.values.empty()) { e.values.resize(nval, 0); - } else if ((size_t) nval != e.values.size()) { - LOG_ERR("%s: mismatched sums size for %s: %zu != %zu\n", __func__, name.c_str(), (size_t) nval, e.values.size()); - gguf_free(ctx_gguf); - ggml_free(ctx); + e.counts.resize(nval, 0); + } + + std::vector tmp(nval); + in.read((char*)tmp.data(), nval*sizeof(float)); + if (in.fail()) { + LOG_ERR("%s: failed reading data for entry %d\n",__func__,i); + m_stats = {}; return false; } - int64_t ncounts = ggml_nelements(counts); - if (e.counts.empty()) { - e.counts.resize(ncounts, 0); - } else if (e.counts.size() == 1 && ncounts > 1) { - // broadcast, when loading an old imatrix - e.counts.resize(ncounts, e.counts[0]); - } else if ((size_t) ncounts != e.counts.size()) { - LOG_ERR("%s: mismatched counts size for %s: %zu != %zu\n", __func__, name.c_str(), (size_t) ncounts, e.counts.size()); - gguf_free(ctx_gguf); - ggml_free(ctx); - return false; + // Recreate the state as expected by save_imatrix(), and corerct for weighted sum. + for (int i = 0; i < nval; i++) { + e.values[i] += tmp[i]; + e.counts[i] += ncall; } + e.ncall += ncall; - // Recreate the state as expected by save_imatrix() - for (int64_t j = 0; j < nval; j++) { - e.values[j] += ((const float *) sums->data)[j]; - } - for (int64_t j = 0; j < ncounts; j++) { - e.counts[j] += std::lround(((const float *) counts->data)[j]); - } } - gguf_free(ctx_gguf); - ggml_free(ctx); return true; } @@ -509,11 +427,12 @@ static void process_logits( } } -static bool compute_imatrix(llama_context * ctx, const common_params & params, const int32_t n_ctx) { +static bool compute_imatrix(llama_context * ctx, const common_params & params) { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const bool add_bos = llama_vocab_get_add_bos(vocab); + const int n_ctx = llama_n_ctx(ctx); GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); @@ -558,61 +477,45 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c double nll = 0.0; double nll2 = 0.0; + LOG_INF("%s: computing over %d chunks with batch_size %d\n", __func__, n_chunk, n_batch); + + std::vector workers(std::thread::hardware_concurrency() - 1); + const int num_batches = (n_ctx + n_batch - 1) / n_batch; - const int n_seq = std::max(1, n_batch / n_ctx); - - GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0); - GGML_ASSERT(params.n_ctx == n_seq * n_ctx); - - llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1); std::vector logits; if (params.compute_ppl && num_batches > 1) { logits.reserve((size_t)n_ctx * n_vocab); } - LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); - - std::vector workers(std::thread::hardware_concurrency() - 1); - - for (int i = 0; i < n_chunk; i += n_seq) { + for (int i = 0; i < n_chunk; ++i) { const int start = i * n_ctx; const int end = start + n_ctx; - const int n_seq_batch = std::min(n_seq, n_chunk - i); + std::vector logits; const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache llama_kv_cache_clear(ctx); + llama_batch batch = llama_batch_init(n_batch, 0, 1); + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); - // clear the batch + // save original token and restore it after eval + const auto token_org = tokens[batch_start]; + + // add BOS token for the first batch of each chunk + if (add_bos && j == 0) { + tokens[batch_start] = llama_vocab_bos(vocab); + } + common_batch_clear(batch); - - for (int seq = 0; seq < n_seq_batch; seq++) { - int seq_start = batch_start + seq*n_ctx; - - // save original token and restore it after eval - const auto token_org = tokens[seq_start]; - - // add BOS token for the first batch of each chunk - if (add_bos && j == 0) { - tokens[seq_start] = llama_vocab_bos(vocab); - } - for (int k = 0; k < batch_size; ++k) { - // NOTE: specifying all logits to get activations for the output.weight tensor - // and also for the perplexity calculation. - // TODO: only get outputs when (params.process_output || params.compute_ppl) - // (not possible when this skips FFN computation of the last layer) - common_batch_add(batch, tokens[seq_start + k], j*n_batch + k, { seq }, true); - } - - // restore the original token in case it was set to BOS - tokens[seq_start] = token_org; + for (int i = 0; i < batch_size; i++) { + common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); } if (llama_decode(ctx, batch)) { @@ -621,19 +524,23 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c return false; } + // restore the original token in case it was set to BOS + tokens[batch_start] = token_org; + if (params.compute_ppl && num_batches > 1) { const auto * batch_logits = llama_get_logits(ctx); logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); } } + llama_batch_free(batch); + + const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { - llama_synchronize(ctx); - const auto t_end = std::chrono::high_resolution_clock::now(); const float t_total = std::chrono::duration(t_end - t_start).count(); LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total); - int total_seconds = (int)(t_total * n_chunk / n_seq); + int total_seconds = (int)(t_total * n_chunk); if (total_seconds >= 60*60) { LOG("%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); @@ -643,27 +550,17 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c if (params.compute_ppl) { const int first = n_ctx/2; - for (int seq = 0; seq < n_seq_batch; seq++) { - const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx); + const auto * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); + process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, + workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); + count += n_ctx - first - 1; - llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first; - - process_logits(n_vocab, all_logits + first*n_vocab, - tokens_data, n_ctx - 1 - first, - workers, nll, nll2, - logit_history.data() + start + seq*n_ctx + first, - prob_history.data() + start + seq*n_ctx + first); - - count += n_ctx - first - 1; - - LOG("[%d]%.4lf,", i + seq + 1, std::exp(nll / count)); - } + LOG("[%d]%.4lf,", i + 1, std::exp(nll / count)); fflush(stdout); logits.clear(); } } - LOG("\n"); if (params.compute_ppl) { @@ -679,8 +576,6 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c } } - llama_batch_free(batch); - return true; } @@ -697,22 +592,7 @@ int main(int argc, char ** argv) { common_init(); - const int32_t n_ctx = params.n_ctx; - - if (n_ctx <= 0) { - LOG_ERR("%s: imatrix tool requires '--ctx-size' > 0\n", __func__); - return 1; - } - - { - const int32_t n_seq = std::max(1, params.n_batch / n_ctx); - const int32_t n_kv = n_seq * n_ctx; - - params.n_parallel = n_seq; - params.n_ctx = n_kv; - - params.n_batch = std::min(params.n_batch, n_kv); - } + params.n_batch = std::min(params.n_batch, params.n_ctx); g_collector.set_params(params); @@ -768,7 +648,7 @@ int main(int argc, char ** argv) { } LOG_INF("No prompt provided; combining precomputed matrices only.\n"); } else { - if (!compute_imatrix(ctx, params, n_ctx)) { + if (!compute_imatrix(ctx, params)) { return 1; } } diff --git a/examples/main/README.md b/examples/main/README.md index 46f92eb7a..ceaed42f6 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -37,7 +37,7 @@ Once downloaded, place your model in the models folder in llama.cpp. ##### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it): ```bash -./llama-cli -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 +./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 ``` ### Windows: diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 355aef4a6..8d47b17b6 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -1,13 +1,13 @@ #include "common.h" #include "llama.h" -#include "gguf.h" #include #include #include #include #include -#include +#include +#include struct quant_option { std::string name; @@ -60,11 +60,6 @@ static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count"; static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS = "quantize.imatrix.chunks_count"; -// TODO: share with imatrix.cpp -static const char * const LLM_KV_IMATRIX_DATASET = "imatrix.dataset"; -static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count"; -static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size"; - static bool striequals(const char * a, const char * b) { while (*a && *b) { if (std::tolower(*a) != std::tolower(*b)) { @@ -134,114 +129,67 @@ static void usage(const char * executable) { exit(1); } -// TODO: share with implementation in imatrix.cpp -static bool str_remove_suffix(std::string & str, const std::string & suffix) { - bool has_suffix = str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), str.size(), suffix) == 0; - if (has_suffix) { - str = str.substr(0, str.size() - suffix.size()); - } - return has_suffix; -} - static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_dataset, std::unordered_map> & imatrix_data) { - - struct ggml_context * ctx = nullptr; - struct gguf_init_params meta_gguf_params = { - /* .no_alloc = */ false, // the data is needed - /* .ctx = */ &ctx, - }; - struct gguf_context * ctx_gguf = gguf_init_from_file(imatrix_file.c_str(), meta_gguf_params); - if (!ctx_gguf) { - fprintf(stderr, "%s: if this is an older imatrix file, make sure to convert it to the GGUF-based imatrix format\n", __func__); + std::ifstream in(imatrix_file.c_str(), std::ios::binary); + if (!in) { + printf("%s: failed to open %s\n",__func__, imatrix_file.c_str()); exit(1); } - const int32_t n_entries = gguf_get_n_tensors(ctx_gguf); - if (n_entries < 1) { - fprintf(stderr, "%s: no data in file %s\n", __func__, imatrix_file.c_str()); - gguf_free(ctx_gguf); - ggml_free(ctx); + int n_entries; + in.read((char *)&n_entries, sizeof(n_entries)); + if (in.fail() || n_entries < 1) { + printf("%s: no data in file %s\n", __func__, imatrix_file.c_str()); exit(1); } - - const int dataset_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_DATASET); - const int chunk_count_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_CHUNK_COUNT); - const int chunk_size_idx = gguf_find_key(ctx_gguf, LLM_KV_IMATRIX_CHUNK_SIZE); - if (dataset_idx < 0 || chunk_count_idx < 0 || chunk_size_idx < 0) { - fprintf(stderr, "%s: missing imatrix metadata in file %s\n", __func__, imatrix_file.c_str()); - gguf_free(ctx_gguf); - ggml_free(ctx); - exit(1); - } - - const uint32_t chunk_size = gguf_get_val_u32(ctx_gguf, chunk_size_idx); - - const std::string sums_suffix{".sums"}; - const std::string counts_suffix{".counts"}; - - // Using an ordered map to get a deterministic iteration order. - std::map> sums_counts_for; - - for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { - std::string name = cur->name; - - if (name.empty()) { continue; } - - if (str_remove_suffix(name, sums_suffix)) { - // sums - sums_counts_for[name].first = cur; - } else if (str_remove_suffix(name, counts_suffix)) { - // counts - sums_counts_for[name].second = cur; - } else { - fprintf(stderr, "%s: invalid imatrix tensor name: %s\n", __func__, name.c_str()); - gguf_free(ctx_gguf); - ggml_free(ctx); + for (int i = 0; i < n_entries; ++i) { + int len; in.read((char *)&len, sizeof(len)); + std::vector name_as_vec(len+1); + in.read((char *)name_as_vec.data(), len); + if (in.fail()) { + printf("%s: failed reading name for entry %d from %s\n", __func__, i+1, imatrix_file.c_str()); exit(1); } - } - - for (const auto & sc : sums_counts_for) { - const std::string & name = sc.first; - const struct ggml_tensor * sums = sc.second.first; - const struct ggml_tensor * counts = sc.second.second; - - if (!sums || !counts) { - fprintf(stderr, "%s: mismatched sums and counts for %s\n", __func__, name.c_str()); - gguf_free(ctx_gguf); - ggml_free(ctx); - exit(1); - } - - const int64_t ne0 = sums->ne[0]; - const int64_t ne1 = sums->ne[1]; - + name_as_vec[len] = 0; + std::string name{name_as_vec.data()}; auto & e = imatrix_data[name]; - e.resize(ggml_nelements(sums)); - float max_count = 0.0f; - for (int64_t j = 0; j < ne1; ++j) { - const float count = ((const float *) counts->data)[j]; - for (int64_t i = 0; i < ne0; ++i) { - e[j*ne0 + i] = ((const float *) sums->data)[j*ne0 + i] / count; - } - if (count > max_count) { - max_count = count; - } + int ncall; + in.read((char *)&ncall, sizeof(ncall)); + int nval; + in.read((char *)&nval, sizeof(nval)); + if (in.fail() || nval < 1) { + printf("%s: failed reading number of values for entry %d\n", __func__, i); + imatrix_data = {}; + exit(1); } + e.resize(nval); + in.read((char *)e.data(), nval*sizeof(float)); + if (in.fail()) { + printf("%s: failed reading data for entry %d\n", __func__, i); + imatrix_data = {}; + exit(1); + } + if (ncall > 0) { + for (auto& v : e) v /= ncall; + } + if (getenv("LLAMA_TRACE")) { - printf("%s: loaded data (size = %6d, n_tokens = %6d, n_chunks = %6d) for '%s'\n", __func__, int(e.size()), int(max_count), int(max_count / chunk_size), name.c_str()); + printf("%s: loaded data (size = %6d, ncall = %6d) for '%s'\n", __func__, int(e.size()), ncall, name.c_str()); } } - int m_last_chunk = gguf_get_val_u32(ctx_gguf, chunk_count_idx); - imatrix_dataset = gguf_get_val_str(ctx_gguf, dataset_idx); - - printf("%s: imatrix dataset='%s'\n", __func__, imatrix_dataset.c_str()); - printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_chunk); - - gguf_free(ctx_gguf); - ggml_free(ctx); - - return m_last_chunk; + // latest imatrix version contains the dataset filename at the end of the file + int m_last_call = 0; + if (in.peek() != EOF) { + in.read((char *)&m_last_call, sizeof(m_last_call)); + int dataset_len; + in.read((char *)&dataset_len, sizeof(dataset_len)); + std::vector dataset_as_vec(dataset_len); + in.read(dataset_as_vec.data(), dataset_len); + imatrix_dataset.assign(dataset_as_vec.begin(), dataset_as_vec.end()); + printf("%s: imatrix dataset='%s'\n", __func__, imatrix_dataset.c_str()); + } + printf("%s: loaded %d importance matrix entries from %s computed on %d chunks\n", __func__, int(imatrix_data.size()), imatrix_file.c_str(), m_last_call); + return m_last_call; } static int prepare_imatrix(const std::string & imatrix_file, diff --git a/ggml/include/ggml-vulkan.h b/ggml/include/ggml-vulkan.h index 53cdba072..ed5ea5f79 100644 --- a/ggml/include/ggml-vulkan.h +++ b/ggml/include/ggml-vulkan.h @@ -10,8 +10,6 @@ extern "C" { #define GGML_VK_NAME "Vulkan" #define GGML_VK_MAX_DEVICES 16 -GGML_BACKEND_API void ggml_vk_instance_init(void); - // backend API GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d32ba4efb..bffe95086 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -167,6 +167,7 @@ struct vk_device_struct { uint32_t subgroup_size; uint32_t shader_core_count; bool uma; + bool prefer_host_memory; bool float_controls_rte_fp16; bool subgroup_size_control; @@ -1294,7 +1295,9 @@ static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk: static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { vk_buffer buf; try { - if (device->uma) { + if (device->prefer_host_memory) { + buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); + } else if (device->uma) { // Fall back to host memory type buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); } else { @@ -2199,6 +2202,9 @@ static vk_device ggml_vk_get_device(size_t idx) { device->physical_device = physical_devices[dev_num]; const std::vector ext_props = device->physical_device.enumerateDeviceExtensionProperties(); + const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY"); + device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr; + bool fp16_storage = false; bool fp16_compute = false; bool maintenance4_support = false; @@ -2787,14 +2793,12 @@ static void ggml_vk_print_gpu_info(size_t idx) { static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions); static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); -void ggml_vk_instance_init() { +static void ggml_vk_instance_init() { if (vk_instance_initialized) { return; } VK_LOG_DEBUG("ggml_vk_instance_init()"); - vk_instance_initialized = true; - uint32_t api_version = vk::enumerateInstanceVersion(); if (api_version < VK_API_VERSION_1_2) { @@ -2845,6 +2849,7 @@ void ggml_vk_instance_init() { GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n"); } vk_instance.instance = vk::createInstance(instance_create_info); + vk_instance_initialized = true; size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); @@ -2869,7 +2874,7 @@ void ggml_vk_instance_init() { // Make sure at least one device exists if (devices.empty()) { std::cerr << "ggml_vulkan: Error: No devices found." << std::endl; - GGML_ABORT("fatal error"); + return; } // Default to using all dedicated GPUs @@ -8344,8 +8349,13 @@ ggml_backend_reg_t ggml_backend_vk_reg() { /* .iface = */ ggml_backend_vk_reg_i, /* .context = */ nullptr, }; - - return ® + try { + ggml_vk_instance_init(); + return ® + } catch (const vk::SystemError& e) { + VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what()); + return nullptr; + } } // Extension availability diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 0b1f21fc7..ecac5b4bb 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -211,12 +211,6 @@ class Keys: TYPE = "adapter.type" LORA_ALPHA = "adapter.lora.alpha" - class IMatrix: - CHUNK_COUNT = "imatrix.chunk_count" - CHUNK_SIZE = "imatrix.chunk_size" - DATASET = "imatrix.dataset" - - # # recommended mapping of model tensor names for storage in gguf # @@ -225,7 +219,6 @@ class Keys: class GGUFType: MODEL = "model" ADAPTER = "adapter" - IMATRIX = "imatrix" class MODEL_ARCH(IntEnum): diff --git a/requirements.txt b/requirements.txt index 98c53db81..9e190ae27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,5 @@ -r ./requirements/requirements-convert_hf_to_gguf.txt -r ./requirements/requirements-convert_hf_to_gguf_update.txt --r ./requirements/requirements-convert_legacy_imatrix_to_gguf.txt -r ./requirements/requirements-convert_llama_ggml_to_gguf.txt -r ./requirements/requirements-convert_lora_to_gguf.txt diff --git a/requirements/requirements-convert_legacy_imatrix_to_gguf.txt b/requirements/requirements-convert_legacy_imatrix_to_gguf.txt deleted file mode 100644 index afe2747d4..000000000 --- a/requirements/requirements-convert_legacy_imatrix_to_gguf.txt +++ /dev/null @@ -1 +0,0 @@ --r ./requirements-convert_legacy_llama.txt