Merge branch 'ggerganov:master' into master

This commit is contained in:
bmwl 2024-02-12 20:39:41 -08:00 committed by GitHub
commit 87f8d9e5e0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
54 changed files with 2975 additions and 1795 deletions

View file

@ -1,2 +1,3 @@
[flake8] [flake8]
max-line-length = 125 max-line-length = 125
ignore = W503

View file

@ -16,5 +16,5 @@ jobs:
- name: flake8 Lint - name: flake8 Lint
uses: py-actions/flake8@v2 uses: py-actions/flake8@v2
with: with:
ignore: "E203,E211,E221,E225,E231,E241,E251,E261,E266,E501,E701,E704" ignore: "E203,E211,E221,E225,E231,E241,E251,E261,E266,E501,E701,E704,W503"
exclude: "examples/*,examples/*/**,*/**/__init__.py" exclude: "examples/*,examples/*/**,*/**/__init__.py"

View file

@ -13,17 +13,31 @@ let package = Package(
products: [ products: [
.library(name: "llama", targets: ["llama"]), .library(name: "llama", targets: ["llama"]),
], ],
dependencies: [
.package(url: "https://github.com/ggerganov/ggml.git", .branch("release"))
],
targets: [ targets: [
.target( .target(
name: "llama", name: "llama",
dependencies: ["ggml"],
path: ".", path: ".",
exclude: ["ggml-metal.metal"], exclude: [
"cmake",
"examples",
"scripts",
"models",
"tests",
"CMakeLists.txt",
"ggml-cuda.cu",
"ggml-cuda.h",
"Makefile"
],
sources: [ sources: [
"ggml.c",
"llama.cpp", "llama.cpp",
"ggml-alloc.c",
"ggml-backend.c",
"ggml-quants.c",
"ggml-metal.m",
],
resources: [
.process("ggml-metal.metal")
], ],
publicHeadersPath: "spm-headers", publicHeadersPath: "spm-headers",
cSettings: [ cSettings: [

View file

@ -124,6 +124,7 @@ Typically finetunes of the base models below are supported as well.
- Go: [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) - Go: [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp)
- Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp) - Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp)
- JS/TS (llama.cpp server client): [lgrammel/modelfusion](https://modelfusion.dev/integration/model-provider/llamacpp) - JS/TS (llama.cpp server client): [lgrammel/modelfusion](https://modelfusion.dev/integration/model-provider/llamacpp)
- JavaScript/Wasm (works in browser): [tangledgroup/llama-cpp-wasm](https://github.com/tangledgroup/llama-cpp-wasm)
- Ruby: [yoshoku/llama_cpp.rb](https://github.com/yoshoku/llama_cpp.rb) - Ruby: [yoshoku/llama_cpp.rb](https://github.com/yoshoku/llama_cpp.rb)
- Rust (nicer API): [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp) - Rust (nicer API): [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp)
- Rust (more direct bindings): [utilityai/llama-cpp-rs](https://github.com/utilityai/llama-cpp-rs) - Rust (more direct bindings): [utilityai/llama-cpp-rs](https://github.com/utilityai/llama-cpp-rs)

View file

@ -340,13 +340,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.samplers_sequence = parse_samplers_input(argv[i]); const auto sampler_names = string_split(argv[i], ';');
sparams.samplers_sequence = sampler_types_from_names(sampler_names);
} else if (arg == "--sampling-seq") { } else if (arg == "--sampling-seq") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.samplers_sequence = argv[i]; sparams.samplers_sequence = sampler_types_from_chars(argv[i]);
} else if (arg == "--top-p") { } else if (arg == "--top-p") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -915,6 +916,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
const llama_sampling_params & sparams = params.sparams; const llama_sampling_params & sparams = params.sparams;
std::string sampler_type_chars;
std::string sampler_type_names;
for (const auto sampler_type : sparams.samplers_sequence) {
sampler_type_chars += static_cast<char>(sampler_type);
sampler_type_names += sampler_type_to_name_string(sampler_type) + ";";
}
sampler_type_names.pop_back();
printf("\n"); printf("\n");
printf("usage: %s [options]\n", argv[0]); printf("usage: %s [options]\n", argv[0]);
printf("\n"); printf("\n");
@ -956,8 +965,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n"); printf(" --samplers samplers that will be used for generation in the order, separated by \';\' (default: %s)\n", sampler_type_names.c_str());
printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str()); printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sampler_type_chars.c_str());
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
@ -1107,45 +1116,85 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
} }
// //
// String parsing // String utils
// //
std::string parse_samplers_input(std::string input) { std::vector<std::string> string_split(std::string input, char separator) {
std::string output = ""; std::vector<std::string> parts;
size_t separator_pos = input.find(separator);
while (separator_pos != std::string::npos) {
std::string part = input.substr(0, separator_pos);
parts.emplace_back(part);
input = input.substr(separator_pos + 1);
separator_pos = input.find(separator);
}
parts.emplace_back(input);
return parts;
}
std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names) {
// since samplers names are written multiple ways // since samplers names are written multiple ways
// make it ready for both system names and input names // make it ready for both system names and input names
std::unordered_map<std::string, char> samplers_symbols { std::unordered_map<std::string, llama_sampler_type> sampler_name_map {
{"top_k", 'k'}, {"top_k", llama_sampler_type::TOP_K},
{"top-k", 'k'}, {"top-k", llama_sampler_type::TOP_K},
{"top_p", 'p'}, {"top_p", llama_sampler_type::TOP_P},
{"top-p", 'p'}, {"top-p", llama_sampler_type::TOP_P},
{"nucleus", 'p'}, {"nucleus", llama_sampler_type::TOP_P},
{"typical_p", 'y'}, {"typical_p", llama_sampler_type::TYPICAL_P},
{"typical-p", 'y'}, {"typical-p", llama_sampler_type::TYPICAL_P},
{"typical", 'y'}, {"typical", llama_sampler_type::TYPICAL_P},
{"min_p", 'm'}, {"min_p", llama_sampler_type::MIN_P},
{"min-p", 'm'}, {"min-p", llama_sampler_type::MIN_P},
{"tfs_z", 'f'}, {"tfs_z", llama_sampler_type::TFS_Z},
{"tfs-z", 'f'}, {"tfs-z", llama_sampler_type::TFS_Z},
{"tfs", 'f'}, {"tfs", llama_sampler_type::TFS_Z},
{"temp", 't'}, {"temp", llama_sampler_type::TEMP},
{"temperature",'t'} {"temperature", llama_sampler_type::TEMP}
}; };
// expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p"
size_t separator = input.find(';');
while (separator != input.npos) {
std::string name = input.substr(0,separator);
input = input.substr(separator+1);
separator = input.find(';');
if (samplers_symbols.find(name) != samplers_symbols.end()) { std::vector<llama_sampler_type> sampler_types;
output += samplers_symbols[name]; sampler_types.reserve(names.size());
for (const auto& name : names) {
const auto sampler_item = sampler_name_map.find(name);
if (sampler_item != sampler_name_map.end()) {
sampler_types.push_back(sampler_item->second);
} }
} }
if (samplers_symbols.find(input) != samplers_symbols.end()) { return sampler_types;
output += samplers_symbols[input]; }
std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & names_string) {
std::unordered_map<char, llama_sampler_type> sampler_name_map {
{'k', llama_sampler_type::TOP_K},
{'p', llama_sampler_type::TOP_P},
{'y', llama_sampler_type::TYPICAL_P},
{'m', llama_sampler_type::MIN_P},
{'f', llama_sampler_type::TFS_Z},
{'t', llama_sampler_type::TEMP}
};
std::vector<llama_sampler_type> sampler_types;
sampler_types.reserve(names_string.size());
for (const auto & c : names_string) {
const auto sampler_item = sampler_name_map.find(c);
if (sampler_item != sampler_name_map.end()) {
sampler_types.push_back(sampler_item->second);
}
}
return sampler_types;
}
std::string sampler_type_to_name_string(llama_sampler_type sampler_type) {
switch (sampler_type) {
case llama_sampler_type::TOP_K: return "top_k";
case llama_sampler_type::TFS_Z: return "tfs_z";
case llama_sampler_type::TYPICAL_P: return "typical_p";
case llama_sampler_type::TOP_P: return "top_p";
case llama_sampler_type::MIN_P: return "min_p";
case llama_sampler_type::TEMP: return "temp";
default : return "";
} }
return output;
} }
// //
@ -1560,6 +1609,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false"); fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false");
fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false"); fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false");
fprintf(stream, "cpu_has_matmul_int8: %s\n", ggml_cpu_has_matmul_int8() ? "true" : "false");
#ifdef NDEBUG #ifdef NDEBUG
fprintf(stream, "debug: false\n"); fprintf(stream, "debug: false\n");

View file

@ -162,10 +162,13 @@ std::string gpt_random_prompt(std::mt19937 & rng);
void process_escapes(std::string& input); void process_escapes(std::string& input);
// //
// String parsing // String utils
// //
std::string parse_samplers_input(std::string input); std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names);
std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & names_string);
std::vector<std::string> string_split(std::string input, char separator);
std::string sampler_type_to_name_string(llama_sampler_type sampler_type);
// //
// Model utils // Model utils

View file

@ -103,15 +103,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
std::string llama_sampling_order_print(const llama_sampling_params & params) { std::string llama_sampling_order_print(const llama_sampling_params & params) {
std::string result = "CFG -> Penalties "; std::string result = "CFG -> Penalties ";
if (params.mirostat == 0) { if (params.mirostat == 0) {
for (auto s : params.samplers_sequence) { for (auto sampler_type : params.samplers_sequence) {
switch (s) { const auto sampler_type_name = sampler_type_to_name_string(sampler_type);
case 'k': result += "-> top_k "; break; if (!sampler_type_name.empty()) {
case 'f': result += "-> tfs_z "; break; result += "-> " + sampler_type_name + " ";
case 'y': result += "-> typical_p "; break;
case 'p': result += "-> top_p "; break;
case 'm': result += "-> min_p "; break;
case 't': result += "-> temp "; break;
default : break;
} }
} }
} else { } else {
@ -127,8 +122,6 @@ static void sampler_queue(
const llama_sampling_params & params, const llama_sampling_params & params,
llama_token_data_array & cur_p, llama_token_data_array & cur_p,
size_t & min_keep) { size_t & min_keep) {
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
const float temp = params.temp; const float temp = params.temp;
const float dynatemp_range = params.dynatemp_range; const float dynatemp_range = params.dynatemp_range;
const float dynatemp_exponent = params.dynatemp_exponent; const float dynatemp_exponent = params.dynatemp_exponent;
@ -137,16 +130,16 @@ static void sampler_queue(
const float min_p = params.min_p; const float min_p = params.min_p;
const float tfs_z = params.tfs_z; const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p; const float typical_p = params.typical_p;
const std::string & samplers_sequence = params.samplers_sequence; const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
for (auto s : samplers_sequence) { for (auto sampler_type : samplers_sequence) {
switch (s){ switch (sampler_type) {
case 'k': llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
case 'y': llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
case 'p': llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
case 'm': llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
case 't': case llama_sampler_type::TEMP:
if (dynatemp_range > 0) { if (dynatemp_range > 0) {
float dynatemp_min = std::max(0.0f, temp - dynatemp_range); float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
float dynatemp_max = std::max(0.0f, temp + dynatemp_range); float dynatemp_max = std::max(0.0f, temp + dynatemp_range);

View file

@ -8,6 +8,16 @@
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
// sampler types
enum class llama_sampler_type : char {
TOP_K = 'k',
TOP_P = 'p',
MIN_P = 'm',
TFS_Z = 'f',
TYPICAL_P = 'y',
TEMP = 't'
};
// sampling parameters // sampling parameters
typedef struct llama_sampling_params { typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember int32_t n_prev = 64; // number of previous tokens to remember
@ -28,7 +38,15 @@ typedef struct llama_sampling_params {
float mirostat_tau = 5.00f; // target entropy float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = true; // consider newlines as a repeatable token bool penalize_nl = true; // consider newlines as a repeatable token
std::string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp
std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::TOP_K,
llama_sampler_type::TFS_Z,
llama_sampler_type::TYPICAL_P,
llama_sampler_type::TOP_P,
llama_sampler_type::MIN_P,
llama_sampler_type::TEMP
};
std::string grammar; // optional BNF-like grammar to constrain sampling std::string grammar; // optional BNF-like grammar to constrain sampling

View file

@ -209,6 +209,8 @@ class Model:
return InternLM2Model return InternLM2Model
if model_architecture == "MiniCPMForCausalLM": if model_architecture == "MiniCPMForCausalLM":
return MiniCPMModel return MiniCPMModel
if model_architecture == "BertModel":
return BertModel
return Model return Model
def _is_model_safetensors(self) -> bool: def _is_model_safetensors(self) -> bool:
@ -264,6 +266,8 @@ class Model:
return gguf.MODEL_ARCH.INTERNLM2 return gguf.MODEL_ARCH.INTERNLM2
if arch == "MiniCPMForCausalLM": if arch == "MiniCPMForCausalLM":
return gguf.MODEL_ARCH.MINICPM return gguf.MODEL_ARCH.MINICPM
if arch == "BertModel":
return gguf.MODEL_ARCH.BERT
raise NotImplementedError(f'Architecture "{arch}" not supported!') raise NotImplementedError(f'Architecture "{arch}" not supported!')
@ -1629,6 +1633,96 @@ in chat mode so that the conversation can end normally.")
self.post_write_tensors(tensor_map, name, data_torch) self.post_write_tensors(tensor_map, name, data_torch)
class BertModel(Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_count = self.hparams["num_hidden_layers"]
def set_gguf_parameters(self):
# TODO(cebtenzzre): merge with parent class
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
self.gguf_writer.add_causal_attention(False)
self.gguf_writer.add_file_type(self.ftype)
def set_vocab(self):
path = self.dir_model
added_tokens_path = self.dir_model if self.dir_model.exists() else None
# use huggingface vocab to get all tokens
vocab = HfVocab(path, added_tokens_path)
tokens, scores, toktypes = zip(*vocab.all_tokens())
assert len(tokens) == vocab.vocab_size
# we need this to validate the size of the token_type embeddings
# though currently we are passing all zeros to the token_type embeddings
n_token_types = len(set(toktypes))
self.gguf_writer.add_token_type_count(n_token_types)
# convert to phantom space vocab
def phantom(tok, typ):
if tok.startswith(b"[") and tok.endswith(b"]"):
return tok
if tok.startswith(b"##"):
return tok[2:]
return b"\xe2\x96\x81" + tok
tokens = [phantom(t, y) for t, y in zip(tokens, toktypes)]
# set up bos and eos tokens (cls and sep)
self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
# add vocab to gguf
self.gguf_writer.add_tokenizer_model("bert")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
# handle special tokens
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
def write_tensors(self):
tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
tensors = dict(self.get_tensors())
for name, data_torch in tensors.items():
# we are only using BERT for embeddings so we don't need the pooling layer
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
continue # we don't need these
# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
data = data_torch.squeeze().numpy()
n_dims = len(data.shape)
new_dtype: type[np.floating[Any]]
if (
self.ftype == 1 and name.endswith(".weight") and n_dims == 2
and name != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32
):
# if f16 desired, convert any float32 2-dim weight tensors to float16
new_dtype = np.float16
else:
# if f32 desired, convert any float16 to float32
new_dtype = np.float32
print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}")
if data.dtype != new_dtype:
data = data.astype(new_dtype)
self.gguf_writer.add_tensor(new_name, data)
###### CONVERSION LOGIC ###### ###### CONVERSION LOGIC ######

View file

@ -88,7 +88,8 @@ def main():
gguf_writer.add_embedding_length(hidden_size) gguf_writer.add_embedding_length(hidden_size)
gguf_writer.add_block_count(block_count) gguf_writer.add_block_count(block_count)
gguf_writer.add_feed_forward_length(hparams.ffn_hidden_size) gguf_writer.add_feed_forward_length(hparams.ffn_hidden_size)
gguf_writer.add_rope_dimension_count(hidden_size // head_count) # ref: https://github.com/ggerganov/llama.cpp/pull/4889/commits/eea19039fc52ea2dbd1aab45b59ab4e3e29a3443
gguf_writer.add_rope_dimension_count(hidden_size // head_count // 2)
gguf_writer.add_head_count(head_count) gguf_writer.add_head_count(head_count)
gguf_writer.add_head_count_kv(head_count_kv) gguf_writer.add_head_count_kv(head_count_kv)
gguf_writer.add_rope_freq_base(hparams.rotary_emb_base) gguf_writer.add_rope_freq_base(hparams.rotary_emb_base)

View file

@ -87,7 +87,17 @@ int main(int argc, char ** argv) {
} }
const int n_embd = llama_n_embd(model); const int n_embd = llama_n_embd(model);
const auto * embeddings = llama_get_embeddings(ctx); auto * embeddings = llama_get_embeddings(ctx);
// l2-normalize embeddings
float norm = 0;
for (int i = 0; i < n_embd; i++) {
norm += embeddings[i] * embeddings[i];
}
norm = sqrt(norm);
for (int i = 0; i < n_embd; i++) {
embeddings[i] /= norm;
}
for (int i = 0; i < n_embd; i++) { for (int i = 0; i < n_embd; i++) {
printf("%f ", embeddings[i]); printf("%f ", embeddings[i]);

View file

@ -337,24 +337,14 @@ static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int
params.mem_buffer = NULL; params.mem_buffer = NULL;
params.no_alloc = true; params.no_alloc = true;
struct ggml_context * ctx = NULL; struct ggml_context * ctx = NULL;
struct ggml_allocr * alloc = NULL; struct ggml_gallocr * alloc = NULL;
struct ggml_cgraph * gf = NULL; struct ggml_cgraph * gf = NULL;
ctx = ggml_init(params); ctx = ggml_init(params);
alloc = ggml_allocr_new_measure(tensor_alignment); alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling); gf = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf);
ggml_allocr_free(alloc);
ggml_free(ctx);
static std::vector<uint8_t> data_compute; ggml_gallocr_alloc_graph(alloc, gf);
data_compute.resize(alloc_size + tensor_alignment);
ctx = ggml_init(params);
alloc = ggml_allocr_new(data_compute.data(), data_compute.size(), tensor_alignment);
gf = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
ggml_allocr_alloc_graph(alloc, gf);
ggml_allocr_free(alloc);
struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads); struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads);
static std::vector<uint8_t> data_work; static std::vector<uint8_t> data_work;
@ -363,6 +353,7 @@ static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int
ggml_graph_compute(gf, &cplan); ggml_graph_compute(gf, &cplan);
ggml_gallocr_free(alloc);
ggml_free(ctx); ggml_free(ctx);
return true; return true;
} }

View file

@ -1,5 +1,6 @@
#include "ggml.h" #include "ggml.h"
#include "ggml-alloc.h" #include "ggml-alloc.h"
#include "ggml-backend.h"
#include "llama.h" #include "llama.h"
#include "common.h" #include "common.h"
#include "train.h" #include "train.h"
@ -13,8 +14,6 @@
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
static const size_t tensor_alignment = 32;
struct my_llama_hparams { struct my_llama_hparams {
uint32_t n_vocab = 32000; uint32_t n_vocab = 32000;
uint32_t n_ctx = 512; uint32_t n_ctx = 512;
@ -128,7 +127,7 @@ struct my_llama_lora_layer {
struct my_llama_lora { struct my_llama_lora {
struct ggml_context * ctx = NULL; struct ggml_context * ctx = NULL;
std::vector<uint8_t> data; ggml_backend_buffer_t data;
my_llama_lora_hparams hparams; my_llama_lora_hparams hparams;
@ -372,63 +371,6 @@ static void set_param_lora(struct my_llama_lora * lora) {
} }
} }
static void alloc_lora(struct ggml_allocr * alloc, struct my_llama_lora * lora) {
ggml_allocr_alloc(alloc, lora->tok_embeddings_a);
ggml_allocr_alloc(alloc, lora->tok_embeddings_b);
ggml_allocr_alloc(alloc, lora->norm_a);
ggml_allocr_alloc(alloc, lora->norm_b);
ggml_allocr_alloc(alloc, lora->output_a);
ggml_allocr_alloc(alloc, lora->output_b);
for (uint32_t i = 0; i < lora->layers.size(); ++i) {
auto & layer = lora->layers[i];
ggml_allocr_alloc(alloc, layer.attention_norm_a);
ggml_allocr_alloc(alloc, layer.attention_norm_b);
ggml_allocr_alloc(alloc, layer.wq_a);
ggml_allocr_alloc(alloc, layer.wq_b);
ggml_allocr_alloc(alloc, layer.wk_a);
ggml_allocr_alloc(alloc, layer.wk_b);
ggml_allocr_alloc(alloc, layer.wv_a);
ggml_allocr_alloc(alloc, layer.wv_b);
ggml_allocr_alloc(alloc, layer.wo_a);
ggml_allocr_alloc(alloc, layer.wo_b);
ggml_allocr_alloc(alloc, layer.ffn_norm_a);
ggml_allocr_alloc(alloc, layer.ffn_norm_b);
ggml_allocr_alloc(alloc, layer.w1_a);
ggml_allocr_alloc(alloc, layer.w1_b);
ggml_allocr_alloc(alloc, layer.w2_a);
ggml_allocr_alloc(alloc, layer.w2_b);
ggml_allocr_alloc(alloc, layer.w3_a);
ggml_allocr_alloc(alloc, layer.w3_b);
}
ggml_allocr_alloc(alloc, lora->tok_embeddings_a->grad);
ggml_allocr_alloc(alloc, lora->tok_embeddings_b->grad);
ggml_allocr_alloc(alloc, lora->norm_a->grad);
ggml_allocr_alloc(alloc, lora->norm_b->grad);
ggml_allocr_alloc(alloc, lora->output_a->grad);
ggml_allocr_alloc(alloc, lora->output_b->grad);
for (uint32_t i = 0; i < lora->layers.size(); ++i) {
auto & layer = lora->layers[i];
ggml_allocr_alloc(alloc, layer.attention_norm_a->grad);
ggml_allocr_alloc(alloc, layer.attention_norm_b->grad);
ggml_allocr_alloc(alloc, layer.wq_a->grad);
ggml_allocr_alloc(alloc, layer.wq_b->grad);
ggml_allocr_alloc(alloc, layer.wk_a->grad);
ggml_allocr_alloc(alloc, layer.wk_b->grad);
ggml_allocr_alloc(alloc, layer.wv_a->grad);
ggml_allocr_alloc(alloc, layer.wv_b->grad);
ggml_allocr_alloc(alloc, layer.wo_a->grad);
ggml_allocr_alloc(alloc, layer.wo_b->grad);
ggml_allocr_alloc(alloc, layer.ffn_norm_a->grad);
ggml_allocr_alloc(alloc, layer.ffn_norm_b->grad);
ggml_allocr_alloc(alloc, layer.w1_a->grad);
ggml_allocr_alloc(alloc, layer.w1_b->grad);
ggml_allocr_alloc(alloc, layer.w2_a->grad);
ggml_allocr_alloc(alloc, layer.w2_b->grad);
ggml_allocr_alloc(alloc, layer.w3_a->grad);
ggml_allocr_alloc(alloc, layer.w3_b->grad);
}
}
static void init_lora(const struct my_llama_model * model, struct my_llama_lora * lora) { static void init_lora(const struct my_llama_model * model, struct my_llama_lora * lora) {
const auto & lparams = lora->hparams; const auto & lparams = lora->hparams;
@ -522,18 +464,8 @@ static void init_lora(const struct my_llama_model * model, struct my_llama_lora
set_param_lora(lora); set_param_lora(lora);
// measure data size // allocate data for lora tensors
size_t size = 0; lora->data = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_cpu_buffer_type());
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
size += GGML_PAD(ggml_nbytes(t), tensor_alignment);
}
// allocate data
struct ggml_allocr * alloc = NULL;
lora->data.resize(size + tensor_alignment);
alloc = ggml_allocr_new(lora->data.data(), lora->data.size(), tensor_alignment);
alloc_lora(alloc, lora);
ggml_allocr_free(alloc);
} }
static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, float std, float min, float max) { static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, float std, float min, float max) {
@ -579,7 +511,7 @@ static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, fl
static struct ggml_tensor * llama_build_lora_finetune_graphs( static struct ggml_tensor * llama_build_lora_finetune_graphs(
struct my_llama_model * model, struct my_llama_model * model,
struct my_llama_lora * lora, struct my_llama_lora * lora,
struct ggml_allocr * alloc, ggml_gallocr_t alloc,
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct ggml_cgraph * gb, struct ggml_cgraph * gb,
@ -590,7 +522,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
const int n_tokens, const int n_tokens,
const int n_batch, const int n_batch,
const bool enable_flash_attn, const bool enable_flash_attn,
const bool enable_checkpointing) { const bool enable_checkpointing,
const bool measure_only) {
ggml_set_scratch(ctx, { 0, 0, nullptr, }); ggml_set_scratch(ctx, { 0, 0, nullptr, });
const int n_past = 0; const int n_past = 0;
@ -622,13 +555,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
// KQ_pos - contains the positions // KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N); struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
ggml_allocr_alloc(alloc, KQ_pos); ggml_set_input(KQ_pos);
if (!ggml_allocr_is_measure(alloc)) {
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}
// rope has so much parameters that we make a custom function for it // rope has so much parameters that we make a custom function for it
auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale] auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
@ -780,7 +707,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
// input gradient // input gradient
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f)); ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL); GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
ggml_allocr_alloc(alloc, t36->grad); ggml_set_input(t36->grad);
// KQ_pos // KQ_pos
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f)); ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
@ -805,11 +732,23 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
// note: they will be freed in reverse order // note: they will be freed in reverse order
for (unsigned int i = 0; i < checkpoints.size(); ++i) { for (unsigned int i = 0; i < checkpoints.size(); ++i) {
if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) { if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
ggml_allocr_alloc(alloc, checkpoints[i]); ggml_set_input(checkpoints[i]);
} }
} }
ggml_allocr_alloc_graph(alloc, gb); if (measure_only) {
ggml_gallocr_reserve(alloc, gb);
} else {
ggml_gallocr_alloc_graph(alloc, gb);
// set KQ_pos
{
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}
}
// remove the additional nodes and leafs // remove the additional nodes and leafs
for (int i = n_leafs_before; i < gb->n_leafs; ++i) { for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
@ -1663,7 +1602,7 @@ int main(int argc, char ** argv) {
printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples); printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples);
printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens); printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens);
printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs); printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
printf("%s: lora_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(lora.ctx) + lora.data.size()), (float) (ggml_used_mem(lora.ctx) + lora.data.size()) / (1024.0f*1024.0f)); printf("%s: lora_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(lora.ctx) + ggml_backend_buffer_get_size(lora.data)), (float) (ggml_used_mem(lora.ctx) + ggml_backend_buffer_get_size(lora.data)) / (1024.0f*1024.0f));
if (params.only_write_lora) { if (params.only_write_lora) {
save_train_files_data save_data; save_train_files_data save_data;
@ -1690,10 +1629,6 @@ int main(int argc, char ** argv) {
int n_vocab = model.hparams.n_vocab; int n_vocab = model.hparams.n_vocab;
int n_batch = params.common.n_batch; int n_batch = params.common.n_batch;
std::vector<uint8_t> mem_input_data;
std::vector<uint8_t> mem_compute_data;
// context for input tensors without their data // context for input tensors without their data
struct ggml_init_params ctx_input_params = { struct ggml_init_params ctx_input_params = {
ggml_tensor_overhead() * 2, // mem_size ggml_tensor_overhead() * 2, // mem_size
@ -1706,17 +1641,11 @@ int main(int argc, char ** argv) {
struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx_input, GGML_TYPE_I32, n_tokens, n_batch); struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx_input, GGML_TYPE_I32, n_tokens, n_batch);
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
// measure required memory for input tensors
size_t max_input_size = GGML_PAD(ggml_nbytes(tokens_input), tensor_alignment) +
GGML_PAD(ggml_nbytes(target_probs), tensor_alignment) +
tensor_alignment;
printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
// allocate input tensors // allocate input tensors
mem_input_data.resize(max_input_size); // measure required memory for input tensors
ggml_allocr_t alloc_inps = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment); ggml_backend_buffer_t input_data = ggml_backend_alloc_ctx_tensors_from_buft(ctx_input, ggml_backend_cpu_buffer_type());
ggml_allocr_alloc(alloc_inps, tokens_input); size_t max_input_size = ggml_backend_buffer_get_size(input_data);
ggml_allocr_alloc(alloc_inps, target_probs); printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
// context for compute tensors without their data // context for compute tensors without their data
const size_t estimated_compute_size_wo_data = ( const size_t estimated_compute_size_wo_data = (
@ -1743,7 +1672,7 @@ int main(int argc, char ** argv) {
// find best evaluation order // find best evaluation order
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) { for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
ctx_compute = ggml_init(ctx_compute_params); ctx_compute = ggml_init(ctx_compute_params);
ggml_allocr_t alloc = ggml_allocr_new_measure(tensor_alignment); ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order; gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
@ -1756,14 +1685,15 @@ int main(int argc, char ** argv) {
&logits, tokens_input, target_probs, &logits, tokens_input, target_probs,
n_tokens, n_batch, n_tokens, n_batch,
params.common.use_flash, params.common.use_flash,
params.common.use_checkpointing params.common.use_checkpointing,
true
); );
size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment; size_t max_compute_size = ggml_gallocr_get_buffer_size(alloc, 0); // FIXME: this will still allocate the buffer
if (max_compute_size < best_compute_size) { if (max_compute_size < best_compute_size) {
best_compute_size = max_compute_size; best_compute_size = max_compute_size;
best_order = gf->order; best_order = gf->order;
} }
ggml_allocr_free(alloc); ggml_gallocr_free(alloc);
ggml_free(ctx_compute); ggml_free(ctx_compute);
} }
size_t max_compute_size = best_compute_size; size_t max_compute_size = best_compute_size;
@ -1774,9 +1704,8 @@ int main(int argc, char ** argv) {
"invalid"); "invalid");
// allocate compute tensors // allocate compute tensors
mem_compute_data.resize(max_compute_size);
ctx_compute = ggml_init(ctx_compute_params); ctx_compute = ggml_init(ctx_compute_params);
ggml_allocr_t alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment); ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order; gf->order = best_order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
@ -1789,11 +1718,9 @@ int main(int argc, char ** argv) {
&logits, tokens_input, target_probs, &logits, tokens_input, target_probs,
n_tokens, n_batch, n_tokens, n_batch,
params.common.use_flash, params.common.use_flash,
params.common.use_checkpointing params.common.use_checkpointing,
false
); );
ggml_allocr_free(alloc);
ggml_allocr_free(alloc_inps);
// tokenize data // tokenize data
std::vector<llama_token> train_tokens; std::vector<llama_token> train_tokens;
@ -1908,6 +1835,8 @@ int main(int argc, char ** argv) {
ggml_free(ctx_work); ggml_free(ctx_work);
ggml_free(ctx_compute); ggml_free(ctx_compute);
ggml_free(ctx_input); ggml_free(ctx_input);
ggml_gallocr_free(alloc);
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
printf("%s: total training time: ", __func__); printf("%s: total training time: ", __func__);

View file

@ -29,19 +29,25 @@ git clone https://huggingface.co/liuhaotian/llava-v1.5-7b
git clone https://huggingface.co/openai/clip-vit-large-patch14-336 git clone https://huggingface.co/openai/clip-vit-large-patch14-336
``` ```
2. Use `llava-surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents: 2. Install the required Python packages:
```sh
pip install -r examples/llava/requirements.txt
```
3. Use `llava-surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents:
```sh ```sh
python ./examples/llava/llava-surgery.py -m ../llava-v1.5-7b python ./examples/llava/llava-surgery.py -m ../llava-v1.5-7b
``` ```
3. Use `convert-image-encoder-to-gguf.py` to convert the LLaVA image encoder to GGUF: 4. Use `convert-image-encoder-to-gguf.py` to convert the LLaVA image encoder to GGUF:
```sh ```sh
python ./examples/llava/convert-image-encoder-to-gguf.py -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b python ./examples/llava/convert-image-encoder-to-gguf.py -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b
``` ```
4. Use `convert.py` to convert the LLaMA part of LLaVA to GGUF: 5. Use `convert.py` to convert the LLaMA part of LLaVA to GGUF:
```sh ```sh
python ./convert.py ../llava-v1.5-7b python ./convert.py ../llava-v1.5-7b

View file

@ -367,7 +367,7 @@ struct clip_ctx {
ggml_backend_buffer_t params_buffer = NULL; ggml_backend_buffer_t params_buffer = NULL;
ggml_backend_buffer_t compute_buffer = NULL; ggml_backend_buffer_t compute_buffer = NULL;
ggml_backend_t backend = NULL; ggml_backend_t backend = NULL;
ggml_allocr * compute_alloc = NULL; ggml_gallocr_t compute_alloc = NULL;
}; };
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs) { static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
@ -405,31 +405,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
struct ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size); struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size);
ggml_allocr_alloc(ctx->compute_alloc, inp_raw); ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
float * data = (float *)malloc(ggml_nbytes(inp_raw));
for (size_t i = 0; i < imgs->size; i++) {
const int nx = imgs->data[i].nx;
const int ny = imgs->data[i].ny;
GGML_ASSERT(nx == image_size && ny == image_size);
const int n = nx * ny;
for (int b = 0; b < batch_size; b++) {
for (int k = 0; k < 3; k++) {
for (int y = 0; y < ny; y++) {
for (int x = 0; x < nx; x++) {
data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
}
}
}
}
}
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
free(data);
}
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1); struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
@ -438,13 +415,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
// concat class_embeddings and patch_embeddings // concat class_embeddings and patch_embeddings
struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
ggml_allocr_alloc(ctx->compute_alloc, embeddings); ggml_set_name(embeddings, "embeddings");
if (!ggml_allocr_is_measure(ctx->compute_alloc)) { ggml_set_input(embeddings);
void* zero_mem = malloc(ggml_nbytes(embeddings));
memset(zero_mem, 0, ggml_nbytes(embeddings));
ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
free(zero_mem);
}
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
@ -453,15 +425,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
ggml_allocr_alloc(ctx->compute_alloc, positions); ggml_set_name(positions, "positions");
if (!ggml_allocr_is_measure(ctx->compute_alloc)) { ggml_set_input(positions);
int* positions_data = (int*)malloc(ggml_nbytes(positions));
for (int i = 0; i < num_positions; i++) {
positions_data[i] = i;
}
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
}
embeddings = embeddings =
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
@ -560,15 +525,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
ggml_allocr_alloc(ctx->compute_alloc, patches); ggml_set_name(patches, "patches");
if (!ggml_allocr_is_measure(ctx->compute_alloc)) { ggml_set_input(patches);
int* patches_data = (int*)malloc(ggml_nbytes(patches));
for (int i = 0; i < num_patches; i++) {
patches_data[i] = i + 1;
}
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
free(patches_data);
}
// shape [1, 576, 1024] // shape [1, 576, 1024]
// ne is whcn, ne = [1024, 576, 1, 1] // ne is whcn, ne = [1024, 576, 1, 1]
@ -809,7 +767,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
} }
// data // data
size_t buffer_size = 0; size_t model_size = 0;
{ {
for (int i = 0; i < n_tensors; ++i) { for (int i = 0; i < n_tensors; ++i) {
const char * name = gguf_get_tensor_name(ctx, i); const char * name = gguf_get_tensor_name(ctx, i);
@ -817,7 +775,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
enum ggml_type type = gguf_get_tensor_type(ctx, i); enum ggml_type type = gguf_get_tensor_type(ctx, i);
struct ggml_tensor * cur = ggml_get_tensor(meta, name); struct ggml_tensor * cur = ggml_get_tensor(meta, name);
size_t tensor_size = ggml_nbytes(cur); size_t tensor_size = ggml_nbytes(cur);
buffer_size += tensor_size; model_size += tensor_size;
if (verbosity >= 3) { if (verbosity >= 3) {
printf("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n", printf("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n",
__func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type)); __func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type));
@ -825,8 +783,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
} }
} }
buffer_size += n_tensors * 128 /* CLIP PADDING */;
clip_ctx * new_clip = new clip_ctx; clip_ctx * new_clip = new clip_ctx;
// update projector type // update projector type
@ -886,12 +842,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
printf("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder); printf("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
printf("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector); printf("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
printf("%s: model size: %.2f MB\n", __func__, buffer_size / 1024.0 / 1024.0); printf("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0);
printf("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0); printf("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
} }
} }
printf("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, buffer_size / (1024.0 * 1024.0), n_tensors); printf("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, model_size / (1024.0 * 1024.0), n_tensors);
// load tensors // load tensors
{ {
@ -925,12 +881,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
} }
// alloc memory and offload data // alloc memory and offload data
new_clip->params_buffer = ggml_backend_alloc_buffer(new_clip->backend, buffer_size); new_clip->params_buffer = ggml_backend_alloc_ctx_tensors(new_clip->ctx_data, new_clip->backend);
ggml_allocr* alloc = ggml_allocr_new_from_buffer(new_clip->params_buffer);
for (int i = 0; i < n_tensors; ++i) { for (int i = 0; i < n_tensors; ++i) {
const char * name = gguf_get_tensor_name(ctx, i); const char * name = gguf_get_tensor_name(ctx, i);
struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx_data, name); struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx_data, name);
ggml_allocr_alloc(alloc, cur);
const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i); const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i);
fin.seekg(offset, std::ios::beg); fin.seekg(offset, std::ios::beg);
if (!fin) { if (!fin) {
@ -949,7 +903,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
} }
} }
ggml_allocr_free(alloc);
fin.close(); fin.close();
} }
@ -1077,15 +1030,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
// measure mem requirement and allocate // measure mem requirement and allocate
{ {
new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead()); new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
new_clip->compute_alloc = ggml_allocr_new_measure_from_backend(new_clip->backend); new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend));
clip_image_f32_batch batch; clip_image_f32_batch batch;
batch.size = 1; batch.size = 1;
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch); ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch);
size_t compute_memory_buffer_size = ggml_allocr_alloc_graph(new_clip->compute_alloc, gf); ggml_gallocr_reserve(new_clip->compute_alloc, gf);
ggml_allocr_free(new_clip->compute_alloc); size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0);
new_clip->compute_buffer = ggml_backend_alloc_buffer(new_clip->backend, compute_memory_buffer_size);
new_clip->compute_alloc = ggml_allocr_new_from_buffer(new_clip->compute_buffer);
printf("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0); printf("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0);
} }
@ -1267,12 +1217,72 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
GGML_ASSERT(batch_size == 1); // TODO: support multiple images GGML_ASSERT(batch_size == 1); // TODO: support multiple images
} }
// reset alloc buffer to clean the memory from previous invocations
ggml_allocr_reset(ctx->compute_alloc);
// build the inference graph // build the inference graph
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs); ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
ggml_allocr_alloc_graph(ctx->compute_alloc, gf); ggml_gallocr_alloc_graph(ctx->compute_alloc, gf);
// set inputs
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
const int image_size = hparams.image_size;
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size / patch_size) * (image_size / patch_size));
const int num_positions = num_patches + 1;
{
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
float * data = (float *)malloc(ggml_nbytes(inp_raw));
for (size_t i = 0; i < imgs->size; i++) {
const int nx = imgs->data[i].nx;
const int ny = imgs->data[i].ny;
GGML_ASSERT(nx == image_size && ny == image_size);
const int n = nx * ny;
for (int b = 0; b < batch_size; b++) {
for (int k = 0; k < 3; k++) {
for (int y = 0; y < ny; y++) {
for (int x = 0; x < nx; x++) {
data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
}
}
}
}
}
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
free(data);
}
{
struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
void* zero_mem = malloc(ggml_nbytes(embeddings));
memset(zero_mem, 0, ggml_nbytes(embeddings));
ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
free(zero_mem);
}
{
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
int* positions_data = (int*)malloc(ggml_nbytes(positions));
for (int i = 0; i < num_positions; i++) {
positions_data[i] = i;
}
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
}
{
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
int* patches_data = (int*)malloc(ggml_nbytes(patches));
for (int i = 0; i < num_patches; i++) {
patches_data[i] = i + 1;
}
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
free(patches_data);
}
if (ggml_backend_is_cpu(ctx->backend)) { if (ggml_backend_is_cpu(ctx->backend)) {
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads); ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);

View file

@ -71,7 +71,7 @@ def bytes_to_unicode():
return dict(zip(bs, cs)) return dict(zip(bs, cs))
ap = argparse.ArgumentParser(prog="convert_hf_to_gguf.py") ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True) ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16") ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
ap.add_argument("--text-only", action="store_true", required=False, ap.add_argument("--text-only", action="store_true", required=False,

View file

@ -42,5 +42,5 @@ if len(clip_tensors) > 0:
torch.save(checkpoint, path) torch.save(checkpoint, path)
print("Done!") print("Done!")
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.") print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.") print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")

View file

@ -0,0 +1,3 @@
-r ../../requirements/requirements-convert.txt
pillow~=10.2.0
torch~=2.1.1

View file

@ -1,7 +1,9 @@
#include "common.h" #include "common.h"
#include "ggml.h"
#include "llama.h" #include "llama.h"
#include <cmath> #include <cmath>
#include <cstdint>
#include <cstdio> #include <cstdio>
#include <string> #include <string>
#include <vector> #include <vector>
@ -73,6 +75,8 @@ int main(int argc, char ** argv){
int n_drafted = 0; int n_drafted = 0;
int n_accept = 0; int n_accept = 0;
int64_t t_draft_us = 0;
int n_past = inp.size(); int n_past = inp.size();
bool has_eos = false; bool has_eos = false;
@ -160,7 +164,7 @@ int main(int argc, char ** argv){
// generate n_pred tokens through prompt lookup // generate n_pred tokens through prompt lookup
auto prompt_lookup = [&]() -> void { auto prompt_lookup = [&]() -> void {
int inp_size = inp.size(); const int inp_size = inp.size();
for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){ for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){
const llama_token * ngram = &inp[inp_size - ngram_size]; const llama_token * ngram = &inp[inp_size - ngram_size];
@ -191,8 +195,12 @@ int main(int argc, char ** argv){
return; return;
}; };
const int64_t t_start_draft_us = ggml_time_us();
prompt_lookup(); prompt_lookup();
t_draft_us += ggml_time_us() - t_start_draft_us;
llama_decode(ctx, batch_tgt); llama_decode(ctx, batch_tgt);
++n_past; ++n_past;
@ -210,6 +218,8 @@ int main(int argc, char ** argv){
LOG_TEE("n_draft = %d\n", n_draft); LOG_TEE("n_draft = %d\n", n_draft);
LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_predict = %d\n", n_predict);
LOG_TEE("n_drafted = %d\n", n_drafted); LOG_TEE("n_drafted = %d\n", n_drafted);
LOG_TEE("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n",
t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us));
LOG_TEE("n_accept = %d\n", n_accept); LOG_TEE("n_accept = %d\n", n_accept);
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);

View file

@ -98,7 +98,7 @@ static void write_logfile(
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
static void sigint_handler(int signo) { static void sigint_handler(int signo) {
if (signo == SIGINT) { if (signo == SIGINT) {
if (!is_interacting) { if (!is_interacting && g_params->interactive) {
is_interacting = true; is_interacting = true;
} else { } else {
console::cleanup(); console::cleanup();
@ -392,7 +392,8 @@ int main(int argc, char ** argv) {
LOG_TEE("\n"); LOG_TEE("\n");
} }
if (params.interactive) { // ctrl+C handling
{
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action; struct sigaction sigint_action;
sigint_action.sa_handler = sigint_handler; sigint_action.sa_handler = sigint_handler;
@ -405,7 +406,9 @@ int main(int argc, char ** argv) {
}; };
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true); SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif #endif
}
if (params.interactive) {
LOG_TEE("%s: interactive mode on.\n", __func__); LOG_TEE("%s: interactive mode on.\n", __func__);
if (!params.antiprompt.empty()) { if (!params.antiprompt.empty()) {

View file

@ -185,7 +185,7 @@ node index.js
`ignore_eos`: Ignore end of stream token and continue generating (default: false). `ignore_eos`: Ignore end of stream token and continue generating (default: false).
`logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced (default: []). `logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. (default: []).
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0) `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0)

View file

@ -15,9 +15,13 @@
using json = nlohmann::json; using json = nlohmann::json;
inline static json oaicompat_completion_params_parse( inline static json oaicompat_completion_params_parse(
const json &body /* openai api json semantics */) const json &body, /* openai api json semantics */
const std::string &chat_template)
{ {
json llama_params; json llama_params;
std::string formatted_prompt = chat_template == "chatml"
? format_chatml(body["messages"]) // OpenAI 'messages' to chatml (with <|im_start|>,...)
: format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...)
llama_params["__oaicompat"] = true; llama_params["__oaicompat"] = true;
@ -30,7 +34,7 @@ inline static json oaicompat_completion_params_parse(
// https://platform.openai.com/docs/api-reference/chat/create // https://platform.openai.com/docs/api-reference/chat/create
llama_sampling_params default_sparams; llama_sampling_params default_sparams;
llama_params["model"] = json_value(body, "model", std::string("unknown")); llama_params["model"] = json_value(body, "model", std::string("unknown"));
llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' llama_params["prompt"] = formatted_prompt;
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
llama_params["temperature"] = json_value(body, "temperature", 0.0); llama_params["temperature"] = json_value(body, "temperature", 0.0);
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);

View file

@ -36,6 +36,7 @@ struct server_params
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
std::vector<std::string> api_keys; std::vector<std::string> api_keys;
std::string public_path = "examples/server/public"; std::string public_path = "examples/server/public";
std::string chat_template = "chatml";
int32_t port = 8080; int32_t port = 8080;
int32_t read_timeout = 600; int32_t read_timeout = 600;
int32_t write_timeout = 600; int32_t write_timeout = 600;
@ -625,18 +626,36 @@ struct llama_server_context
const int n_vocab = llama_n_vocab(model); const int n_vocab = llama_n_vocab(model);
for (const auto &el : *logit_bias) for (const auto &el : *logit_bias)
{ {
if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) if (el.is_array() && el.size() == 2)
{ {
llama_token tok = el[0].get<llama_token>(); float bias;
if (tok >= 0 && tok < n_vocab) if (el[1].is_number())
{ {
if (el[1].is_number()) bias = el[1].get<float>();
}
else if (el[1].is_boolean() && !el[1].get<bool>())
{
bias = -INFINITY;
}
else
{
continue;
}
if (el[0].is_number_integer())
{
llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab)
{ {
slot->sparams.logit_bias[tok] = el[1].get<float>(); slot->sparams.logit_bias[tok] = bias;
} }
else if (el[1].is_boolean() && !el[1].get<bool>()) }
else if (el[0].is_string())
{
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
for (auto tok : toks)
{ {
slot->sparams.logit_bias[tok] = -INFINITY; slot->sparams.logit_bias[tok] = bias;
} }
} }
} }
@ -1592,10 +1611,6 @@ struct llama_server_context
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
} }
LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
slot.cache_tokens = prompt_tokens; slot.cache_tokens = prompt_tokens;
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0) if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
@ -1609,6 +1624,10 @@ struct llama_server_context
} }
} }
LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
LOG_VERBOSE("prompt ingested", { LOG_VERBOSE("prompt ingested", {
{"n_past", slot.n_past}, {"n_past", slot.n_past},
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
@ -1862,6 +1881,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`"); printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`"); printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
printf(" --chat-template FORMAT_NAME");
printf(" set chat template, possible valus is: llama2, chatml (default %s)", sparams.chat_template.c_str());
printf("\n"); printf("\n");
} }
@ -2301,6 +2322,21 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
log_set_target(stdout); log_set_target(stdout);
LOG_INFO("logging to file is disabled.", {}); LOG_INFO("logging to file is disabled.", {});
} }
else if (arg == "--chat-template")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
std::string value(argv[i]);
if (value != "chatml" && value != "llama2") {
fprintf(stderr, "error: chat template can be \"llama2\" or \"chatml\", but got: %s\n", value.c_str());
invalid_param = true;
break;
}
sparams.chat_template = value;
}
else if (arg == "--override-kv") else if (arg == "--override-kv")
{ {
if (++i >= argc) { if (++i >= argc) {
@ -2754,13 +2790,13 @@ int main(int argc, char **argv)
// TODO: add mount point without "/v1" prefix -- how? // TODO: add mount point without "/v1" prefix -- how?
svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
{ {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) { if (!validate_api_key(req, res)) {
return; return;
} }
json data = oaicompat_completion_params_parse(json::parse(req.body)); json data = oaicompat_completion_params_parse(json::parse(req.body), sparams.chat_template);
const int task_id = llama.queue_tasks.get_new_id(); const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id); llama.queue_results.add_waiting_task_id(task_id);

View file

@ -167,6 +167,34 @@ static T json_value(const json &body, const std::string &key, const T &default_v
: default_value; : default_value;
} }
inline std::string format_llama2(std::vector<json> messages)
{
std::ostringstream output;
bool is_inside_turn = false;
for (auto it = messages.begin(); it != messages.end(); ++it) {
if (!is_inside_turn) {
output << "[INST] ";
}
std::string role = json_value(*it, "role", std::string("user"));
std::string content = json_value(*it, "content", std::string(""));
if (role == "system") {
output << "<<SYS>>\n" << content << "\n<<SYS>>\n\n";
is_inside_turn = true;
} else if (role == "user") {
output << content << " [/INST]";
is_inside_turn = true;
} else {
output << " " << content << " </s>";
is_inside_turn = false;
}
}
LOG_VERBOSE("format_llama2", {{"text", output.str()}});
return output.str();
}
inline std::string format_chatml(std::vector<json> messages) inline std::string format_chatml(std::vector<json> messages)
{ {
std::ostringstream chatml_msgs; std::ostringstream chatml_msgs;
@ -180,6 +208,8 @@ inline std::string format_chatml(std::vector<json> messages)
chatml_msgs << "<|im_start|>assistant" << '\n'; chatml_msgs << "<|im_start|>assistant" << '\n';
LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}});
return chatml_msgs.str(); return chatml_msgs.str();
} }

View file

@ -1,5 +1,6 @@
#include "ggml.h" #include "ggml.h"
#include "ggml-alloc.h" #include "ggml-alloc.h"
#include "ggml-backend.h"
#include "common.h" #include "common.h"
#include "train.h" #include "train.h"
#include "llama.h" #include "llama.h"
@ -19,8 +20,6 @@
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
static const size_t tensor_alignment = 32;
struct my_llama_hparams { struct my_llama_hparams {
uint32_t n_vocab = 32000; uint32_t n_vocab = 32000;
uint32_t n_ctx = 512; uint32_t n_ctx = 512;
@ -58,7 +57,7 @@ struct my_llama_layer {
struct my_llama_model { struct my_llama_model {
struct ggml_context * ctx = NULL; struct ggml_context * ctx = NULL;
std::vector<uint8_t> data; ggml_backend_buffer_t data = NULL;
my_llama_hparams hparams; my_llama_hparams hparams;
@ -147,39 +146,6 @@ static void set_param_model(struct my_llama_model * model) {
} }
} }
static void alloc_model(struct ggml_allocr * alloc, struct my_llama_model * model) {
ggml_allocr_alloc(alloc, model->tok_embeddings);
ggml_allocr_alloc(alloc, model->norm);
ggml_allocr_alloc(alloc, model->output);
for (uint32_t i = 0; i < model->layers.size(); ++i) {
auto & layer = model->layers[i];
ggml_allocr_alloc(alloc, layer.attention_norm);
ggml_allocr_alloc(alloc, layer.wq);
ggml_allocr_alloc(alloc, layer.wk);
ggml_allocr_alloc(alloc, layer.wv);
ggml_allocr_alloc(alloc, layer.wo);
ggml_allocr_alloc(alloc, layer.ffn_norm);
ggml_allocr_alloc(alloc, layer.w1);
ggml_allocr_alloc(alloc, layer.w2);
ggml_allocr_alloc(alloc, layer.w3);
}
ggml_allocr_alloc(alloc, model->tok_embeddings->grad);
ggml_allocr_alloc(alloc, model->norm->grad);
ggml_allocr_alloc(alloc, model->output->grad);
for (uint32_t i = 0; i < model->layers.size(); ++i) {
auto & layer = model->layers[i];
ggml_allocr_alloc(alloc, layer.attention_norm->grad);
ggml_allocr_alloc(alloc, layer.wq->grad);
ggml_allocr_alloc(alloc, layer.wk->grad);
ggml_allocr_alloc(alloc, layer.wv->grad);
ggml_allocr_alloc(alloc, layer.wo->grad);
ggml_allocr_alloc(alloc, layer.ffn_norm->grad);
ggml_allocr_alloc(alloc, layer.w1->grad);
ggml_allocr_alloc(alloc, layer.w2->grad);
ggml_allocr_alloc(alloc, layer.w3->grad);
}
}
static void init_model(struct my_llama_model * model) { static void init_model(struct my_llama_model * model) {
const auto & hparams = model->hparams; const auto & hparams = model->hparams;
@ -252,17 +218,8 @@ static void init_model(struct my_llama_model * model) {
set_param_model(model); set_param_model(model);
// measure data size
size_t size = 0;
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
size += GGML_PAD(ggml_nbytes(t), tensor_alignment);
}
// allocate data // allocate data
struct ggml_allocr * alloc = NULL; model->data = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_cpu_buffer_type());
model->data.resize(size + tensor_alignment);
alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment);
alloc_model(alloc, model);
} }
static void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) { static void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) {
@ -297,7 +254,7 @@ static void randomize_model(struct my_llama_model * model, int seed, float mean,
static struct ggml_tensor * llama_build_train_graphs( static struct ggml_tensor * llama_build_train_graphs(
struct my_llama_model * model, struct my_llama_model * model,
struct ggml_allocr * alloc, ggml_gallocr_t alloc,
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct ggml_cgraph * gb, struct ggml_cgraph * gb,
@ -308,7 +265,8 @@ static struct ggml_tensor * llama_build_train_graphs(
const int n_tokens, const int n_tokens,
const int n_batch, const int n_batch,
const bool enable_flash_attn, const bool enable_flash_attn,
const bool enable_checkpointing) { const bool enable_checkpointing,
const bool measure_only) {
ggml_set_scratch(ctx, { 0, 0, nullptr, }); ggml_set_scratch(ctx, { 0, 0, nullptr, });
const int n_past = 0; const int n_past = 0;
@ -334,13 +292,7 @@ static struct ggml_tensor * llama_build_train_graphs(
// KQ_pos - contains the positions // KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N); struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
ggml_allocr_alloc(alloc, KQ_pos); ggml_set_input(KQ_pos);
if (!ggml_allocr_is_measure(alloc)) {
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}
// rope has so much parameters that we make a custom function for it // rope has so much parameters that we make a custom function for it
auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale] auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
@ -448,21 +400,31 @@ static struct ggml_tensor * llama_build_train_graphs(
// KQ_pos // KQ_pos
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f)); ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL); GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
ggml_set_input(t36->grad);
ggml_allocr_alloc(alloc, t36->grad);
// allocating checkpoints in one block to reduce memory fragmentation // allocating checkpoints in one block to reduce memory fragmentation
// note: they will be freed in reverse order // note: they will be freed in reverse order
for (int i = 0; i < (int) checkpoints.size(); ++i) { for (int i = 0; i < (int) checkpoints.size(); ++i) {
if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) { if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
ggml_allocr_alloc(alloc, checkpoints[i]); ggml_set_input(checkpoints[i]);
} }
} }
//int n_leafs_after = gb->n_leafs; //int n_leafs_after = gb->n_leafs;
//int n_nodes_after = gb->n_nodes; //int n_nodes_after = gb->n_nodes;
if (measure_only) {
// FIXME: will still allocate
ggml_gallocr_reserve(alloc, gb);
} else {
ggml_gallocr_alloc_graph(alloc, gb);
ggml_allocr_alloc_graph(alloc, gb); if (!measure_only) {
int * data = (int *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
}
}
// remove the additional nodes and leafs // remove the additional nodes and leafs
for (int i = n_leafs_before; i < gb->n_leafs; ++i) { for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
@ -1046,7 +1008,7 @@ int main(int argc, char ** argv) {
printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples); printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples);
printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens); printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens);
printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs); printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
printf("%s: model_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(model.ctx) + model.data.size()), (float) (ggml_used_mem(model.ctx) + model.data.size()) / (1024.0f*1024.0f)); printf("%s: model_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(model.ctx) + ggml_backend_buffer_get_size(model.data)), (float) (ggml_used_mem(model.ctx) + ggml_backend_buffer_get_size(model.data)) / (1024.0f*1024.0f));
if (params.only_write_model) { if (params.only_write_model) {
save_train_files_data save_data; save_train_files_data save_data;
@ -1073,11 +1035,6 @@ int main(int argc, char ** argv) {
int n_vocab = model.hparams.n_vocab; int n_vocab = model.hparams.n_vocab;
int n_batch = params.common.n_batch; int n_batch = params.common.n_batch;
std::vector<uint8_t> mem_input_data;
std::vector<uint8_t> mem_compute_data;
ggml_allocr * alloc = NULL;
// context for input tensors without their data // context for input tensors without their data
struct ggml_init_params ctx_input_params = { struct ggml_init_params ctx_input_params = {
ggml_tensor_overhead() * 2, // mem_size ggml_tensor_overhead() * 2, // mem_size
@ -1091,16 +1048,10 @@ int main(int argc, char ** argv) {
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
// measure required memory for input tensors // measure required memory for input tensors
size_t max_input_size = GGML_PAD(ggml_nbytes(tokens_input), tensor_alignment) +
GGML_PAD(ggml_nbytes(target_probs), tensor_alignment) +
tensor_alignment;
printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
// allocate input tensors // allocate input tensors
mem_input_data.resize(max_input_size); ggml_backend_buffer_t input_data = ggml_backend_alloc_ctx_tensors_from_buft(ctx_input, ggml_backend_cpu_buffer_type());
alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment); size_t max_input_size = ggml_backend_buffer_get_size(input_data);
ggml_allocr_alloc(alloc, tokens_input); printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
ggml_allocr_alloc(alloc, target_probs);
// context for compute tensors without their data // context for compute tensors without their data
const size_t estimated_compute_size_wo_data = ( const size_t estimated_compute_size_wo_data = (
@ -1127,7 +1078,7 @@ int main(int argc, char ** argv) {
// find best evaluation order // find best evaluation order
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) { for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
ctx_compute = ggml_init(ctx_compute_params); ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new_measure(tensor_alignment); ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order; gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
@ -1140,9 +1091,10 @@ int main(int argc, char ** argv) {
&logits, tokens_input, target_probs, &logits, tokens_input, target_probs,
n_tokens, n_batch, n_tokens, n_batch,
params.common.use_flash, params.common.use_flash,
params.common.use_checkpointing params.common.use_checkpointing,
true
); );
size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment; size_t max_compute_size = ggml_gallocr_get_buffer_size(alloc, 0); // FIXME: this will still allocate the buffer
if (max_compute_size < best_compute_size) { if (max_compute_size < best_compute_size) {
best_compute_size = max_compute_size; best_compute_size = max_compute_size;
best_order = gf->order; best_order = gf->order;
@ -1157,9 +1109,8 @@ int main(int argc, char ** argv) {
"invalid"); "invalid");
// allocate compute tensors // allocate compute tensors
mem_compute_data.resize(max_compute_size);
ctx_compute = ggml_init(ctx_compute_params); ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment); ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order; gf->order = best_order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true); gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
@ -1172,7 +1123,8 @@ int main(int argc, char ** argv) {
&logits, tokens_input, target_probs, &logits, tokens_input, target_probs,
n_tokens, n_batch, n_tokens, n_batch,
params.common.use_flash, params.common.use_flash,
params.common.use_checkpointing params.common.use_checkpointing,
false
); );
std::vector<llama_token> train_tokens; std::vector<llama_token> train_tokens;

6
flake.lock generated
View file

@ -20,11 +20,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1706732774, "lastModified": 1707268954,
"narHash": "sha256-hqJlyJk4MRpcItGYMF+3uHe8HvxNETWvlGtLuVpqLU0=", "narHash": "sha256-2en1kvde3cJVc3ZnTy8QeD2oKcseLFjYPLKhIGDanQ0=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "b8b232ae7b8b144397fdb12d20f592e5e7c1a64d", "rev": "f8e2ebd66d097614d51a56a755450d4ae1632df1",
"type": "github" "type": "github"
}, },
"original": { "original": {

File diff suppressed because it is too large Load diff

View file

@ -6,88 +6,62 @@
extern "C" { extern "C" {
#endif #endif
struct ggml_backend; typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
struct ggml_backend_buffer; typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
struct ggml_backend_buffer_type; typedef struct ggml_backend * ggml_backend_t;
//
// Legacy API
//
typedef struct ggml_allocr * ggml_allocr_t;
// initialize allocator for use with CPU backend only
GGML_API ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment);
GGML_API ggml_allocr_t ggml_allocr_new_measure(size_t alignment);
// initialize allocator for use with ggml-backend
GGML_API ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer);
GGML_API ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
GGML_API ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend);
GGML_API struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc);
// tell the allocator to parse nodes following the order described in the list
// you should call this if your graph are optimized to execute out-of-order
GGML_API void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n);
GGML_API void ggml_allocr_free (ggml_allocr_t alloc);
GGML_API bool ggml_allocr_is_measure (ggml_allocr_t alloc);
GGML_API void ggml_allocr_reset (ggml_allocr_t alloc);
GGML_API void ggml_allocr_alloc (ggml_allocr_t alloc, struct ggml_tensor * tensor);
GGML_API size_t ggml_allocr_max_size (ggml_allocr_t alloc);
GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph);
//
// ggml-backend v2 API
//
// Separate tensor and graph allocator objects
// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
// The original API is kept as a wrapper around the new API
// Tensor allocator // Tensor allocator
typedef struct ggml_tallocr * ggml_tallocr_t; typedef struct ggml_tallocr * ggml_tallocr_t;
GGML_API ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment); GGML_API ggml_tallocr_t ggml_tallocr_new(ggml_backend_buffer_t buffer);
GGML_API ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment); GGML_API void ggml_tallocr_free(ggml_tallocr_t talloc);
GGML_API ggml_tallocr_t ggml_tallocr_new_from_buft(struct ggml_backend_buffer_type * buft, size_t size); GGML_API void ggml_tallocr_alloc(ggml_tallocr_t talloc, struct ggml_tensor * tensor);
GGML_API ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
GGML_API ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer);
GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_buft(struct ggml_backend_buffer_type * buft);
GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend);
GGML_API struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t talloc);
GGML_API void ggml_tallocr_free (ggml_tallocr_t talloc);
GGML_API bool ggml_tallocr_is_measure (ggml_tallocr_t talloc);
GGML_API void ggml_tallocr_reset (ggml_tallocr_t talloc);
GGML_API void ggml_tallocr_alloc (ggml_tallocr_t talloc, struct ggml_tensor * tensor);
GGML_API size_t ggml_tallocr_max_size (ggml_tallocr_t talloc);
// Graph allocator // Graph allocator
/*
Example usage:
ggml_gallocr_t galloc = ggml_gallocr_new(ggml_bacckend_cpu_buffer_type());
// optional: create a worst-case graph and reserve the buffers to avoid reallocations
ggml_gallocr_reserve(galloc, build_graph(max_batch));
// allocate the graph
struct ggml_cgraph * graph = build_graph(batch);
ggml_gallocr_alloc_graph(galloc, graph);
printf("compute buffer size: %zu bytes\n", ggml_gallocr_get_buffer_size(galloc, 0));
// evaluate the graph
ggml_backend_graph_compute(backend, graph);
*/
// special tensor flags for use with the graph allocator:
// ggml_set_input(): all input tensors are allocated at the beginning of the graph in non-overlapping addresses
// ggml_set_output(): output tensors are never freed and never overwritten
typedef struct ggml_gallocr * ggml_gallocr_t; typedef struct ggml_gallocr * ggml_gallocr_t;
GGML_API ggml_gallocr_t ggml_gallocr_new(void); GGML_API ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft);
GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc); GGML_API ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs);
GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc);
GGML_API void ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n); // pre-allocate buffers from a measure graph - does not allocate or modify the graph
GGML_API size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph); // call with a worst-case graph to avoid buffer reallocations
// not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed
// returns false if the buffer allocation failed
GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
GGML_API bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids);
// Allocate tensors from the allocators given by the hash table // automatic reallocation if the topology changes when using a single buffer
GGML_API void ggml_gallocr_alloc_graph_n( // returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers)
ggml_gallocr_t galloc, GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
struct ggml_cgraph * graph,
struct ggml_hash_set hash_set,
ggml_tallocr_t * hash_node_talloc);
GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
// Utils // Utils
// Create a buffer and allocate all the tensors in a ggml_context // Create a buffer and allocate all the tensors in a ggml_context
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft); GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend); GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);
#ifdef __cplusplus #ifdef __cplusplus
} }

File diff suppressed because it is too large Load diff

View file

@ -83,8 +83,9 @@ extern "C" {
GGML_API ggml_backend_t ggml_backend_cpu_init(void); GGML_API ggml_backend_t ggml_backend_cpu_init(void);
GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend); GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend);
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads); GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
// Create a backend buffer from an existing pointer // Create a backend buffer from an existing pointer
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
@ -129,11 +130,7 @@ extern "C" {
// in build_graph: // in build_graph:
build_graph(...) { build_graph(...) {
// allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer) // manually assign nodes to a backend (optional, should not be needed in most cases)
alloc_cpu = ggml_backend_sched_get_allocr(sched, backend_cpu);
ggml_allocr_alloc(alloc_cpu, tensor);
// manually assigning nodes to a backend (optional, shouldn't be needed in most cases)
struct ggml_tensor * node = ggml_mul_mat(ctx, ...); struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
ggml_backend_sched_set_node_backend(sched, node, backend_gpu); ggml_backend_sched_set_node_backend(sched, node, backend_gpu);
} }
@ -163,20 +160,19 @@ extern "C" {
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size); GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
// Initialize backend buffers from a measure graph // Initialize backend buffers from a measure graph
GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
// Get the number of splits of the last graph // Get the number of splits of the last graph
GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched); GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
GGML_API ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend); GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend);
GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
GGML_API ggml_backend_t ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); GGML_API ggml_backend_t ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
// Allocate and compute graph on the backend scheduler // Allocate and compute graph on the backend scheduler
GGML_API void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph); GGML_API bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
// Reset all assignments and allocators - must be called before using the sched allocators to allocate inputs // Reset all assignments and allocators - must be called before changing the node backends
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched); GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
// Set a callback to be called for each resulting node during graph compute // Set a callback to be called for each resulting node during graph compute

View file

@ -150,8 +150,8 @@
#define CUDA_USE_TENSOR_CORES #define CUDA_USE_TENSOR_CORES
#endif #endif
// max batch size to use MMQ kernels when tensor cores are available #define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
#define MMQ_MAX_BATCH_SIZE 32 #define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available
#if defined(GGML_USE_HIPBLAS) #if defined(GGML_USE_HIPBLAS)
#define __CUDA_ARCH__ 1300 #define __CUDA_ARCH__ 1300
@ -5310,51 +5310,59 @@ template <bool need_check> static __global__ void
#endif // __CUDA_ARCH__ >= CC_VOLTA #endif // __CUDA_ARCH__ >= CC_VOLTA
} }
#define MMVQ_NWARPS_NVIDIA 4 template <int ncols_y, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
#define MMVQ_NWARPS_AMD_RDNA2 1
#define MMVQ_NWARPS_AMD_OLD 4
template <int nwarps, int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1) // tells the compiler to use as many registers as it wants // tell the compiler to use as many registers as it wants, see nwarps definition below
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static __global__ void mul_mat_vec_q( static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) { const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par; #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
constexpr int nwarps = 1;
constexpr int rows_per_cuda_block = 1;
#else
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
const int row = blockIdx.x; const int row0 = rows_per_cuda_block*blockIdx.x;
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_row_x = ncols_x / qk; const int blocks_per_col_y = nrows_y / QK8_1;
const int blocks_per_col_y = nrows_y / QK8_1; constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
const int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
// partial sum for each thread // partial sum for each thread
float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f}; float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
const block_q_t * x = (const block_q_t *) vx; const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy; const block_q8_1 * y = (const block_q8_1 *) vy;
for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter) { for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
const int ibx = row*blocks_per_row_x + i; // x block index const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx // x block quant index when casting the quants to int
const int kqs = vdr * (tid % (qi/vdr));
const int iqs = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_y; ++j) { for (int j = 0; j < ncols_y; ++j) {
tmp[j] += vec_dot_q_cuda(&x[ibx], &y[j*blocks_per_col_y + iby], iqs); #pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp[j][i] += vec_dot_q_cuda(
&x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
}
} }
} }
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y_template != 0 ? ncols_y_template : 8][WARP_SIZE]; __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
if (threadIdx.y > 0) { if (threadIdx.y > 0) {
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_y; ++j) { for (int j = 0; j < ncols_y; ++j) {
tmp_shared[threadIdx.y-1][j][threadIdx.x] = tmp[j]; #pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
}
} }
} }
__syncthreads(); __syncthreads();
@ -5366,13 +5374,16 @@ static __global__ void mul_mat_vec_q(
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_y; ++j) { for (int j = 0; j < ncols_y; ++j) {
#pragma unroll #pragma unroll
for (int i = 0; i < nwarps-1; ++i) { for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp[j] += tmp_shared[i][j][threadIdx.x]; #pragma unroll
for (int l = 0; l < nwarps-1; ++l) {
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
}
tmp[j][i] = warp_reduce_sum(tmp[j][i]);
} }
tmp[j] = warp_reduce_sum(tmp[j]);
if (threadIdx.x == 0) { if (threadIdx.x < rows_per_cuda_block) {
dst[j*nrows_dst + row] = tmp[j]; dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
} }
} }
} }
@ -6851,65 +6862,75 @@ static void mul_mat_vec_q_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
GGML_ASSERT(ncols_x % qk == 0); GGML_ASSERT(ncols_x % qk == 0);
GGML_ASSERT(ncols_y <= 4); GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
int nwarps; int64_t nwarps = 1;
if (g_device_caps[id].cc >= CC_OFFSET_AMD) { int64_t rows_per_cuda_block = 1;
nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD;
} else {
nwarps = MMVQ_NWARPS_NVIDIA;
}
const dim3 block_nums(nrows_x, 1, 1); if (g_device_caps[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
switch(ncols_y) {
case 1:
nwarps = 4;
rows_per_cuda_block = 1;
break;
case 2:
case 3:
case 4:
nwarps = 4;
rows_per_cuda_block = 2;
break;
case 5:
case 6:
case 7:
case 8:
nwarps = 2;
rows_per_cuda_block = 2;
break;
default:
GGML_ASSERT(false);
break;
}
}
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
const dim3 block_nums(nblocks, 1, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1);
switch (nwarps) { switch (ncols_y) {
case 1: switch(ncols_y) { case 1:
case 1: mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
mul_mat_vec_q<1, 1, qk, qi, block_q_t, vdr, vec_dot> <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); break;
break; case 2:
case 2: mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
mul_mat_vec_q<1, 2, qk, qi, block_q_t, vdr, vec_dot> <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); break;
break; case 3:
case 3: mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
mul_mat_vec_q<1, 3, qk, qi, block_q_t, vdr, vec_dot> <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); break;
break; case 4:
case 4: mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
mul_mat_vec_q<1, 4, qk, qi, block_q_t, vdr, vec_dot> <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); break;
break; case 5:
default: mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
GGML_ASSERT(false); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break; break;
} break; case 6:
case 4: switch(ncols_y) { mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
case 1: <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
mul_mat_vec_q<4, 1, qk, qi, block_q_t, vdr, vec_dot> break;
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); case 7:
break; mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
case 2: <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
mul_mat_vec_q<4, 2, qk, qi, block_q_t, vdr, vec_dot> break;
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); case 8:
break; mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
case 3: <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
mul_mat_vec_q<4, 3, qk, qi, block_q_t, vdr, vec_dot> break;
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
case 4:
mul_mat_vec_q<4, 4, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
default:
GGML_ASSERT(false);
break;
} break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
break; break;
@ -9735,7 +9756,7 @@ static __global__ void k_compute_batched_ptrs(
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3; ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
} }
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_mul_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_transposed(src1));
@ -9893,39 +9914,69 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
int64_t min_compute_capability = INT_MAX; int64_t min_compute_capability = INT_MAX;
bool any_pascal_with_slow_fp16 = false;
if (split) { if (split) {
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
auto & tensor_split = buft_ctx->tensor_split; auto & tensor_split = buft_ctx->tensor_split;
for (int id = 0; id < g_device_count; ++id) { for (int id = 0; id < g_device_count; ++id) {
if (min_compute_capability > g_device_caps[id].cc && tensor_split[id] < (id + 1 < g_device_count ? tensor_split[id + 1] : 1.0f)) { // skip devices that are not going to do any work:
if (tensor_split[id] >= (id + 1 < g_device_count ? tensor_split[id + 1] : 1.0f)) {
continue;
}
if (min_compute_capability > g_device_caps[id].cc) {
min_compute_capability = g_device_caps[id].cc; min_compute_capability = g_device_caps[id].cc;
} }
if (g_device_caps[id].cc == 610) {
any_pascal_with_slow_fp16 = true;
}
} }
} else { } else {
min_compute_capability = g_device_caps[g_main_device].cc; min_compute_capability = g_device_caps[g_main_device].cc;
any_pascal_with_slow_fp16 = g_device_caps[g_main_device].cc == 610;
} }
// check data types and tensor shapes for custom matrix multiplication kernels:
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
bool use_mul_mat_q = ggml_cuda_supports_mmq(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
const bool fp16_performance_good = min_compute_capability >= CC_RDNA1; const bool fp16_performance_good = min_compute_capability >= CC_RDNA1;
bool use_mul_mat_q = ggml_is_quantized(src0->type);
#ifdef CUDA_USE_TENSOR_CORES #ifdef CUDA_USE_TENSOR_CORES
use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3; use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3;
#endif // CUDA_USE_TENSOR_CORES #endif // CUDA_USE_TENSOR_CORES
#else #else
const bool fp16_performance_good = min_compute_capability >= CC_VOLTA; // fp16 performance is good on Volta or newer and on P100 (compute capability 6.0)
bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type); const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16;
// mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1
use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A;
use_mul_mat_q = use_mul_mat_q && min_compute_capability >= MIN_CC_DP4A;
#ifdef CUDA_USE_TENSOR_CORES #ifdef CUDA_USE_TENSOR_CORES
// when tensor cores are available, use them for large batch size // when tensor cores are available, use them for large batch size
// ref: https://github.com/ggerganov/llama.cpp/pull/3776 // ref: https://github.com/ggerganov/llama.cpp/pull/3776
use_mul_mat_q = use_mul_mat_q && !(fp16_performance_good && src1->ne[1] > MMQ_MAX_BATCH_SIZE); use_mul_mat_q = use_mul_mat_q && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
#endif // CUDA_USE_TENSOR_CORES #endif // CUDA_USE_TENSOR_CORES
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type); // if mmvq is available it's a better choice than dmmv:
#ifndef GGML_CUDA_FORCE_DMMV
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
#endif // GGML_CUDA_FORCE_DMMV
// debug helpers // debug helpers
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
@ -9943,33 +9994,15 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
ggml_cuda_mul_mat_vec_nc(src0, src1, dst); ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
} else if (!split && all_on_device && fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { } else if (!split && all_on_device && fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
// KQ + KQV multi-batch // KQ + KQV multi-batch
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); ggml_cuda_mul_mat_batched_cublas(src0, src1, dst);
} else if (src0->type == GGML_TYPE_F32) { } else if (use_dequantize_mul_mat_vec) {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { } else if (use_mul_mat_vec_q) {
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
#ifdef GGML_CUDA_FORCE_DMMV } else if (use_mul_mat_q) {
const bool use_mul_mat_vec_q = false; ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
#else
const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
#endif // GGML_CUDA_FORCE_DMMV
if (use_mul_mat_vec_q) {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
} else {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
}
} else {
if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32) {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
} else if (use_mul_mat_q) {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
} else {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
}
}
} else { } else {
GGML_ASSERT(false); ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
} }
} }

View file

@ -687,6 +687,7 @@ static bool ggml_metal_graph_compute(
struct ggml_metal_context * ctx, struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) { struct ggml_cgraph * gf) {
@autoreleasepool {
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
edesc.dispatchType = MTLDispatchTypeSerial; edesc.dispatchType = MTLDispatchTypeSerial;
@ -2272,6 +2273,7 @@ static bool ggml_metal_graph_compute(
[[MTLCaptureManager sharedCaptureManager] stopCapture]; [[MTLCaptureManager sharedCaptureManager] stopCapture];
} }
}
return true; return true;
} }

View file

@ -49,6 +49,8 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define UNUSED GGML_UNUSED
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
@ -268,6 +270,17 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
#ifdef _MSC_VER
#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
#else
#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
#endif
#if !defined(__aarch64__) #if !defined(__aarch64__)
// 64-bit compatibility // 64-bit compatibility
@ -3666,15 +3679,92 @@ static inline __m128i get_scale_shuffle(int i) {
} }
#endif #endif
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
assert(n % qk == 0); assert(n % qk == 0);
#if defined(__ARM_FEATURE_MATMUL_INT8)
assert((nrc == 2) || (nrc == 1));
#else
assert(nrc == 1);
#endif
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q4_0 * restrict x = vx; const block_q4_0 * restrict x = vx;
const block_q8_0 * restrict y = vy; const block_q8_0 * restrict y = vy;
#if defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q4_0 * restrict vx0 = vx;
const block_q4_0 * restrict vx1 = vx + bx;
const block_q8_0 * restrict vy0 = vy;
const block_q8_0 * restrict vy1 = vy + by;
float32x4_t sumv0 = vdupq_n_f32(0.0f);
for (int i = 0; i < nb; i++) {
const block_q4_0 * restrict b_x0 = &vx0[i];
const block_q4_0 * restrict b_x1 = &vx1[i];
const block_q8_0 * restrict b_y0 = &vy0[i];
const block_q8_0 * restrict b_y1 = &vy1[i];
const uint8x16_t m4b = vdupq_n_u8(0x0F);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// sub 8
const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);
const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);
const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);
const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);
// load y
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
l1, r1)), l2, r2)), l3, r3))), scale);
}
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
vst1_f32(s, vget_low_f32(sumv2));
vst1_f32(s + bs, vget_high_f32(sumv2));
return;
}
#endif
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f);
@ -3956,15 +4046,93 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
#endif #endif
} }
void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_1; const int qk = QK8_1;
const int nb = n / qk; const int nb = n / qk;
assert(n % qk == 0); assert(n % qk == 0);
#if defined(__ARM_FEATURE_MATMUL_INT8)
assert((nrc == 2) || (nrc == 1));
#else
assert(nrc == 1);
#endif
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q4_1 * restrict x = vx; const block_q4_1 * restrict x = vx;
const block_q8_1 * restrict y = vy; const block_q8_1 * restrict y = vy;
#if defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q4_1 * restrict vx0 = vx;
const block_q4_1 * restrict vx1 = vx + bx;
const block_q8_1 * restrict vy0 = vy;
const block_q8_1 * restrict vy1 = vy + by;
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t summs0 = vdupq_n_f32(0.0f);
for (int i = 0; i < nb; i++) {
const block_q4_1 * restrict b_x0 = &vx0[i];
const block_q4_1 * restrict b_x1 = &vx1[i];
const block_q8_1 * restrict b_y0 = &vy0[i];
const block_q8_1 * restrict b_y1 = &vy1[i];
float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * b_y0->s,
GGML_FP16_TO_FP32(b_x1->m) * b_y0->s,
GGML_FP16_TO_FP32(b_x0->m) * b_y1->s,
GGML_FP16_TO_FP32(b_x1->m) * b_y1->s};
summs0 += summs_t;
const uint8x16_t m4b = vdupq_n_u8(0x0F);
const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
// 4-bit -> 8-bit
const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// load y
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
// mmla into int32x4_t
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
l1, r1)), l2, r2)), l3, r3))), scale);
}
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
sumv2 = sumv2 + summs0;
vst1_f32(s, vget_low_f32(sumv2));
vst1_f32(s + bs, vget_high_f32(sumv2));
return;
}
#endif
// TODO: add WASM SIMD // TODO: add WASM SIMD
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv0 = vdupq_n_f32(0.0f);
@ -4096,12 +4264,17 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
#endif #endif
} }
void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
assert(n % qk == 0); assert(n % qk == 0);
assert(qk == QK5_0); assert(qk == QK5_0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q5_0 * restrict x = vx; const block_q5_0 * restrict x = vx;
const block_q8_0 * restrict y = vy; const block_q8_0 * restrict y = vy;
@ -4382,12 +4555,17 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
#endif #endif
} }
void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_1; const int qk = QK8_1;
const int nb = n / qk; const int nb = n / qk;
assert(n % qk == 0); assert(n % qk == 0);
assert(qk == QK5_1); assert(qk == QK5_1);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q5_1 * restrict x = vx; const block_q5_1 * restrict x = vx;
const block_q8_1 * restrict y = vy; const block_q8_1 * restrict y = vy;
@ -4681,15 +4859,79 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
#endif #endif
} }
void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
assert(n % qk == 0); assert(n % qk == 0);
#if defined(__ARM_FEATURE_MATMUL_INT8)
assert((nrc == 2) || (nrc == 1));
#else
assert(nrc == 1);
#endif
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q8_0 * restrict x = vx; const block_q8_0 * restrict x = vx;
const block_q8_0 * restrict y = vy; const block_q8_0 * restrict y = vy;
#if defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q8_0 * restrict vx0 = vx;
const block_q8_0 * restrict vx1 = vx + bx;
const block_q8_0 * restrict vy0 = vy;
const block_q8_0 * restrict vy1 = vy + by;
float32x4_t sumv0 = vdupq_n_f32(0.0f);
for (int i = 0; i < nb; i++) {
const block_q8_0 * restrict b_x0 = &vx0[i];
const block_q8_0 * restrict b_y0 = &vy0[i];
const block_q8_0 * restrict b_x1 = &vx1[i];
const block_q8_0 * restrict b_y1 = &vy1[i];
const int8x16_t x0_l = vld1q_s8(b_x0->qs);
const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
const int8x16_t x1_l = vld1q_s8(b_x1->qs);
const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
// load y
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
l1, r1)), l2, r2)), l3, r3))), scale);
}
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
vst1_f32(s, vget_low_f32(sumv2));
vst1_f32(s + bs, vget_high_f32(sumv2));
return;
}
#endif
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f);
@ -4784,7 +5026,12 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
} }
#if QK_K == 256 #if QK_K == 256
void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q2_K * restrict x = vx; const block_q2_K * restrict x = vx;
const block_q8_K * restrict y = vy; const block_q8_K * restrict y = vy;
@ -5160,7 +5407,12 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
#else #else
void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q2_K * restrict x = vx; const block_q2_K * restrict x = vx;
const block_q8_K * restrict y = vy; const block_q8_K * restrict y = vy;
@ -5418,8 +5670,13 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
#endif #endif
#if QK_K == 256 #if QK_K == 256
void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const uint32_t kmask1 = 0x03030303; const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f; const uint32_t kmask2 = 0x0f0f0f0f;
@ -5938,8 +6195,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
#else #else
void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q3_K * restrict x = vx; const block_q3_K * restrict x = vx;
const block_q8_K * restrict y = vy; const block_q8_K * restrict y = vy;
@ -6281,8 +6543,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
#endif #endif
#if QK_K == 256 #if QK_K == 256
void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q4_K * restrict x = vx; const block_q4_K * restrict x = vx;
const block_q8_K * restrict y = vy; const block_q8_K * restrict y = vy;
@ -6637,8 +6904,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
#endif #endif
} }
#else #else
void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q4_K * restrict x = vx; const block_q4_K * restrict x = vx;
const block_q8_K * restrict y = vy; const block_q8_K * restrict y = vy;
@ -6880,8 +7152,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
#endif #endif
#if QK_K == 256 #if QK_K == 256
void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q5_K * restrict x = vx; const block_q5_K * restrict x = vx;
const block_q8_K * restrict y = vy; const block_q8_K * restrict y = vy;
@ -7300,8 +7577,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
#else #else
void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q5_K * restrict x = vx; const block_q5_K * restrict x = vx;
const block_q8_K * restrict y = vy; const block_q8_K * restrict y = vy;
@ -7566,8 +7848,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
#if QK_K == 256 #if QK_K == 256
void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q6_K * restrict x = vx; const block_q6_K * restrict x = vx;
const block_q8_K * restrict y = vy; const block_q8_K * restrict y = vy;
@ -7998,8 +8285,13 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
#else #else
void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q6_K * restrict x = vx; const block_q6_K * restrict x = vx;
const block_q8_K * restrict y = vy; const block_q8_K * restrict y = vy;
@ -8328,8 +8620,13 @@ static const int8_t keven_signs_q2xs[1024] = {
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
}; };
void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_iq2_xxs * restrict x = vx; const block_iq2_xxs * restrict x = vx;
const block_q8_K * restrict y = vy; const block_q8_K * restrict y = vy;
@ -8451,8 +8748,13 @@ void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * res
#endif #endif
} }
void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_iq2_xs * restrict x = vx; const block_iq2_xs * restrict x = vx;
const block_q8_K * restrict y = vy; const block_q8_K * restrict y = vy;
@ -8671,8 +8973,13 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
} }
// TODO // TODO
void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_iq3_xxs * restrict x = vx; const block_iq3_xxs * restrict x = vx;
const block_q8_K * restrict y = vy; const block_q8_K * restrict y = vy;
@ -8698,10 +9005,10 @@ void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * res
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
q8b = ggml_vld1q_s8_x4(q8); q8 += 64; q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t); memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
const uint32x4_t aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]}; const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]);
const uint32x4_t aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]}; const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]);
const uint32x4_t aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]}; const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]);
const uint32x4_t aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]}; const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]);
q3 += 16; q3 += 16;
q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127)))); q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127)))); q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));

View file

@ -245,20 +245,20 @@ void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_
void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
// Dot product // Dot product
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy); void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy); void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy); void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy); void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy); void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy); void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy); void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy); void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy); void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy); void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy); void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy); void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy); void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
// //
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")

View file

@ -11578,11 +11578,8 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
} }
char * dst_ptr = (char *) dst; char * dst_ptr = (char *) dst;
const int64_t ne0 = src->ne[0]; GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);
const int64_t nb0 = src->nb[0]; GGML_TENSOR_LOCALS(int64_t, nb, src, nb);
const int64_t nb1 = src->nb[1];
const int64_t nb2 = src->nb[2];
const int64_t nb3 = src->nb[3];
const enum ggml_type type = src->type; const enum ggml_type type = src->type;
const int64_t ts = ggml_type_size(type); const int64_t ts = ggml_type_size(type);
const int64_t bs = ggml_blck_size(type); const int64_t bs = ggml_blck_size(type);
@ -12426,9 +12423,7 @@ inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1,
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0]; GGML_TENSOR_LOCALS_3(int64_t, ne0, src0, ne);
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t nrows = ggml_nrows(src0); const int64_t nrows = ggml_nrows(src0);
//const int n_past = ((int32_t *) dst->op_params)[0]; //const int n_past = ((int32_t *) dst->op_params)[0];
@ -12758,15 +12753,9 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
ggml_sycl_op_mul_mat_t op, ggml_sycl_op_mul_mat_t op,
const bool convert_src1_to_q8_1) try { const bool convert_src1_to_q8_1) try {
const int64_t ne00 = src0->ne[0]; GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t ne10 = src1->ne[0]; GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src1->ne[3];
const int64_t nrows1 = ggml_nrows(src1); const int64_t nrows1 = ggml_nrows(src1);
GGML_ASSERT(ne03 == ne13); GGML_ASSERT(ne03 == ne13);
@ -13337,23 +13326,13 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00); GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t nb01 = src0->nb[1]; GGML_TENSOR_LOCALS(int64_t, nb0, src0, nb);
const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02);
const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03);
const int64_t ne10 = src1->ne[0]; GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src1->ne[3];
const int64_t nb11 = src1->nb[1]; GGML_TENSOR_LOCALS(int64_t, nb1, src1, nb);
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
const int64_t ne1 = ggml_nelements(src1); const int64_t ne1 = ggml_nelements(src1);
const int64_t ne = ggml_nelements(dst); const int64_t ne = ggml_nelements(dst);
@ -13655,23 +13634,15 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
GGML_ASSERT(src00->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(src00->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32);
const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00); GGML_TENSOR_LOCALS(int64_t, ne0, src00, ne);
const int64_t ne01 = src00->ne[1];
const int64_t ne02 = src00->ne[2];
const int64_t ne03 = src00->ne[3];
//const int64_t nb01 = src00->nb[1]; //const int64_t nb01 = src00->nb[1];
const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02); GGML_TENSOR_LOCALS(int64_t, nb0, src00, nb);
const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03);
const int64_t ne10 = src1->ne[0]; GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src1->ne[3];
GGML_TENSOR_LOCALS(int64_t, nb1, src1, nb);
//const int64_t nb11 = src1->nb[1]; //const int64_t nb11 = src1->nb[1];
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
const int64_t ne1 = ggml_nelements(src1); const int64_t ne1 = ggml_nelements(src1);
const int64_t ne = ggml_nelements(dst); const int64_t ne = ggml_nelements(dst);
@ -13940,25 +13911,7 @@ static void ggml_sycl_cpy(const ggml_tensor *src0, const ggml_tensor *src1,
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
const int64_t ne00 = src0->ne[0]; GGML_TENSOR_BINARY_OP_LOCALS;
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t nb00 = src0->nb[0];
const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2];
const int64_t nb03 = src0->nb[3];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t nb10 = src1->nb[0];
const int64_t nb11 = src1->nb[1];
const int64_t nb12 = src1->nb[2];
const int64_t nb13 = src1->nb[3];
SYCL_CHECK(ggml_sycl_set_device(g_main_device)); SYCL_CHECK(ggml_sycl_set_device(g_main_device));
dpct::queue_ptr main_stream = g_syclStreams[g_main_device_index][0]; dpct::queue_ptr main_stream = g_syclStreams[g_main_device_index][0];

View file

@ -27,6 +27,7 @@
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
#define VK_VENDOR_ID_AMD 0x1002 #define VK_VENDOR_ID_AMD 0x1002
#define VK_VENDOR_ID_APPLE 0x106b
#define VK_VENDOR_ID_INTEL 0x8086 #define VK_VENDOR_ID_INTEL 0x8086
#define VK_VENDOR_ID_NVIDIA 0x10de #define VK_VENDOR_ID_NVIDIA 0x10de
@ -744,6 +745,8 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
} }
if (memory_type_index >= mem_props.memoryTypeCount) { if (memory_type_index >= mem_props.memoryTypeCount) {
ctx->device.lock()->device.destroyBuffer(buf->buffer);
buf->size = 0;
throw vk::OutOfDeviceMemoryError("No suitable memory type found"); throw vk::OutOfDeviceMemoryError("No suitable memory type found");
} }
@ -2032,18 +2035,100 @@ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ct
return ctx->pipeline_matmul_f32_aligned_l.align; return ctx->pipeline_matmul_f32_aligned_l.align;
} }
static vk_pipeline* ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) { static vk_pipeline* ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_guess_matmul_pipeline(" << bit16_x << ", " << bit16_y << ", " << m << ", " << n << ", " << aligned << ")";
#endif
if (bit16_x && bit16_y) { if (bit16_x && bit16_y) {
if (ctx->device.lock()->vendor_id == VK_VENDOR_ID_INTEL || m <= 32 || n <= 32) { if (m <= 32 || n <= 32) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << " S" << std::endl; std::cerr << " S" << std::endl;
#endif #endif
return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s; return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
} }
if (ctx->device.lock()->subgroup_size == 64 || m <= 64 || n <= 64) { #ifdef GGML_VULKAN_DEBUG
std::cerr << " M" << std::endl;
#endif
return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
}
if (bit16_x && !bit16_y) {
if (m <= 32 || n <= 32) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " S" << std::endl;
#endif
return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
}
#ifdef GGML_VULKAN_DEBUG
std::cerr << " M" << std::endl;
#endif
return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
}
if (!bit16_x && bit16_y) {
GGML_ASSERT(false);
}
if (m <= 32 || n <= 32) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " S" << std::endl;
#endif
return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
}
#ifdef GGML_VULKAN_DEBUG
std::cerr << " M" << std::endl;
#endif
return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
}
static vk_pipeline* ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " M" << std::endl;
#endif
if (bit16_x && bit16_y) {
return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
}
if (bit16_x && !bit16_y) {
return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
}
if (!bit16_x && bit16_y) {
GGML_ASSERT(false);
}
return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
}
static vk_pipeline* ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " S" << std::endl;
#endif
if (bit16_x && bit16_y) {
return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
}
if (bit16_x && !bit16_y) {
return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
}
if (!bit16_x && bit16_y) {
GGML_ASSERT(false);
}
return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
}
static vk_pipeline* ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_guess_matmul_pipeline(" << bit16_x << ", " << bit16_y << ", " << m << ", " << n << ", " << aligned << ")";
#endif
switch (ctx->device.lock()->vendor_id) {
case VK_VENDOR_ID_AMD:
return ggml_vk_guess_matmul_pipeline_amd(ctx, bit16_x, bit16_y, m, n, aligned);
case VK_VENDOR_ID_APPLE:
return ggml_vk_guess_matmul_pipeline_apple(ctx, bit16_x, bit16_y, aligned);
case VK_VENDOR_ID_INTEL:
return ggml_vk_guess_matmul_pipeline_intel(ctx, bit16_x, bit16_y, aligned);
}
if (bit16_x && bit16_y) {
if (m <= 32 || n <= 32) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << " S" << std::endl;
#endif
return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
}
if (m <= 64 || n <= 64) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << " M" << std::endl; std::cerr << " M" << std::endl;
#endif #endif
@ -2055,13 +2140,13 @@ static vk_pipeline* ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
return aligned ? &ctx->pipeline_matmul_f16_aligned_l : &ctx->pipeline_matmul_f16_l; return aligned ? &ctx->pipeline_matmul_f16_aligned_l : &ctx->pipeline_matmul_f16_l;
} }
if (bit16_x && !bit16_y) { if (bit16_x && !bit16_y) {
if (ctx->device.lock()->vendor_id == VK_VENDOR_ID_INTEL || m <= 32 || n <= 32) { if (m <= 32 || n <= 32) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << " S" << std::endl; std::cerr << " S" << std::endl;
#endif #endif
return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s; return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
} }
if (ctx->device.lock()->subgroup_size == 64 || m <= 64 || n <= 64) { if (m <= 64 || n <= 64) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << " M" << std::endl; std::cerr << " M" << std::endl;
#endif #endif
@ -2076,13 +2161,13 @@ static vk_pipeline* ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
GGML_ASSERT(false); GGML_ASSERT(false);
} }
if (ctx->device.lock()->vendor_id == VK_VENDOR_ID_INTEL || m <= 32 || n <= 32) { if (m <= 32 || n <= 32) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << " S" << std::endl; std::cerr << " S" << std::endl;
#endif #endif
return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s; return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
} }
if (ctx->device.lock()->subgroup_size == 64 || m <= 64 || n <= 64) { if (m <= 64 || n <= 64) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << " M" << std::endl; std::cerr << " M" << std::endl;
#endif #endif
@ -3875,7 +3960,7 @@ static ggml_tensor * ggml_vk_find_last_use(const ggml_tensor * node, ggml_cgraph
static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggml_tensor * node){ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggml_tensor * node){
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_ctx->preallocate_buffers_graph(" << node << ")" << std::endl; std::cerr << "ggml_vk_preallocate_buffers_graph(" << node << ")" << std::endl;
#endif #endif
const bool any_on_device = node->backend == GGML_BACKEND_GPU const bool any_on_device = node->backend == GGML_BACKEND_GPU
|| (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
@ -3994,8 +4079,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
return; return;
} }
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_ctx->preallocate_buffers()" << std::endl; std::cerr << "ggml_vk_preallocate_buffers(qx_size: " << ctx->prealloc_size_qx << " qy_size: " << ctx->prealloc_size_qy << " x_size: " << ctx->prealloc_size_x << " y_size: " << ctx->prealloc_size_y << " split_k_size: " << ctx->prealloc_size_split_k << ")" << std::endl;
std::cerr << "qx_size: " << ctx->prealloc_size_qx << " qy_size: " << ctx->prealloc_size_qy << " x_size: " << ctx->prealloc_size_x << " y_size: " << ctx->prealloc_size_y << " split_k_size: " << ctx->prealloc_size_split_k << std::endl;
#endif #endif
#if defined(GGML_VULKAN_RUN_TESTS) #if defined(GGML_VULKAN_RUN_TESTS)
ctx->staging = ggml_vk_create_buffer_check(ctx, 100ul * 1024ul * 1024ul, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached); ctx->staging = ggml_vk_create_buffer_check(ctx, 100ul * 1024ul * 1024ul, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);

194
ggml.c
View file

@ -428,8 +428,8 @@ int64_t ggml_cycles_per_ms(void) {
static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y); static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y); static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
[GGML_TYPE_I8] = { [GGML_TYPE_I8] = {
@ -457,6 +457,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = false, .is_quantized = false,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
.vec_dot_type = GGML_TYPE_F32, .vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
}, },
[GGML_TYPE_F16] = { [GGML_TYPE_F16] = {
.type_name = "f16", .type_name = "f16",
@ -468,6 +469,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row, .from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
.vec_dot_type = GGML_TYPE_F16, .vec_dot_type = GGML_TYPE_F16,
.nrows = 1,
}, },
[GGML_TYPE_Q4_0] = { [GGML_TYPE_Q4_0] = {
.type_name = "q4_0", .type_name = "q4_0",
@ -479,6 +481,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
.vec_dot = ggml_vec_dot_q4_0_q8_0, .vec_dot = ggml_vec_dot_q4_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
#if defined (__ARM_FEATURE_MATMUL_INT8)
.nrows = 2,
#else
.nrows = 1,
#endif
}, },
[GGML_TYPE_Q4_1] = { [GGML_TYPE_Q4_1] = {
.type_name = "q4_1", .type_name = "q4_1",
@ -490,6 +497,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
.vec_dot = ggml_vec_dot_q4_1_q8_1, .vec_dot = ggml_vec_dot_q4_1_q8_1,
.vec_dot_type = GGML_TYPE_Q8_1, .vec_dot_type = GGML_TYPE_Q8_1,
#if defined (__ARM_FEATURE_MATMUL_INT8)
.nrows = 2,
#else
.nrows = 1,
#endif
}, },
[4] = { // GGML_TYPE_Q4_2 [4] = { // GGML_TYPE_Q4_2
.type_name = "DEPRECATED", .type_name = "DEPRECATED",
@ -501,6 +513,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = NULL, .from_float_reference = NULL,
.vec_dot = NULL, .vec_dot = NULL,
.vec_dot_type = GGML_TYPE_COUNT, .vec_dot_type = GGML_TYPE_COUNT,
.nrows = 1,
}, },
[5] = { // GGML_TYPE_Q4_3 [5] = { // GGML_TYPE_Q4_3
.type_name = "DEPRECATED", .type_name = "DEPRECATED",
@ -512,6 +525,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = NULL, .from_float_reference = NULL,
.vec_dot = NULL, .vec_dot = NULL,
.vec_dot_type = GGML_TYPE_COUNT, .vec_dot_type = GGML_TYPE_COUNT,
.nrows = 1,
}, },
[GGML_TYPE_Q5_0] = { [GGML_TYPE_Q5_0] = {
.type_name = "q5_0", .type_name = "q5_0",
@ -523,6 +537,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference,
.vec_dot = ggml_vec_dot_q5_0_q8_0, .vec_dot = ggml_vec_dot_q5_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
}, },
[GGML_TYPE_Q5_1] = { [GGML_TYPE_Q5_1] = {
.type_name = "q5_1", .type_name = "q5_1",
@ -534,6 +549,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference,
.vec_dot = ggml_vec_dot_q5_1_q8_1, .vec_dot = ggml_vec_dot_q5_1_q8_1,
.vec_dot_type = GGML_TYPE_Q8_1, .vec_dot_type = GGML_TYPE_Q8_1,
.nrows = 1,
}, },
[GGML_TYPE_Q8_0] = { [GGML_TYPE_Q8_0] = {
.type_name = "q8_0", .type_name = "q8_0",
@ -545,6 +561,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
.vec_dot = ggml_vec_dot_q8_0_q8_0, .vec_dot = ggml_vec_dot_q8_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
#if defined (__ARM_FEATURE_MATMUL_INT8)
.nrows = 2,
#else
.nrows = 1,
#endif
}, },
[GGML_TYPE_Q8_1] = { [GGML_TYPE_Q8_1] = {
.type_name = "q8_1", .type_name = "q8_1",
@ -554,6 +575,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q8_1, .from_float = quantize_row_q8_1,
.from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference,
.vec_dot_type = GGML_TYPE_Q8_1, .vec_dot_type = GGML_TYPE_Q8_1,
.nrows = 1,
}, },
[GGML_TYPE_Q2_K] = { [GGML_TYPE_Q2_K] = {
.type_name = "q2_K", .type_name = "q2_K",
@ -565,6 +587,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference,
.vec_dot = ggml_vec_dot_q2_K_q8_K, .vec_dot = ggml_vec_dot_q2_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
}, },
[GGML_TYPE_Q3_K] = { [GGML_TYPE_Q3_K] = {
.type_name = "q3_K", .type_name = "q3_K",
@ -576,6 +599,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference,
.vec_dot = ggml_vec_dot_q3_K_q8_K, .vec_dot = ggml_vec_dot_q3_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
}, },
[GGML_TYPE_Q4_K] = { [GGML_TYPE_Q4_K] = {
.type_name = "q4_K", .type_name = "q4_K",
@ -587,6 +611,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference,
.vec_dot = ggml_vec_dot_q4_K_q8_K, .vec_dot = ggml_vec_dot_q4_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
}, },
[GGML_TYPE_Q5_K] = { [GGML_TYPE_Q5_K] = {
.type_name = "q5_K", .type_name = "q5_K",
@ -598,6 +623,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference,
.vec_dot = ggml_vec_dot_q5_K_q8_K, .vec_dot = ggml_vec_dot_q5_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
}, },
[GGML_TYPE_Q6_K] = { [GGML_TYPE_Q6_K] = {
.type_name = "q6_K", .type_name = "q6_K",
@ -609,6 +635,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference, .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference,
.vec_dot = ggml_vec_dot_q6_K_q8_K, .vec_dot = ggml_vec_dot_q6_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
}, },
[GGML_TYPE_IQ2_XXS] = { [GGML_TYPE_IQ2_XXS] = {
.type_name = "iq2_xxs", .type_name = "iq2_xxs",
@ -620,6 +647,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = NULL, .from_float_reference = NULL,
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K, .vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
}, },
[GGML_TYPE_IQ2_XS] = { [GGML_TYPE_IQ2_XS] = {
.type_name = "iq2_xs", .type_name = "iq2_xs",
@ -631,6 +659,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = NULL, .from_float_reference = NULL,
.vec_dot = ggml_vec_dot_iq2_xs_q8_K, .vec_dot = ggml_vec_dot_iq2_xs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
}, },
[GGML_TYPE_IQ3_XXS] = { [GGML_TYPE_IQ3_XXS] = {
.type_name = "iq3_xxs", .type_name = "iq3_xxs",
@ -642,6 +671,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference, .from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference,
.vec_dot = ggml_vec_dot_iq3_xxs_q8_K, .vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
}, },
[GGML_TYPE_Q8_K] = { [GGML_TYPE_Q8_K] = {
.type_name = "q8_K", .type_name = "q8_K",
@ -1212,7 +1242,13 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
#ifdef GGML_SIMD #ifdef GGML_SIMD
float sumf = 0.0f; float sumf = 0.0f;
const int np = (n & ~(GGML_F32_STEP - 1)); const int np = (n & ~(GGML_F32_STEP - 1));
@ -1249,7 +1285,13 @@ static void ggml_vec_dot_f32(const int n, float * restrict s, const float * rest
*s = sumf; *s = sumf;
} }
static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
ggml_float sumf = 0.0; ggml_float sumf = 0.0;
#if defined(GGML_SIMD) #if defined(GGML_SIMD)
@ -1455,7 +1497,7 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
#endif #endif
} }
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
@ -2631,7 +2673,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
/*.nb =*/ { 0, 0, 0, 0 }, /*.nb =*/ { 0, 0, 0, 0 },
/*.op =*/ GGML_OP_NONE, /*.op =*/ GGML_OP_NONE,
/*.op_params =*/ { 0 }, /*.op_params =*/ { 0 },
/*.is_param =*/ false, /*.flags =*/ 0,
/*.grad =*/ NULL, /*.grad =*/ NULL,
/*.src =*/ { NULL }, /*.src =*/ { NULL },
/*.perf_runs =*/ 0, /*.perf_runs =*/ 0,
@ -6533,7 +6575,7 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
void ggml_set_param( void ggml_set_param(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * tensor) { struct ggml_tensor * tensor) {
tensor->is_param = true; tensor->flags |= GGML_TENSOR_FLAG_PARAM;
GGML_ASSERT(tensor->grad == NULL); GGML_ASSERT(tensor->grad == NULL);
tensor->grad = ggml_dup_tensor(ctx, tensor); tensor->grad = ggml_dup_tensor(ctx, tensor);
@ -10016,6 +10058,7 @@ static void ggml_compute_forward_mul_mat(
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
int64_t const vec_dot_num_rows = type_traits[type].nrows;
GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11); GGML_ASSERT(ne1 == ne11);
@ -10183,12 +10226,23 @@ static void ggml_compute_forward_mul_mat(
const int64_t blck_0 = 16; const int64_t blck_0 = 16;
const int64_t blck_1 = 16; const int64_t blck_1 = 16;
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
int64_t nrc = vec_dot_num_rows;
// TODO: currently the mmla kernels support only even numbered rows/cols.
// this check can be removed once they are extended to support odd numbered rows/cols too
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
nrc = 1;
}
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
// attempt to reduce false-sharing (does not seem to make a difference) // attempt to reduce false-sharing (does not seem to make a difference)
float tmp[16]; // 16 * 2, accounting for mmla kernels
float tmp[32];
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) {
const int64_t i13 = (ir1/(ne12*ne1)); const int64_t i13 = (ir1/(ne12*ne1));
const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1; const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1); const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
@ -10211,17 +10265,19 @@ static void ggml_compute_forward_mul_mat(
(src1_cont || src1->type != vec_dot_type (src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
: (i11*nb11 + i12*nb12 + i13*nb13)); : (i11*nb11 + i12*nb12 + i13*nb13));
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//} //}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) {
vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col); vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc);
}
for (int cn = 0; cn < nrc; ++cn) {
memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
} }
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
} }
} }
} }
@ -10410,7 +10466,7 @@ static void ggml_compute_forward_mul_mat_id(
//} //}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col); vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0*nb01, 0, src1_col, 0, 1);
} }
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
} }
@ -11592,7 +11648,7 @@ static void ggml_compute_forward_soft_max_back_f32(
// linear runtime, no additional memory // linear runtime, no additional memory
float dot_y_dy = 0; float dot_y_dy = 0;
ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy); ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
ggml_vec_cpy_f32 (nc, dx, dy); ggml_vec_cpy_f32 (nc, dx, dy);
ggml_vec_acc1_f32(nc, dx, -dot_y_dy); ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
ggml_vec_mul_f32 (nc, dx, dx, y); ggml_vec_mul_f32 (nc, dx, dx, y);
@ -12393,9 +12449,9 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
const int i1n = i10*ne11; const int i1n = i10*ne11;
for (int i00 = 0; i00 < ne00; i00++) { for (int i00 = 0; i00 < ne00; i00++) {
float v = 0; float v = 0;
ggml_vec_dot_f16(ne02, &v, ggml_vec_dot_f16(ne02, &v, 0,
(ggml_fp16_t *) wdata_src + i1n, (ggml_fp16_t *) wdata_src + i1n, 0,
(ggml_fp16_t *) wdata_kernel + i00*ne02); (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
dst_data[i10*s0 + i00] += v; dst_data[i10*s0 + i00] += v;
} }
} }
@ -12490,9 +12546,9 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
const int i1n = i10*ne11; const int i1n = i10*ne11;
for (int i00 = 0; i00 < ne00; i00++) { for (int i00 = 0; i00 < ne00; i00++) {
float v = 0; float v = 0;
ggml_vec_dot_f32(ne02, &v, ggml_vec_dot_f32(ne02, &v, 0,
wdata_src + i1n, wdata_src + i1n, 0,
wdata_kernel + i00*ne02); wdata_kernel + i00*ne02, 0, 1);
dst_data[i10*s0 + i00] += v; dst_data[i10*s0 + i00] += v;
} }
} }
@ -12807,9 +12863,9 @@ static void ggml_compute_forward_conv_transpose_2d(
for (int i01 = 0; i01 < ne01; i01++) { for (int i01 = 0; i01 < ne01; i01++) {
for (int i00 = 0; i00 < ne00; i00++) { for (int i00 = 0; i00 < ne00; i00++) {
float v = 0; float v = 0;
ggml_vec_dot_f16(ne03, &v, ggml_vec_dot_f16(ne03, &v, 0,
wdata_src + i1n, wdata_src + i1n, 0,
wdata_kernel + i01*ne00*ne03 + i00*ne03); wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v; dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
} }
} }
@ -13238,9 +13294,9 @@ static void ggml_compute_forward_flash_attn_f32(
const int i1 = ik1; const int i1 = ik1;
ggml_vec_dot_f32(neq0, ggml_vec_dot_f32(neq0,
S + i1, S + i1, 0,
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
} }
// scale // scale
@ -13323,9 +13379,9 @@ static void ggml_compute_forward_flash_attn_f32(
const int iv3 = iq3; const int iv3 = iq3;
ggml_vec_dot_f32(masked_begin, ggml_vec_dot_f32(masked_begin,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
S); S, 0, 1);
} }
} }
} }
@ -13428,9 +13484,9 @@ static void ggml_compute_forward_flash_attn_f16(
const int i1 = ik1; const int i1 = ik1;
ggml_vec_dot_f16(neq0, ggml_vec_dot_f16(neq0,
S + i1, S + i1, 0,
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
} }
} else { } else {
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
@ -13532,9 +13588,9 @@ static void ggml_compute_forward_flash_attn_f16(
const int iv3 = iq3; const int iv3 = iq3;
ggml_vec_dot_f16(nev0, ggml_vec_dot_f16(nev0,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
S16); S16, 0, 1);
} }
} else { } else {
for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
@ -13676,9 +13732,9 @@ static void ggml_compute_forward_flash_ff_f16(
const int i1 = ib01; const int i1 = ib01;
ggml_vec_dot_f16(nea0, ggml_vec_dot_f16(nea0,
S + i1, S + i1, 0,
(ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
(ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3))); (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
} }
ggml_vec_add_f32(neb01, S, S, (float *) b1->data); ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
@ -13701,9 +13757,9 @@ static void ggml_compute_forward_flash_ff_f16(
for (int64_t ic = 0; ic < nec01; ++ic) { for (int64_t ic = 0; ic < nec01; ++ic) {
ggml_vec_dot_f16(neb01, ggml_vec_dot_f16(neb01,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
(ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
S16); S16, 0, 1);
} }
ggml_vec_add_f32(nec01, ggml_vec_add_f32(nec01,
@ -13890,9 +13946,9 @@ static void ggml_compute_forward_flash_attn_back_f32(
const int i1 = ik1; const int i1 = ik1;
ggml_vec_dot_f32(neq0, ggml_vec_dot_f32(neq0,
S + i1, S + i1, 0,
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
} }
// scale // scale
@ -14037,7 +14093,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
// S = SM * (S - dot(SM, S)) // S = SM * (S - dot(SM, S))
float dot_SM_gradSM = 0; float dot_SM_gradSM = 0;
ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S); ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
ggml_vec_mul_f32 (masked_begin, S, S, SM); ggml_vec_mul_f32 (masked_begin, S, S, SM);
@ -15335,7 +15391,7 @@ static struct ggml_tensor * ggml_recompute_graph_node(
return NULL; return NULL;
} }
if (node->is_param) { if (node->flags & GGML_TENSOR_FLAG_PARAM) {
return node; return node;
} }
@ -15369,7 +15425,7 @@ static struct ggml_tensor * ggml_recompute_graph_node(
clone->op = node->op; clone->op = node->op;
clone->grad = node->grad; clone->grad = node->grad;
clone->is_param = node->is_param; clone->flags = node->flags;
clone->extra = node->extra; clone->extra = node->extra;
for (int k = 0; k < GGML_MAX_DIMS; ++k) { for (int k = 0; k < GGML_MAX_DIMS; ++k) {
clone->nb[k] = node->nb[k]; clone->nb[k] = node->nb[k];
@ -16401,7 +16457,7 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
for (int i = 0; i < gf->n_nodes; i++) { for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_tensor * node = gf->nodes[i]; struct ggml_tensor * node = gf->nodes[i];
if (node->is_param) { if (node->flags & GGML_TENSOR_FLAG_PARAM) {
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
ggml_build_forward_expand(gb, node->grad); ggml_build_forward_expand(gb, node->grad);
} }
@ -16692,7 +16748,7 @@ struct ggml_compute_state_shared {
atomic_int node_n; // active graph node atomic_int node_n; // active graph node
atomic_int node_task; // active graph node task phase atomic_int node_task; // active graph node task phase
bool (*abort_callback)(void * data); // abort ggml_graph_compute when true ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
void * abort_callback_data; void * abort_callback_data;
}; };
@ -17905,7 +17961,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
i, i,
node->ne[0], node->ne[1], node->ne[2], node->ne[0], node->ne[1], node->ne[2],
ggml_op_name(node->op), node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " ", node->perf_runs,
(double) node->perf_cycles / (double) ggml_cycles_per_ms(), (double) node->perf_cycles / (double) ggml_cycles_per_ms(),
(double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs, (double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs,
(double) node->perf_time_us / 1000.0, (double) node->perf_time_us / 1000.0,
@ -17998,7 +18054,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
continue; continue;
} }
if (node->is_param) { if (node->flags & GGML_TENSOR_FLAG_PARAM) {
snprintf(color, sizeof(color), "yellow"); snprintf(color, sizeof(color), "yellow");
} else if (node->grad) { } else if (node->grad) {
if (ggml_graph_find(gf, node)) { if (ggml_graph_find(gf, node)) {
@ -18172,7 +18228,7 @@ static enum ggml_opt_result ggml_opt_adam(
int np = 0; int np = 0;
int64_t nx = 0; int64_t nx = 0;
for (int i = 0; i < gf->n_nodes; ++i) { for (int i = 0; i < gf->n_nodes; ++i) {
if (gf->nodes[i]->is_param) { if (gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
GGML_ASSERT(np < GGML_MAX_PARAMS); GGML_ASSERT(np < GGML_MAX_PARAMS);
@ -18425,7 +18481,7 @@ static enum ggml_opt_result linesearch_backtracking(
} }
// compute the initial gradient in the search direction // compute the initial gradient in the search direction
ggml_vec_dot_f32(nx, &dginit, g, d); ggml_vec_dot_f32(nx, &dginit, 0, g, 0, d, 0, 1);
// make sure that d points to a descent direction // make sure that d points to a descent direction
if (0 < dginit) { if (0 < dginit) {
@ -18475,7 +18531,7 @@ static enum ggml_opt_result linesearch_backtracking(
return count; return count;
} }
ggml_vec_dot_f32(nx, &dg, g, d); ggml_vec_dot_f32(nx, &dg, 0, g, 0, d, 0, 1);
// check the Wolfe condition // check the Wolfe condition
if (dg < params->lbfgs.wolfe * dginit) { if (dg < params->lbfgs.wolfe * dginit) {
@ -18535,7 +18591,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
int np = 0; int np = 0;
int nx = 0; int nx = 0;
for (int i = 0; i < gf->n_nodes; ++i) { for (int i = 0; i < gf->n_nodes; ++i) {
if (gf->nodes[i]->is_param) { if (gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
GGML_ASSERT(np < GGML_MAX_PARAMS); GGML_ASSERT(np < GGML_MAX_PARAMS);
@ -18736,8 +18792,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
// ys = y^t \cdot s -> 1 / \rho. // ys = y^t \cdot s -> 1 / \rho.
// yy = y^t \cdot y. // yy = y^t \cdot y.
// //
ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0]*nx]); ggml_vec_dot_f32(nx, &ys, 0, &lm_y[end[0]*nx], 0, &lm_s[end[0]*nx], 0, 1);
ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]); ggml_vec_dot_f32(nx, &yy, 0, &lm_y[end[0]*nx], 0, &lm_y[end[0]*nx], 0, 1);
lm_ys[end[0]] = ys; lm_ys[end[0]] = ys;
@ -18756,7 +18812,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
for (int i = 0; i < bound; ++i) { for (int i = 0; i < bound; ++i) {
j[0] = (j[0] + m - 1) % m; j[0] = (j[0] + m - 1) % m;
// \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1} // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d); ggml_vec_dot_f32(nx, &lm_alpha[j[0]], 0, &lm_s[j[0]*nx], 0, d, 0, 1);
lm_alpha[j[0]] /= lm_ys[j[0]]; lm_alpha[j[0]] /= lm_ys[j[0]];
// q_{i} = q_{i+1} - \alpha_{i} y_{i} // q_{i} = q_{i+1} - \alpha_{i} y_{i}
ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]); ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
@ -18766,7 +18822,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
for (int i = 0; i < bound; ++i) { for (int i = 0; i < bound; ++i) {
// \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i} // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d); ggml_vec_dot_f32(nx, &beta, 0, &lm_y[j[0]*nx], 0, d, 0, 1);
beta /= lm_ys[j[0]]; beta /= lm_ys[j[0]];
// \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j} // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta); ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
@ -19010,6 +19066,16 @@ enum ggml_opt_result ggml_opt_resume_g(
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
void ggml_set_input(struct ggml_tensor * tensor) {
tensor->flags |= GGML_TENSOR_FLAG_INPUT;
}
void ggml_set_output(struct ggml_tensor * tensor) {
tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
}
////////////////////////////////////////////////////////////////////////////////
void ggml_quantize_init(enum ggml_type type) { void ggml_quantize_init(enum ggml_type type) {
ggml_critical_section_start(); ggml_critical_section_start();
@ -20654,4 +20720,12 @@ int ggml_cpu_has_vsx(void) {
#endif #endif
} }
int ggml_cpu_has_matmul_int8(void) {
#if defined(__ARM_FEATURE_MATMUL_INT8)
return 1;
#else
return 0;
#endif
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

32
ggml.h
View file

@ -505,11 +505,17 @@ extern "C" {
enum ggml_log_level { enum ggml_log_level {
GGML_LOG_LEVEL_ERROR = 2, GGML_LOG_LEVEL_ERROR = 2,
GGML_LOG_LEVEL_WARN = 3, GGML_LOG_LEVEL_WARN = 3,
GGML_LOG_LEVEL_INFO = 4, GGML_LOG_LEVEL_INFO = 4,
GGML_LOG_LEVEL_DEBUG = 5 GGML_LOG_LEVEL_DEBUG = 5
}; };
enum ggml_tensor_flag {
GGML_TENSOR_FLAG_INPUT = 1,
GGML_TENSOR_FLAG_OUTPUT = 2,
GGML_TENSOR_FLAG_PARAM = 4,
};
// ggml object // ggml object
struct ggml_object { struct ggml_object {
size_t offs; size_t offs;
@ -543,7 +549,7 @@ extern "C" {
// op params - allocated as int32_t for alignment // op params - allocated as int32_t for alignment
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
bool is_param; int32_t flags;
struct ggml_tensor * grad; struct ggml_tensor * grad;
struct ggml_tensor * src[GGML_MAX_SRC]; struct ggml_tensor * src[GGML_MAX_SRC];
@ -567,6 +573,11 @@ extern "C" {
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
// Abort callback
// If not NULL, called before ggml computation
// If it returns true, the computation is aborted
typedef bool (*ggml_abort_callback)(void * data);
// the compute plan that needs to be prepared for ggml_graph_compute() // the compute plan that needs to be prepared for ggml_graph_compute()
// since https://github.com/ggerganov/ggml/issues/287 // since https://github.com/ggerganov/ggml/issues/287
struct ggml_cplan { struct ggml_cplan {
@ -576,8 +587,8 @@ extern "C" {
int n_threads; int n_threads;
// abort ggml_graph_compute when true // abort ggml_graph_compute when true
bool (*abort_callback)(void * data); ggml_abort_callback abort_callback;
void * abort_callback_data; void * abort_callback_data;
}; };
enum ggml_cgraph_eval_order { enum ggml_cgraph_eval_order {
@ -2097,6 +2108,12 @@ extern "C" {
ggml_opt_callback callback, ggml_opt_callback callback,
void * callback_data); void * callback_data);
//
// tensor flags
//
GGML_API void ggml_set_input(struct ggml_tensor * tensor);
GGML_API void ggml_set_output(struct ggml_tensor * tensor);
// //
// quantization // quantization
// //
@ -2283,6 +2300,7 @@ extern "C" {
GGML_API int ggml_cpu_has_ssse3 (void); GGML_API int ggml_cpu_has_ssse3 (void);
GGML_API int ggml_cpu_has_sycl (void); GGML_API int ggml_cpu_has_sycl (void);
GGML_API int ggml_cpu_has_vsx (void); GGML_API int ggml_cpu_has_vsx (void);
GGML_API int ggml_cpu_has_matmul_int8(void);
// //
// Internal types and functions exposed for tests and benchmarks // Internal types and functions exposed for tests and benchmarks
@ -2296,7 +2314,8 @@ extern "C" {
#endif #endif
typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
typedef void (*ggml_vec_dot_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y); typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
const void * GGML_RESTRICT y, size_t by, int nrc);
typedef struct { typedef struct {
const char * type_name; const char * type_name;
@ -2308,6 +2327,7 @@ extern "C" {
ggml_from_float_t from_float_reference; ggml_from_float_t from_float_reference;
ggml_vec_dot_t vec_dot; ggml_vec_dot_t vec_dot;
enum ggml_type vec_dot_type; enum ggml_type vec_dot_type;
int64_t nrows; // number of rows to process simultaneously;
} ggml_type_traits_t; } ggml_type_traits_t;
GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);

View file

@ -2067,6 +2067,8 @@ type_names = {
K_QUANTS_PER_ITERATION = 2 K_QUANTS_PER_ITERATION = 2
ASYNCIO_CONCURRENCY = 64
output_dir = gettempdir() output_dir = gettempdir()
lock = asyncio.Lock() lock = asyncio.Lock()
@ -2291,7 +2293,14 @@ async def main():
tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
await asyncio.gather(*tasks) # Helper to decorate tasks with semaphore acquisition.
async def withSemaphore(sem, task):
async with sem:
return await task
# Run tasks concurrently guarded by a concurrency limit.
sem = asyncio.Semaphore(ASYNCIO_CONCURRENCY)
await asyncio.gather(*(withSemaphore(sem, task) for task in tasks))
with open("ggml-vulkan-shaders.hpp", "w") as f: with open("ggml-vulkan-shaders.hpp", "w") as f:
f.write("#include <cstdint>\n\n") f.write("#include <cstdint>\n\n")

View file

@ -50,6 +50,7 @@ class Keys:
VALUE_LENGTH = "{arch}.attention.value_length" VALUE_LENGTH = "{arch}.attention.value_length"
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
CAUSAL = "{arch}.attention.causal"
class Rope: class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count" DIMENSION_COUNT = "{arch}.rope.dimension_count"
@ -60,22 +61,23 @@ class Keys:
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
class Tokenizer: class Tokenizer:
MODEL = "tokenizer.ggml.model" MODEL = "tokenizer.ggml.model"
LIST = "tokenizer.ggml.tokens" LIST = "tokenizer.ggml.tokens"
TOKEN_TYPE = "tokenizer.ggml.token_type" TOKEN_TYPE = "tokenizer.ggml.token_type"
SCORES = "tokenizer.ggml.scores" TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types
MERGES = "tokenizer.ggml.merges" SCORES = "tokenizer.ggml.scores"
BOS_ID = "tokenizer.ggml.bos_token_id" MERGES = "tokenizer.ggml.merges"
EOS_ID = "tokenizer.ggml.eos_token_id" BOS_ID = "tokenizer.ggml.bos_token_id"
UNK_ID = "tokenizer.ggml.unknown_token_id" EOS_ID = "tokenizer.ggml.eos_token_id"
SEP_ID = "tokenizer.ggml.seperator_token_id" UNK_ID = "tokenizer.ggml.unknown_token_id"
PAD_ID = "tokenizer.ggml.padding_token_id" SEP_ID = "tokenizer.ggml.seperator_token_id"
ADD_BOS = "tokenizer.ggml.add_bos_token" PAD_ID = "tokenizer.ggml.padding_token_id"
ADD_EOS = "tokenizer.ggml.add_eos_token" ADD_BOS = "tokenizer.ggml.add_bos_token"
ADD_PREFIX = "tokenizer.ggml.add_space_prefix" ADD_EOS = "tokenizer.ggml.add_eos_token"
HF_JSON = "tokenizer.huggingface.json" ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
RWKV = "tokenizer.rwkv.world" HF_JSON = "tokenizer.huggingface.json"
CHAT_TEMPLATE = "tokenizer.chat_template" RWKV = "tokenizer.rwkv.world"
CHAT_TEMPLATE = "tokenizer.chat_template"
# #
@ -122,6 +124,7 @@ class MODEL_TENSOR(IntEnum):
ATTN_OUT = auto() ATTN_OUT = auto()
ATTN_NORM = auto() ATTN_NORM = auto()
ATTN_NORM_2 = auto() ATTN_NORM_2 = auto()
ATTN_OUT_NORM = auto()
ATTN_ROT_EMBD = auto() ATTN_ROT_EMBD = auto()
FFN_GATE_INP = auto() FFN_GATE_INP = auto()
FFN_NORM = auto() FFN_NORM = auto()
@ -134,6 +137,7 @@ class MODEL_TENSOR(IntEnum):
FFN_UP_EXP = auto() FFN_UP_EXP = auto()
ATTN_Q_NORM = auto() ATTN_Q_NORM = auto()
ATTN_K_NORM = auto() ATTN_K_NORM = auto()
LAYER_OUT_NORM = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -178,6 +182,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
@ -187,6 +192,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}", MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
} }
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -262,17 +268,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
], ],
MODEL_ARCH.BERT: [ MODEL_ARCH.BERT: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.TOKEN_TYPES, MODEL_TENSOR.TOKEN_TYPES,
MODEL_TENSOR.POS_EMBD, MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.LAYER_OUT_NORM,
], ],
MODEL_ARCH.MPT: [ MODEL_ARCH.MPT: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,

View file

@ -357,6 +357,9 @@ class GGUFWriter:
def add_layer_norm_rms_eps(self, value: float) -> None: def add_layer_norm_rms_eps(self, value: float) -> None:
self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
def add_causal_attention(self, value: bool) -> None:
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
def add_rope_dimension_count(self, count: int) -> None: def add_rope_dimension_count(self, count: int) -> None:
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count) self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
@ -387,6 +390,9 @@ class GGUFWriter:
def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None: def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
self.add_array(Keys.Tokenizer.TOKEN_TYPE, types) self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)
def add_token_type_count(self, value: int) -> None:
self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value)
def add_token_scores(self, scores: Sequence[float]) -> None: def add_token_scores(self, scores: Sequence[float]) -> None:
self.add_array(Keys.Tokenizer.SCORES, scores) self.add_array(Keys.Tokenizer.SCORES, scores)

View file

@ -30,6 +30,7 @@ class TensorNameMap:
# Normalization of token embeddings # Normalization of token embeddings
MODEL_TENSOR.TOKEN_EMBD_NORM: ( MODEL_TENSOR.TOKEN_EMBD_NORM: (
"word_embeddings_layernorm", # bloom "word_embeddings_layernorm", # bloom
"embeddings.LayerNorm", # bert
), ),
# Position embeddings # Position embeddings
@ -54,7 +55,6 @@ class TensorNameMap:
"transformer.ln_f", # gpt2 gpt-j falcon "transformer.ln_f", # gpt2 gpt-j falcon
"model.norm", # llama-hf baichuan internlm2 "model.norm", # llama-hf baichuan internlm2
"norm", # llama-pth "norm", # llama-pth
"embeddings.LayerNorm", # bert
"transformer.norm_f", # mpt "transformer.norm_f", # mpt
"ln_f", # refact bloom qwen gpt2 "ln_f", # refact bloom qwen gpt2
"language_model.encoder.final_layernorm", # persimmon "language_model.encoder.final_layernorm", # persimmon
@ -79,7 +79,6 @@ class TensorNameMap:
"transformer.h.{bid}.ln_mlp", # falcon40b "transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf "model.layers.{bid}.input_layernorm", # llama-hf
"layers.{bid}.attention_norm", # llama-pth "layers.{bid}.attention_norm", # llama-pth
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon "language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi "model.layers.{bid}.ln1", # yi
"h.{bid}.ln_1", # gpt2 "h.{bid}.ln_1", # gpt2
@ -155,6 +154,11 @@ class TensorNameMap:
"model.layers.{bid}.attention.wo", # internlm2 "model.layers.{bid}.attention.wo", # internlm2
), ),
# Attention output norm
MODEL_TENSOR.ATTN_OUT_NORM: (
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
),
# Rotary embeddings # Rotary embeddings
MODEL_TENSOR.ATTN_ROT_EMBD: ( MODEL_TENSOR.ATTN_ROT_EMBD: (
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
@ -171,7 +175,6 @@ class TensorNameMap:
"transformer.blocks.{bid}.norm_2", # mpt "transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf "model.layers.{bid}.post_attention_layernorm", # llama-hf
"layers.{bid}.ffn_norm", # llama-pth "layers.{bid}.ffn_norm", # llama-pth
"encoder.layer.{bid}.output.LayerNorm", # bert
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
"model.layers.{bid}.ln2", # yi "model.layers.{bid}.ln2", # yi
"h.{bid}.ln_2", # gpt2 "h.{bid}.ln_2", # gpt2
@ -266,6 +269,10 @@ class TensorNameMap:
MODEL_TENSOR.ROPE_FREQS: ( MODEL_TENSOR.ROPE_FREQS: (
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
), ),
MODEL_TENSOR.LAYER_OUT_NORM: (
"encoder.layer.{bid}.output.LayerNorm", # bert
)
} }
mapping: dict[str, tuple[MODEL_TENSOR, str]] mapping: dict[str, tuple[MODEL_TENSOR, str]]

694
llama.cpp

File diff suppressed because it is too large Load diff

View file

@ -61,6 +61,7 @@ extern "C" {
enum llama_vocab_type { enum llama_vocab_type {
LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece
}; };
enum llama_token_type { enum llama_token_type {

View file

@ -156,8 +156,8 @@ int main(int argc, char** argv) {
t1 = std::chrono::high_resolution_clock::now(); t1 = std::chrono::high_resolution_clock::now();
float fs; float fs;
if (type == 0) funcs.vec_dot(kVecSize * QK4_1, &fs, x40.data(), y.data()); if (type == 0) funcs.vec_dot(kVecSize * QK4_1, &fs, 0, x40.data(), 0, y.data(), 0, 1);
else funcs.vec_dot(kVecSize * QK4_1, &fs, x41.data(), y.data()); else funcs.vec_dot(kVecSize * QK4_1, &fs, 0, x41.data(), 0, y.data(), 0, 1);
t2 = std::chrono::high_resolution_clock::now(); t2 = std::chrono::high_resolution_clock::now();
t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count(); t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
if (iloop > 3) ggml.addResult(fs, t); if (iloop > 3) ggml.addResult(fs, t);

View file

@ -284,8 +284,8 @@ int main(int argc, char** argv) {
else { else {
auto vdot = ggml_internal_get_type_traits(funcs.vec_dot_type); auto vdot = ggml_internal_get_type_traits(funcs.vec_dot_type);
vdot.from_float(y1.data(), q8.data(), kVecSize); vdot.from_float(y1.data(), q8.data(), kVecSize);
if (useQ4_1) funcs.vec_dot(kVecSize, &result, q41.data(), q8.data()); if (useQ4_1) funcs.vec_dot(kVecSize, &result, 0, q41.data(), 0, q8.data(), 0, 1);
else funcs.vec_dot(kVecSize, &result, q40.data(), q8.data()); else funcs.vec_dot(kVecSize, &result, 0, q40.data(), 0, q8.data(), 0, 1);
} }
sumq += result; sumq += result;
t2 = std::chrono::high_resolution_clock::now(); t2 = std::chrono::high_resolution_clock::now();

View file

@ -97,6 +97,8 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
# src/ggml-cuda.cu -> ggml-cuda.cu # src/ggml-cuda.cu -> ggml-cuda.cu
# src/ggml-cuda.h -> ggml-cuda.h # src/ggml-cuda.h -> ggml-cuda.h
# src/ggml-impl.h -> ggml-impl.h # src/ggml-impl.h -> ggml-impl.h
# src/ggml-kompute.cpp -> ggml-kompute.cpp
# src/ggml-kompute.h -> ggml-kompute.h
# src/ggml-metal.h -> ggml-metal.h # src/ggml-metal.h -> ggml-metal.h
# src/ggml-metal.m -> ggml-metal.m # src/ggml-metal.m -> ggml-metal.m
# src/ggml-mpi.h -> ggml-mpi.h # src/ggml-mpi.h -> ggml-mpi.h
@ -105,6 +107,10 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
# src/ggml-opencl.h -> ggml-opencl.h # src/ggml-opencl.h -> ggml-opencl.h
# src/ggml-quants.c -> ggml-quants.c # src/ggml-quants.c -> ggml-quants.c
# src/ggml-quants.h -> ggml-quants.h # src/ggml-quants.h -> ggml-quants.h
# src/ggml-sycl.cpp -> ggml-sycl.cpp
# src/ggml-sycl.h -> ggml-sycl.h
# src/ggml-vulkan.cpp -> ggml-vulkan.cpp
# src/ggml-vulkan.h -> ggml-vulkan.h
# include/ggml/ggml.h -> ggml.h # include/ggml/ggml.h -> ggml.h
# include/ggml/ggml-alloc.h -> ggml-alloc.h # include/ggml/ggml-alloc.h -> ggml-alloc.h
# include/ggml/ggml-backend.h -> ggml-backend.h # include/ggml/ggml-backend.h -> ggml-backend.h
@ -123,6 +129,8 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
-e 's/src\/ggml-cuda\.cu/ggml-cuda.cu/g' \ -e 's/src\/ggml-cuda\.cu/ggml-cuda.cu/g' \
-e 's/src\/ggml-cuda\.h/ggml-cuda.h/g' \ -e 's/src\/ggml-cuda\.h/ggml-cuda.h/g' \
-e 's/src\/ggml-impl\.h/ggml-impl.h/g' \ -e 's/src\/ggml-impl\.h/ggml-impl.h/g' \
-e 's/src\/ggml-kompute\.cpp/ggml-kompute.cpp/g' \
-e 's/src\/ggml-kompute\.h/ggml-kompute.h/g' \
-e 's/src\/ggml-metal\.h/ggml-metal.h/g' \ -e 's/src\/ggml-metal\.h/ggml-metal.h/g' \
-e 's/src\/ggml-metal\.m/ggml-metal.m/g' \ -e 's/src\/ggml-metal\.m/ggml-metal.m/g' \
-e 's/src\/ggml-mpi\.h/ggml-mpi.h/g' \ -e 's/src\/ggml-mpi\.h/ggml-mpi.h/g' \
@ -131,6 +139,10 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
-e 's/src\/ggml-opencl\.h/ggml-opencl.h/g' \ -e 's/src\/ggml-opencl\.h/ggml-opencl.h/g' \
-e 's/src\/ggml-quants\.c/ggml-quants.c/g' \ -e 's/src\/ggml-quants\.c/ggml-quants.c/g' \
-e 's/src\/ggml-quants\.h/ggml-quants.h/g' \ -e 's/src\/ggml-quants\.h/ggml-quants.h/g' \
-e 's/src\/ggml-sycl\.cpp/ggml-sycl.cpp/g' \
-e 's/src\/ggml-sycl\.h/ggml-sycl.h/g' \
-e 's/src\/ggml-vulkan\.cpp/ggml-vulkan.cpp/g' \
-e 's/src\/ggml-vulkan\.h/ggml-vulkan.h/g' \
-e 's/include\/ggml\/ggml\.h/ggml.h/g' \ -e 's/include\/ggml\/ggml\.h/ggml.h/g' \
-e 's/include\/ggml\/ggml-alloc\.h/ggml-alloc.h/g' \ -e 's/include\/ggml\/ggml-alloc\.h/ggml-alloc.h/g' \
-e 's/include\/ggml\/ggml-backend\.h/ggml-backend.h/g' \ -e 's/include\/ggml\/ggml-backend\.h/ggml-backend.h/g' \

View file

@ -1 +1 @@
475cbad5c1c834e31e26a2283bc1413181644360 5070f078a67c18c11736e78316ab715ca9afde16

View file

@ -7,6 +7,8 @@ cp -rpv ../ggml/src/ggml-backend.c ./ggml-backend.c
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
cp -rpv ../ggml/src/ggml-impl.h ./ggml-impl.h cp -rpv ../ggml/src/ggml-impl.h ./ggml-impl.h
cp -rpv ../ggml/src/ggml-kompute.cpp ./ggml-kompute.cpp
cp -rpv ../ggml/src/ggml-kompute.h ./ggml-kompute.h
cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
@ -16,6 +18,10 @@ cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
cp -rpv ../ggml/src/ggml-quants.c ./ggml-quants.c cp -rpv ../ggml/src/ggml-quants.c ./ggml-quants.c
cp -rpv ../ggml/src/ggml-quants.h ./ggml-quants.h cp -rpv ../ggml/src/ggml-quants.h ./ggml-quants.h
cp -rpv ../ggml/src/ggml-sycl.cpp ./ggml-sycl.cpp
cp -rpv ../ggml/src/ggml-sycl.h ./ggml-sycl.h
cp -rpv ../ggml/src/ggml-vulkan.cpp ./ggml-vulkan.cpp
cp -rpv ../ggml/src/ggml-vulkan.h ./ggml-vulkan.h
cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h
cp -rpv ../ggml/include/ggml/ggml-backend.h ./ggml-backend.h cp -rpv ../ggml/include/ggml/ggml-backend.h ./ggml-backend.h

1
spm-headers/ggml-alloc.h Symbolic link
View file

@ -0,0 +1 @@
../ggml-alloc.h

1
spm-headers/ggml-backend.h Symbolic link
View file

@ -0,0 +1 @@
../ggml-backend.h

1
spm-headers/ggml.h Symbolic link
View file

@ -0,0 +1 @@
../ggml.h

View file

@ -87,7 +87,7 @@ static float dot_product_error(
vdot.from_float(test_data2, tmp_q2.data(), test_size); vdot.from_float(test_data2, tmp_q2.data(), test_size);
float result = INFINITY; float result = INFINITY;
qfns.vec_dot(test_size, &result, tmp_q1.data(), tmp_q2.data()); qfns.vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);
const float dot_ref = dot_product(test_data1, test_data2, test_size); const float dot_ref = dot_product(test_data1, test_data2, test_size);

View file

@ -346,7 +346,7 @@ int main(int argc, char * argv[]) {
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024)); printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
auto quantize_fn = [&](void) -> float { auto quantize_fn = [&](void) -> float {
float result; float result;
qfns.vec_dot(size, &result, test_q1, test_q2); qfns.vec_dot(size, &result, 0, test_q1, 0, test_q2, 0, 1);
return result; return result;
}; };
size_t quantized_size = ggml_row_size(type, size); size_t quantized_size = ggml_row_size(type, size);